A collection of RL algorithms written in JAX.
Project description
WARNING: Rljax is currently in a beta version and being actively improved. Any contributions are welcome :)
RL Algorithms in JAX
Rljax is a collection of RL algorithms written in JAX.
Setup
You can install dependencies simply by executing the following. To use GPUs, nvidia-driver and CUDA must be installed.
pip install --upgrade https://storage.googleapis.com/jax-releases/`nvcc -V | sed -En "s/.* release ([0-9]*)\.([0-9]*),.*/cuda\1\2/p"`/jaxlib-0.1.55-`python3 -V | sed -En "s/Python ([0-9]*)\.([0-9]*).*/cp\1\2/p"`-none-manylinux2010_x86_64.whl jax
pip install -e .
If you don't have a GPU, please executing the following instead.
pip install --upgrade jaxlib jax
pip install -e .
If you want to use a MuJoCo physics engine, please install mujoco-py.
Algorithms
Currently, following algorithms have been implemented.
- Proximal Policy Optimization(PPO)
- Deep Deterministic Policy Gradient(DDPG)
- Twin Delayed DDPG(TD3)
- Soft Actor-Critic(SAC)
- Deep Q Network(DQN)
- N-step return
- Dueling Network
- Double Q-Learning
- Prioritized Experience Replay(PER)
- Soft Actor-Critic for Discrete Settings(SAC-Discrete)
We plan to implement the following algorithms in the future.
- Quantile Regression DQN(QR-DQN)
- Implicit Quantile Network(IQN)
Below shows that our algorithms successfully learning the discrete action environment CartPole-v0
and the continuous action environment InvertedPendulum-v2
.
Project details
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Hashes for rljax-0.0.2-py2.py3-none-any.whl
Algorithm | Hash digest | |
---|---|---|
SHA256 | e586b6a623c93c39549be98cfa7124edb814fe7bef7d0167f6ca0c70e64fe2e7 |
|
MD5 | b67e43481921ba0b64bfab974f054d9e |
|
BLAKE2b-256 | 900783bdf11ef1e0b9004a9bcccccb8157c82f9e93050b49172c0aed3a91e665 |