Skip to main content

A collection of RL algorithms written in JAX.

Project description

WARNING: Rljax is currently in a beta version and being actively improved. Any contributions are welcome :)

RL Algorithms in JAX

Rljax is a collection of RL algorithms written in JAX.

Setup

You can install dependencies simply by executing the following. To use GPUs, nvidia-driver and CUDA must be installed.

pip install --upgrade https://storage.googleapis.com/jax-releases/`nvcc -V | sed -En "s/.* release ([0-9]*)\.([0-9]*),.*/cuda\1\2/p"`/jaxlib-0.1.55-`python3 -V | sed -En "s/Python ([0-9]*)\.([0-9]*).*/cp\1\2/p"`-none-manylinux2010_x86_64.whl jax
pip install -e .

If you don't have a GPU, please executing the following instead.

pip install --upgrade jaxlib jax
pip install -e .

If you want to use a MuJoCo physics engine, please install mujoco-py.

Algorithms

Currently, following algorithms have been implemented.

  • Proximal Policy Optimization(PPO)
  • Deep Deterministic Policy Gradient(DDPG)
  • Twin Delayed DDPG(TD3)
  • Soft Actor-Critic(SAC)
  • Deep Q Network(DQN)
  • N-step return
  • Dueling Network
  • Double Q-Learning
  • Prioritized Experience Replay(PER)
  • Soft Actor-Critic for Discrete Settings(SAC-Discrete)

We plan to implement the following algorithms in the future.

  • Quantile Regression DQN(QR-DQN)
  • Implicit Quantile Network(IQN)

Below shows that our algorithms successfully learning the discrete action environment CartPole-v0 and the continuous action environment InvertedPendulum-v2.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

rljax-0.0.2.tar.gz (15.4 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

rljax-0.0.2-py2.py3-none-any.whl (46.5 kB view details)

Uploaded Python 2Python 3

File details

Details for the file rljax-0.0.2.tar.gz.

File metadata

  • Download URL: rljax-0.0.2.tar.gz
  • Upload date:
  • Size: 15.4 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/3.2.0 pkginfo/1.5.0.1 requests/2.24.0 setuptools/50.3.0 requests-toolbelt/0.9.1 tqdm/4.48.2 CPython/3.8.5

File hashes

Hashes for rljax-0.0.2.tar.gz
Algorithm Hash digest
SHA256 0612f1b82e652fe603ef15e649faf7315e2b143008ddfa283b7394e4ccbb233a
MD5 bd48c2fc6e58d0d69e7da8439f7f918b
BLAKE2b-256 93079e1742195302204b9b5f942ccb6f13fc3d12e89d8561e3bf471e34fd90b9

See more details on using hashes here.

File details

Details for the file rljax-0.0.2-py2.py3-none-any.whl.

File metadata

  • Download URL: rljax-0.0.2-py2.py3-none-any.whl
  • Upload date:
  • Size: 46.5 kB
  • Tags: Python 2, Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/3.2.0 pkginfo/1.5.0.1 requests/2.24.0 setuptools/50.3.0 requests-toolbelt/0.9.1 tqdm/4.48.2 CPython/3.8.5

File hashes

Hashes for rljax-0.0.2-py2.py3-none-any.whl
Algorithm Hash digest
SHA256 e586b6a623c93c39549be98cfa7124edb814fe7bef7d0167f6ca0c70e64fe2e7
MD5 b67e43481921ba0b64bfab974f054d9e
BLAKE2b-256 900783bdf11ef1e0b9004a9bcccccb8157c82f9e93050b49172c0aed3a91e665

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page