A collection of RL algorithms written in JAX.
Project description
WARNING: Rljax is currently in a beta version and being actively improved. Any contributions are welcome :)
RL Algorithms in JAX
Rljax is a collection of RL algorithms written in JAX.
Setup
You can install dependencies simply by executing the following.
pip install --upgrade https://storage.googleapis.com/jax-releases/`nvcc -V | sed -En "s/.* release ([0-9]*)\.([0-9]*),.*/cuda\1\2/p"`/jaxlib-0.1.55-`python3 -V | sed -En "s/Python ([0-9]*)\.([0-9]*).*/cp\1\2/p"`-none-manylinux2010_x86_64.whl jax
pip install -r requirements.txt
If you don't have a GPU, please executing the following instead.
pip install --upgrade jaxlib jax
pip install -r requirements.txt
If you want to use a MuJoCo physics engine, please install mujoco-py.
Algorithms
Currently, following algorithms have been implemented.
- Proximal Policy Optimization(PPO)
- Deep Deterministic Policy Gradient(DDPG)
- Twin Delayed DDPG(TD3)
- Soft Actor-Critic(SAC)
- Deep Q Network(DQN)
- N-step return
- Dueling Network
- Double Q-Learning
- Prioritized Experience Replay(PER)
- Soft Actor-Critic for Discrete Settings(SAC-Discrete)
We plan to implement the following algorithms in the future.
- Quantile Regression DQN(QR-DQN)
- Implicit Quantile Network(IQN)
Below shows that our algorithms successfully learning the discrete action environment CartPole-v0
and the continuous action environment InvertedPendulum-v2
.
Project details
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Hashes for rljax-0.0.1-py2.py3-none-any.whl
Algorithm | Hash digest | |
---|---|---|
SHA256 | 828fa0c5690804aebdde6e708f52fd28d5b8a9eef9675d99454a8795c908acb1 |
|
MD5 | 45e54b6c5da7c8b67c37bebe82728613 |
|
BLAKE2b-256 | 5c2c6cd2bce8c9382d620b6944defcd700277d938152c9cf38461779eafd51a9 |