JAX (Flax) Deep Learning Library
Project description
JAXDL: JAX (Flax) Deep Learning Library
Clean state-of-the-art JAX/Flax deep learning algorithm implementations:
- Soft-Actor-Critic (SAC) (arXiv:1812.05905)
- Twin-Delayed DDPG (TD3) (arXiv:1802.09477)
- Transformer (arXiv:1706.03762; planned)
- Unified Graph Network Blocks (arXiv:1806.01261; planned)
If you use JAXDL in your work, please cite this repository as follows:
@misc{jaxdl,
author = {Hart, Patrick},
month = {10},
doi = {10.5281/zenodo.5596512},
title = {{JAXDL: JAX Deep Learning Algorithm Implementations.}},
url = {https://github.com/patrickhart/jaxdl},
year = {2021}
}
Results / Benchmark
Continous Control From States
HalfCheetah-v2 | Ant-v2 |
---|---|
Reacher-v2 | Humanoid-v2 |
Installation
Install JAXDL using PyPi pip install jaxdl
.
To use MuJoCo 2.1 you need to run pip install git+https://github.com/nimrod-gileadi/mujoco-py
and place the binaries of MuJoCo in ~/.mujoco/mujoco210
.
Examples / Getting Started
To get started have a look in the examples folder.
To train a reinforcement learning agent run
python examples/run_rl.py \
--mode=train \
--env_name=Ant-v2 \
--save_dir=./tmp/ \
--config=./examples/configs/sac_config.py
To visualize the trained agent use
python examples/run_rl.py \
--mode=visualize \
--env_name=Ant-v2 \
--save_dir=./tmp/ \
--config=./examples/configs/sac_config.py
Tensorboard
Monitor the training run using:
tensorboard --logdir=/save_dir/
Contributing
Contributions are welcome! This repository is meant to provide clean and simple implementations – please consider this when contributing.
Project details
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
File details
Details for the file jaxdl-0.0.4.tar.gz
.
File metadata
- Download URL: jaxdl-0.0.4.tar.gz
- Upload date:
- Size: 254.1 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/3.4.2 importlib_metadata/4.8.1 pkginfo/1.7.1 requests/2.25.1 requests-toolbelt/0.9.1 tqdm/4.62.3 CPython/3.9.4+
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | 68b8519389cb6c1306c32c7bfba2986a0bec982958ed708a8d3a2c2a8dfd2215 |
|
MD5 | f779a62e1a0caca501294f51993a3bbe |
|
BLAKE2b-256 | b537400f12da711e4c37289f9b32a0b4ce1fde2ab5b35fea49038eab21a80dda |