Skip to main content

Modular RL building blocks in JAX

Project description

Tests Code style: black pre-commit DOI

RL-BLOX

This project contains modular implementations of various model-free and model-based RL algorithms and consists of deep neural network-based as well as tabular representation of Q-Values, policies, etc. which can be used interchangeably. The goal of this project is for the authors to learn by reimplementing various RL algorithms and to eventually provide an algorithmic toolbox for research purposes.

[!CAUTION] This library is still experimental and under development. Using it may lead to experiencing bugs or changing interfaces. If you encounter any bugs or other issues, please let us know via the issue tracker. If you are an RL developer and want to collaborate, feel free to contact us.

Design Principles

The implementation of this project follows the following principles:

  1. Algorithms are functions!
  2. Algorithms are implemented in single files.
  3. Policies and values functions are data containers.

Dependencies

  1. Our environment interface is Gymnasium.
  2. We use JAX for everything.
  3. We use Chex to write reliable code.
  4. For optimization algorithms we use Optax.
  5. For probability distributions we use TensorFlow Probability.
  6. For all neural networks we use Flax NNX.
  7. To save checkpoints we use Orbax.

Installation

Install via PyPI

The easiest way to install is via PyPI:

pip install rl-blox

Install from source

Alternatively, e.g. if you want to develop extensions for the library, you can also install rl-blox from source:

git clone git@github.com:mlaux1/rl-blox.git

After cloning the repository, it is recommended to install the library in editable mode.

pip install -e .

Optional dependencies

To be able to run the provided examples use pip install 'rl-blox[examples]'.

To install development dependencies, please use pip install 'rl-blox[dev]'.

To enable logging with aim, please use pip install 'rl_blox[logging]'

You can install all optional dependencies (except logging) using pip install 'rl_blox[all]'.

Algorithm Implementations

We currently provide implementations of the following algorithms:

Algorithm Original Paper
Monte Carlo link
Q-learning link
SARSA link
REINFORCE link
Actor-Critic link
Dyna-Q link
CMA-ES link
Double Q-learning link
DQN link
Nature DQN link
DDQN link
DDPG link
PPO link
TD3 link
SAC link
PETS link
LAP link
TD7 link
MR.Q link
Active Task Scheduling link
SMT link

Getting Started

RL-BLOX relies on gymnasium's environment interface. This is an example with the SAC RL algorithm.

import gymnasium as gym

from rl_blox.algorithm.sac import create_sac_state, train_sac
from rl_blox.logging.checkpointer import OrbaxCheckpointer
from rl_blox.logging.logger import AIMLogger, LoggerList

env_name = "Pendulum-v1"
env = gym.make(env_name)
seed = 1
verbose = 1
env = gym.wrappers.RecordEpisodeStatistics(env)

hparams_models = dict(
    policy_hidden_nodes=[128, 128],
    policy_learning_rate=3e-4,
    q_hidden_nodes=[512, 512],
    q_learning_rate=1e-3,
    seed=seed,
)
hparams_algorithm = dict(
    total_timesteps=11_000,
    buffer_size=11_000,
    gamma=0.99,
    learning_starts=5_000,
)

if verbose:
    print(
        "This example uses the AIM logger. You will not see any output on "
        "stdout. Run 'aim up' to analyze the progress."
    )
checkpointer = OrbaxCheckpointer("/tmp/rl-blox/sac_example/", verbose=verbose)
logger = LoggerList([
    AIMLogger(),
    # uncomment to store checkpoints
    # checkpointer,
])
logger.define_experiment(
    env_name=env_name,
    algorithm_name="SAC",
    hparams=hparams_models | hparams_algorithm,
)
logger.define_checkpoint_frequency("policy", 1_000)

sac_state = create_sac_state(env, **hparams_models)
sac_result = train_sac(
    env,
    sac_state.policy,
    sac_state.policy_optimizer,
    sac_state.q,
    sac_state.q_optimizer,
    logger=logger,
    **hparams_algorithm,
)
env.close()
policy = sac_result.policy

# Do something with the trained policy...

API Documentation

You can build the sphinx documentation with

pip install -e '.[doc]'
cd doc
make html

The HTML documentation will be available under doc/build/html/index.html.

Contributing

If you wish to report bugs, please use the issue tracker. If you would like to contribute to RL-BLOX, just open an issue or a pull request. The target branch for merge requests is the development branch. The development branch will be merged to master for new releases. If you have questions about the software, you should ask them in the discussion section.

The recommended workflow to add a new feature, add documentation, or fix a bug is the following:

  • Push your changes to a branch (e.g. feature/x, doc/y, or fix/z) of your fork of the RL-BLOX repository.
  • Open a pull request to the main branch.

It is forbidden to directly push to the main branch.

Please also check out our contribution guide!

Testing

Run the tests with

pip install -e '.[dev]'
pytest

Releases

Semantic Versioning

Semantic versioning must be used, that is, the major version number will be incremented when the API changes in a backwards incompatible way, the minor version will be incremented when new functionality is added in a backwards compatible manner, and the patch version is incremented for bugfixes, documentation, etc.

Funding

This library is currently developed at the Robotics Group of the University of Bremen together with the Robotics Innovation Center of the German Research Center for Artificial Intelligence (DFKI) in Bremen.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

rl_blox-0.5.7.tar.gz (118.9 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

rl_blox-0.5.7-py3-none-any.whl (147.4 kB view details)

Uploaded Python 3

File details

Details for the file rl_blox-0.5.7.tar.gz.

File metadata

  • Download URL: rl_blox-0.5.7.tar.gz
  • Upload date:
  • Size: 118.9 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.10.12

File hashes

Hashes for rl_blox-0.5.7.tar.gz
Algorithm Hash digest
SHA256 4e0a0be0f1a6a1a7de2064d8a4f836ab67d78229675d9741a91de687a960792f
MD5 55ff6c5181b5a7e8765a856d4aa633a9
BLAKE2b-256 e9e1ad98cba7228a8e54360de147e17b4194eead578d2d18be0a57b3fede9374

See more details on using hashes here.

File details

Details for the file rl_blox-0.5.7-py3-none-any.whl.

File metadata

  • Download URL: rl_blox-0.5.7-py3-none-any.whl
  • Upload date:
  • Size: 147.4 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.10.12

File hashes

Hashes for rl_blox-0.5.7-py3-none-any.whl
Algorithm Hash digest
SHA256 371bd347db29def468c39bd56377df514bae5152b8e76d893aed89bd32ab665e
MD5 aacf1afb4d6164d585827458bf207a31
BLAKE2b-256 f12358be3fc2ecbc9df4c99d0c2d302e8bb0032fbace6a84bdfbc9e6da878983

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page