JAX-compatible version of Open AI's gym environments
Project description
Classic Gym Environments in JAX 🌍
Are you fed up with slow CPU-based RL environment processes? Do you want to leverage massive vectorization for high-throughput RL experiments? gymnax
brings the power of jit
and vmap
/pmap
to the classic gym API. It support a range of different environments including classic control, bsuite, MinAtar and a collection of classic RL tasks. gymnax
allows explicit functional control of environment settings (random seed or hyperparameters), which enables parallelized rollouts for different configurations (e.g. for meta RL). Finally, we provide training pipelines and checkpoints for both PPO and ES in the gymnax-blines
repository. Get started here 👉 .
Basic gymnax
API Usage 🍲
import jax
import gymnax
rng = jax.random.PRNGKey(0)
rng, key_reset, key_act, key_step = jax.random.split(rng, 4)
env, env_params = gymnax.make("Pendulum-v1")
obs, state = env.reset(key_reset, env_params)
action = env.action_space(env_params).sample(key_act)
n_obs, n_state, reward, done, _ = env.step(key_step, state, action, env_params)
Implemented Accelerated Environments 🏎️
Installation ⏳
The latest gymnax
release can directly be installed from PyPI:
pip install gymnax
If you want to get the most recent commit, please install directly from the repository:
pip install git+https://github.com/RobertTLange/gymnax.git@main
In order to use JAX on your accelerators, you can find more details in the JAX documentation.
Examples 📖
- 📓 Environment API - Check out the API, how to train an Anakin agent on
Catch-bsuite
- 📓 ES with
gymnax
- Using CMA-ES in JAX with vectorized population evaluations powered bygymnax
. - 📓 Trained baselines - Check out the trained baseline agents in
gymnax-blines
.
Key Selling Points 💵
- Environment vectorization & acceleration: Easy composition of JAX primitives (e.g.
jit
,vmap
,pmap
):
# Jit-accelerated step transition
jit_step = jax.jit(env.step)
# vmap across random keys for batch rollouts
vreset_rng = jax.vmap(env.reset, in_axes=(0, None))
vstep_rng = jax.vmap(env.step, in_axes=(0, 0, 0, None))
# vmap across environment parameters (e.g. for meta-learning)
vreset_env = jax.vmap(env.reset, in_axes=(None, 0))
vstep_env = jax.vmap(env.step, in_axes=(None, 0, 0, 0))
- Scan through entire episode rollouts: You can also
lax.scan
through entirereset
,step
episode loops for fast compilation:
def rollout(rng_input, policy_params, env_params, num_env_steps):
"""Rollout a jitted gymnax episode with lax.scan."""
# Reset the environment
rng_reset, rng_episode = jax.random.split(rng_input)
obs, state = env.reset(rng_reset, env_params)
def policy_step(state_input, tmp):
"""lax.scan compatible step transition in jax env."""
obs, state, policy_params, rng = state_input
rng, rng_step, rng_net = jax.random.split(rng, 3)
action = network.apply(policy_params, obs)
next_o, next_s, reward, done, _ = env.step(
rng_step, state, action, env_params
)
carry = [next_o, next_s, policy_params, rng]
return carry, [reward, done]
# Scan over episode step loop
_, scan_out = jax.lax.scan(
policy_step,
[obs, state, policy_params, rng_episode],
[jnp.zeros((num_env_steps, 2))],
)
# Return masked sum of rewards accumulated by agent in episode
rewards, dones = scan_out[0], scan_out[1]
rewards = rewards.reshape(num_env_steps, 1)
ep_mask = (jnp.cumsum(dones) < 1).reshape(num_env_steps, 1)
return jnp.sum(rewards * ep_mask)
- Super fast acceleration:
- Training pipelines & pretrained agents: Check out
gymnax-blines
for trained agents and PPO/ES pipelines.
Acknowledgements & Citing gymnax
✏️
If you use gymnax
in your research, please cite it as follows:
@software{gymnax2022github,
author = {Robert Tjarko Lange},
title = {{gymnax}: A {JAX}-based Reinforcement Learning Environment Library},
url = {http://github.com/RobertTLange/gymnax},
version = {0.0.2},
year = {2022},
}
We acknowledge financial support the Google TRC and the Deutsche Forschungsgemeinschaft (DFG, German Research Foundation) under Germany's Excellence Strategy - EXC 2002/1 "Science of Intelligence" - project number 390523135.
Development 👷
You can run the test suite via python -m pytest -vv --all
. If you find a bug or are missing your favourite feature, feel free to create an issue and/or start contributing 🤗.
Project details
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
File details
Details for the file gymnax-0.0.2.tar.gz
.
File metadata
- Download URL: gymnax-0.0.2.tar.gz
- Upload date:
- Size: 19.4 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/4.0.1 CPython/3.10.4
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | d0572b91fb4872f0947bd01cafa750d0e523411da8e8fa885a2eec26ff24be6c |
|
MD5 | 42ba546cf046a8b7c6bd57d4408fdd14 |
|
BLAKE2b-256 | b0aa21bae034deac4508c8352b9ceb446b8df94985b004f39b9a82b5925450ec |
File details
Details for the file gymnax-0.0.2-py3-none-any.whl
.
File metadata
- Download URL: gymnax-0.0.2-py3-none-any.whl
- Upload date:
- Size: 17.6 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/4.0.1 CPython/3.10.4
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | 87a8d08675801e553797bdb4cddfa730e6f14a184ec8ba9cc132a2a23c71396f |
|
MD5 | 5203122a8fa5986d255ae6f41c6a6998 |
|
BLAKE2b-256 | 1d927c5e1933c531085e00d9bd905b943061dc75292edc446aca18d7539f7736 |