Reinforcement learning in pure JAX.
Project description
Dopamax
Dopamax is a library containing pure JAX implementations of common reinforcement learning algorithms. Everything is implemented in JAX, including the environments. This allows for extremely fast training and evaluation of agents, because the entire loop of environment simulation, agent interaction, and policy updates can be compiled as a single XLA program and executed on CPUs, GPUs, or TPUs. More specifically, the implementations in Dopamax follow the Anakin Podracer architecture -- see this paper for more details.
Supported Algorithms
- Proximal Policy Optimization (PPO)
- Deep Q-Network (DQN)
- Deep Deterministic Policy Gradients (DDPG)
- Twin Delayed DDPG (TD3)
- Soft Actor Critic
- AlphaZero
Installation
Dopamax can be installed with:
pip install dopamax
This will install the dopamax Python package, as well as a command-line interface (CLI) for training and evaluation.
Note that only the CPU version of JAX is installed by default. If you would like to use a GPU or TPU, you will need to
install the appropriate version of JAX. See the
JAX installation instructions.
[!NOTE]
The above command will install the latest "release" of Dopamax, which may not necessarily align with the latest commit in the main branch. To install the version found in the main branch of this repository, you can use:pip install git+https://github.com/rystrauss/dopamax.git
Usage
After installation, the Dopamax CLI can be used to train and evaluate agents:
dopamax --help
Dopamax uses Weights and Biases (W&B) for logging and artifact management. Before using the CLI
for training and evaluation, you must first make sure you have a W&B account (it's free) and have authenticated
with wandb login.
Training
Agent's can be trained using the dopamax train command, to which you must provide a configuration file. The
configuration file is a YAML file that specifies the agent, environment, and training hyperparameters. You can find
examples in the examples directory. For example, to train a PPO agent on the CartPole environment, you would
run:
dopamax train --config examples/ppo-cartpole/config.yaml
Note that all of the example config files have a random seed specified, so you will get the same result every time you run the command. The seeds provided in the examples are known to result in a successful run (with the given hyperparameters). To get different results on each run, you can remove the seed from the config file.
Evaluation
Once you have trained some agents, you can evaluate them using the dopamax evaluate command. This will allow you to
specify a W&B agent artifact that you'd like to evaluate (these artifacts are produced by the training runs and
contain the agent hyperparameters and weights from the end of training). For example, to evaluate a PPO agent trained
on CartPole, you might use a command like:
dopamax evaluate --agent_artifact CartPole-PPO-agent:v0 --num_episodes 100
where --num_episodes 100 signals that you would like to rollout the agent's policy for 100 episodes. The minimum,
mean, and maximum episode reward will be logged back to W&B. If you would additionally like to render the episodes and
have then logged back to W&B, you can provide the --render flag. But note that this will usually significantly slow
down the evaluation process since environment rendering is not a pure JAX function and requires callbacks to the host.
You should usually only use the --render flag with a small number of episodes.
See Also
Some of the JAX-native packages that Dopamax relies on:
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file dopamax-0.2.1.tar.gz.
File metadata
- Download URL: dopamax-0.2.1.tar.gz
- Upload date:
- Size: 38.2 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.0.1 CPython/3.9.21
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
0dad6584f5ae5b3c14a8f853a7309052113cc1cd3d70a0a395424dc0f22fac83
|
|
| MD5 |
f75cccc01c15da5f978c040720c16bde
|
|
| BLAKE2b-256 |
8b26e85406c07be073333004dbb03cb3b55f23bebde6f8036cfc401c1406d40f
|
File details
Details for the file dopamax-0.2.1-py3-none-any.whl.
File metadata
- Download URL: dopamax-0.2.1-py3-none-any.whl
- Upload date:
- Size: 52.3 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.0.1 CPython/3.9.21
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
2124e8f41c9184e70d583f2dfa771dbb76e7c02f529738f324db407ef32731f9
|
|
| MD5 |
0ab4eb0fb8b7dfb3c9ef09d595fc1057
|
|
| BLAKE2b-256 |
1725ccf2afd46ff82156df09687fad0c943d9b0e6b10504f7a247642089f51e3
|