Skip to main content

JAX (Flax) Deep Learning Library

Project description

JAXDL: JAX (Flax) Deep Learning Library

Clean state-of-the-art JAX/Flax deep learning algorithm implementations:

If you use JAXDL in your work, please cite this repository as follows:

@misc{jaxdl,
  author = {Hart, Patrick},
  month = {10},
  doi = {10.5281/zenodo.5596512},
  title = {{JAXDL: JAX Deep Learning Algorithm Implementations.}},
  url = {https://github.com/patrickhart/jaxdl},
  year = {2021}
}

Results / Benchmark

Continous Control From States

HalfCheetah-v2 Ant-v2
HalfCheetah-v2 Ant-v2
Reacher-v2 Humanoid-v2
Reacher-v2 Humanoid-v2

Installation

Install JAXDL using PyPi pip install jaxdl.

To use MuJoCo 2.1 you need to run pip install git+https://github.com/nimrod-gileadi/mujoco-py and place the binaries of MuJoCo in ~/.mujoco/mujoco210.

Examples / Getting Started

To get started have a look in the examples folder.

To train a reinforcement learning agent run

python examples/run_rl.py \
  --mode=train \
  --env_name=Ant-v2 \
  --save_dir=./tmp/ \
  --config=./examples/configs/sac_config.py

To visualize the trained agent use

python examples/run_rl.py \
  --mode=visualize \
  --env_name=Ant-v2 \
  --save_dir=./tmp/ \
  --config=./examples/configs/sac_config.py

Tensorboard

Monitor the training run using:

tensorboard --logdir=/save_dir/

Contributing

Contributions are welcome! This repository is meant to provide clean and simple implementations – please consider this when contributing.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

jaxdl-0.0.4.tar.gz (254.1 kB view details)

Uploaded Source

File details

Details for the file jaxdl-0.0.4.tar.gz.

File metadata

  • Download URL: jaxdl-0.0.4.tar.gz
  • Upload date:
  • Size: 254.1 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/3.4.2 importlib_metadata/4.8.1 pkginfo/1.7.1 requests/2.25.1 requests-toolbelt/0.9.1 tqdm/4.62.3 CPython/3.9.4+

File hashes

Hashes for jaxdl-0.0.4.tar.gz
Algorithm Hash digest
SHA256 68b8519389cb6c1306c32c7bfba2986a0bec982958ed708a8d3a2c2a8dfd2215
MD5 f779a62e1a0caca501294f51993a3bbe
BLAKE2b-256 b537400f12da711e4c37289f9b32a0b4ce1fde2ab5b35fea49038eab21a80dda

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page