Skip to main content

Lightweight library of RL algorithms in Jax

Project description

Rejax
Fully Vectorizable Reinforcement Learning Algorithms in Jax!
Open In Colab Code style: black License: Apache 2.0


If you're new to rejax and want to learn more about it,

Open In Colab 📸 Take a tour

Here's what you can expect:

⚡ Vectorize training for incredible speedups!

  • Use jax.jit on the whole train function to run training exclusively on your GPU!
  • Use jax.vmap and jax.pmap on the initial seed or hyperparameters to train a whole batch of agents in parallel!
from rejax.algos import get_agent

# Get train function and initialize config for training
train_fn, config_cls = get_agent("sac")
train_config = config_cls.make(env="CartPole-v1", learning_rate=0.001)

# Jit the training function
jitted_train_fn = jax.jit(train_fn)

# Vmap training function over 300 initial seeds
vmapped_train_fn = jax.vmap(jitted_train_fn, in_axes=(None, 0))

# Train 300 agents!
keys = jax.random.split(jax.random.PRNGKey(0), 300)
train_state, evaluation = vmapped_train_fn(train_config, keys)

Speedup over cleanRL on hopper Speedup over cleanRL on breakout

Benchmark on an A100 80G and a Intel Xeon 4215R CPU. Note that the hyperparameters were set to the default values of cleanRL, including buffer sizes. Shrinking the buffers can yield additional speedups due to better caching, and enables training of even more agents in parallel.

🤖 Implemented algorithms

Algorithm Discrete Continuous Notes
PPO
SAC discrete version as in Christodoulou, 2019
DQN incl. DDQN, Dueling DQN
DDPG
TD3

🛠 Easily extend and modify algorithms

The implementations focus on clarity! Easily modify the implemented algorithms by overwriting isolated parts, such as the loss function, trajectory generation or parameter updates. For example, easily turn DQN into DDQN by writing

class DoubleDQN(DQN):
    @classmethod
    def update(cls, config, state, minibatch):
        # Calculate DDQN-specific targets
        targets = ddqn_targets(config, state, minibatch)

        # The loss function predicts Q-values and returns MSBE
        def loss_fn(params):
            ...
            return jnp.mean((targets - q_values) ** 2)

        # Calculate gradients
        grads = jax.grad(loss_fn)(state.q_ts.params)

        # Update train state
        q_ts = state.q_ts.apply_gradients(grads=grads)
        state = state.replace(q_ts=q_ts)
        return state

🔙 Flexible callbacks

Using callbacks, you can run logging to the console, disk, wandb, and much more. Even when the whole train function is jitted! For example, run a jax.experimental.io_callback regular intervals during training, or print the current policies mean return:

def print_callback(config, state, rng):
    policy = make_act(config, state)         # Get current policy
    episode_returns = evaluate(policy, ...)  # Evaluate it
    jax.debug.print(                         # Print results
        "Step: {}. Mean return: {}",
        state.global_step,
        episode_returns.mean(),
    )
    return ()  # Must return PyTree (None is not a PyTree)

config = config.replace(eval_callback=print_callback)

Callbacks have the signature callback(config, train_state, rng) -> PyTree, which is called every eval_freq training steps with the config and current train state. The output of the callback will be aggregated over training and returned by the train function. The default callback runs a number of episodes in the training environment and returns their length and episodic return, such that the train function returns a training curve.

Importantly, this function is jit-compiled along with the rest of the algorithm. However, you can use one of Jax's callbacks such as jax.experimental.io_callback to implement model checkpoining, logging to wandb, and more, all while maintaining the advantages of a completely jittable training function.

💞 Alternatives in end-to-end GPU training

Libraries:

  • Brax along with several environments, brax implements PPO and SAC within their environment interface

Single file implementations:

  • PureJaxRL implements PPO, recurrent PPO and DQN
  • Stoix features DQN, DDPG, TD3, SAC, PPO, as well as popular extensions and more

✍ Cite us!

@misc{rejax, 
  title={rejax}, 
  url={https://github.com/keraJLi/rejax}, 
  journal={keraJLi/rejax}, 
  author={Liesen, Jarek and Lu, Chris and Lange, Robert}, 
  year={2024}
} 

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

rejax-0.0.1.tar.gz (28.4 kB view details)

Uploaded Source

Built Distribution

rejax-0.0.1-py3-none-any.whl (42.1 kB view details)

Uploaded Python 3

File details

Details for the file rejax-0.0.1.tar.gz.

File metadata

  • Download URL: rejax-0.0.1.tar.gz
  • Upload date:
  • Size: 28.4 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/5.1.0 CPython/3.9.19

File hashes

Hashes for rejax-0.0.1.tar.gz
Algorithm Hash digest
SHA256 59a47d69ebaf6905e62fac94d69e8d706d8f0b209a284d218aebb9b7f85b8ccb
MD5 22e7a2afd1f0bb454e437f8f5380f22f
BLAKE2b-256 ed3119c4b6e5f06a77d7e823622765f51e94d1f2822d45087358d59f507bfc98

See more details on using hashes here.

File details

Details for the file rejax-0.0.1-py3-none-any.whl.

File metadata

  • Download URL: rejax-0.0.1-py3-none-any.whl
  • Upload date:
  • Size: 42.1 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/5.1.0 CPython/3.9.19

File hashes

Hashes for rejax-0.0.1-py3-none-any.whl
Algorithm Hash digest
SHA256 6b5e88b97b4d8a26c2203864d90d20e0cce8e7713540e9b23edb1ca91c62b530
MD5 2e461cfc4f119f3f9458fe4ecdf5a87a
BLAKE2b-256 3a32c519ac1b6877904a6cbadd66ee826f2e07ea16ed7e6a3dec2b31ff93e2d8

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page