Lightweight library of RL algorithms in Jax
Project description
If you're new to rejax and want to learn more about it,
📸 Take a tour
Here's what you can expect:
⚡ Vectorize training for incredible speedups!
- Use
jax.jit
on the whole train function to run training exclusively on your GPU! - Use
jax.vmap
andjax.pmap
on the initial seed or hyperparameters to train a whole batch of agents in parallel!
from rejax.algos import get_agent
# Get train function and initialize config for training
train_fn, config_cls = get_agent("sac")
train_config = config_cls.make(env="CartPole-v1", learning_rate=0.001)
# Jit the training function
jitted_train_fn = jax.jit(train_fn)
# Vmap training function over 300 initial seeds
vmapped_train_fn = jax.vmap(jitted_train_fn, in_axes=(None, 0))
# Train 300 agents!
keys = jax.random.split(jax.random.PRNGKey(0), 300)
train_state, evaluation = vmapped_train_fn(train_config, keys)
Benchmark on an A100 80G and a Intel Xeon 4215R CPU. Note that the hyperparameters were set to the default values of cleanRL, including buffer sizes. Shrinking the buffers can yield additional speedups due to better caching, and enables training of even more agents in parallel.
🤖 Implemented algorithms
Algorithm | Discrete | Continuous | Notes |
---|---|---|---|
PPO | ✔ | ✔ | |
SAC | ✔ | ✔ | discrete version as in Christodoulou, 2019 |
DQN | ✔ | incl. DDQN, Dueling DQN | |
DDPG | ✔ | ||
TD3 | ✔ |
🛠 Easily extend and modify algorithms
The implementations focus on clarity! Easily modify the implemented algorithms by overwriting isolated parts, such as the loss function, trajectory generation or parameter updates. For example, easily turn DQN into DDQN by writing
class DoubleDQN(DQN):
@classmethod
def update(cls, config, state, minibatch):
# Calculate DDQN-specific targets
targets = ddqn_targets(config, state, minibatch)
# The loss function predicts Q-values and returns MSBE
def loss_fn(params):
...
return jnp.mean((targets - q_values) ** 2)
# Calculate gradients
grads = jax.grad(loss_fn)(state.q_ts.params)
# Update train state
q_ts = state.q_ts.apply_gradients(grads=grads)
state = state.replace(q_ts=q_ts)
return state
🔙 Flexible callbacks
Using callbacks, you can run logging to the console, disk, wandb, and much more. Even when the whole train function is jitted! For example, run a jax.experimental.io_callback regular intervals during training, or print the current policies mean return:
def print_callback(config, state, rng):
policy = make_act(config, state) # Get current policy
episode_returns = evaluate(policy, ...) # Evaluate it
jax.debug.print( # Print results
"Step: {}. Mean return: {}",
state.global_step,
episode_returns.mean(),
)
return () # Must return PyTree (None is not a PyTree)
config = config.replace(eval_callback=print_callback)
Callbacks have the signature callback(config, train_state, rng) -> PyTree
, which is called every eval_freq
training steps with the config and current train state. The output of the callback will be aggregated over training and returned by the train function. The default callback runs a number of episodes in the training environment and returns their length and episodic return, such that the train function returns a training curve.
Importantly, this function is jit-compiled along with the rest of the algorithm. However, you can use one of Jax's callbacks such as jax.experimental.io_callback
to implement model checkpoining, logging to wandb, and more, all while maintaining the advantages of a completely jittable training function.
💞 Alternatives in end-to-end GPU training
Libraries:
- Brax along with several environments, brax implements PPO and SAC within their environment interface
Single file implementations:
- PureJaxRL implements PPO, recurrent PPO and DQN
- Stoix features DQN, DDPG, TD3, SAC, PPO, as well as popular extensions and more
✍ Cite us!
@misc{rejax,
title={rejax},
url={https://github.com/keraJLi/rejax},
journal={keraJLi/rejax},
author={Liesen, Jarek and Lu, Chris and Lange, Robert},
year={2024}
}
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
File details
Details for the file rejax-0.0.1.tar.gz
.
File metadata
- Download URL: rejax-0.0.1.tar.gz
- Upload date:
- Size: 28.4 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/5.1.0 CPython/3.9.19
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | 59a47d69ebaf6905e62fac94d69e8d706d8f0b209a284d218aebb9b7f85b8ccb |
|
MD5 | 22e7a2afd1f0bb454e437f8f5380f22f |
|
BLAKE2b-256 | ed3119c4b6e5f06a77d7e823622765f51e94d1f2822d45087358d59f507bfc98 |
File details
Details for the file rejax-0.0.1-py3-none-any.whl
.
File metadata
- Download URL: rejax-0.0.1-py3-none-any.whl
- Upload date:
- Size: 42.1 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/5.1.0 CPython/3.9.19
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | 6b5e88b97b4d8a26c2203864d90d20e0cce8e7713540e9b23edb1ca91c62b530 |
|
MD5 | 2e461cfc4f119f3f9458fe4ecdf5a87a |
|
BLAKE2b-256 | 3a32c519ac1b6877904a6cbadd66ee826f2e07ea16ed7e6a3dec2b31ff93e2d8 |