JAX-compatible version of Open AI's gym
Project description
Gymnax - Classic Gym Environments in JAX
Are you fed up with slow CPU-based RL environment processes? Do you want to leverage massive vectorization for high-throughput RL experiments? gymnax brings the power of jit and vmap to classic OpenAI gym environments.
Basic gymnax API Usage :stew:
- Classic Open AI gym wrapper including
gymnax.make,env.reset,env.step:
import jax
import gymnax
rng = jax.random.PRNGKey(0)
rng, key_reset, key_policy, key_step = jax.random.split(rng, 4)
env, env_params = gymnax.make("Pendulum-v1")
obs, state = env.reset(key_reset, env_params)
action = env.action_space(env_params).sample(key_policy)
n_obs, n_state, reward, done, _ = env.step(key_step, state, action, env_params)
Episode Rollouts, Vectorization & Acceleration
- Easy composition of JAX primitives (e.g.
jit,vmap,pmap):
def rollout(rng_input, policy_params, env_params, num_env_steps):
"""Rollout a jitted gymnax episode with lax.scan."""
# Reset the environment
rng_reset, rng_episode = jax.random.split(rng_input)
obs, state = env.reset(rng_reset, env_params)
def policy_step(state_input, tmp):
"""lax.scan compatible step transition in jax env."""
obs, state, policy_params, rng = state_input
rng, rng_step, rng_net = jax.random.split(rng, 3)
action = network.apply({"params": policy_params}, obs, rng=rng_net)
next_o, next_s, reward, done, _ = env.step(
rng_step, state, action, env_params
)
carry = [next_o.squeeze(), next_s, policy_params, rng]
return carry, [reward, done]
# Scan over episode step loop
_, scan_out = jax.lax.scan(
policy_step,
[obs, state, policy_params, rng_episode],
[jnp.zeros((num_env_steps, 2))],
)
# Return masked sum of rewards accumulated by agent in episode
rewards, dones = scan_out[0], scan_out[1]
rewards = rewards.reshape(num_env_steps, 1)
ep_mask = (jnp.cumsum(dones) < 1).reshape(num_env_steps, 1)
return jnp.sum(rewards * ep_mask)
# Jit-Compiled Episode Rollout
jit_rollout = jax.jit(rollout, static_argnums=3)
# Vmap across random keys for Batch Rollout
batch_rollout = jax.vmap(jit_rollout, in_axes=(0, None, None, None))
- Vectorization over different environment parametrizations:
env.step(key_step, state, action, env_params)
Implemented Accelerated Environments :earth_africa:
Classic Control OpenAI gym environments.
| Environment Name | Implemented | Tested | Single Step Speed Gain (JAX vs. NumPy) |
|---|---|---|---|
Pendulum-v0 |
:heavy_check_mark: | :heavy_check_mark: | |
CartPole-v0 |
:heavy_check_mark: | :heavy_check_mark: | |
MountainCar-v0 |
:heavy_check_mark: | :heavy_check_mark: | |
MountainCarContinuous-v0 |
:heavy_check_mark: | :heavy_check_mark: | |
Acrobot-v1 |
:heavy_check_mark: | :heavy_check_mark: |
DeepMind's BSuite environments.
| Environment Name | Implemented | Tested | Single Step Speed Gain (JAX vs. NumPy) |
|---|---|---|---|
Catch-bsuite |
:heavy_check_mark: | :heavy_check_mark: | |
DeepSea-bsuite |
:heavy_check_mark: | :heavy_check_mark: | |
MemoryChain-bsuite |
:heavy_check_mark: | :heavy_check_mark: | |
UmbrellaChain-bsuite |
:heavy_check_mark: | :heavy_check_mark: | |
DiscountingChain-bsuite |
:heavy_check_mark: | :heavy_check_mark: | |
MNISTBandit-bsuite |
:heavy_check_mark: | :heavy_check_mark: | |
SimpleBandit-bsuite |
:heavy_check_mark: | :heavy_check_mark: |
K. Young's and T. Tian's MinAtar environments.
| Environment Name | Implemented | Tested | Single Step Speed Gain (JAX vs. NumPy) |
|---|---|---|---|
Asterix-MinAtar |
:heavy_check_mark: | :heavy_check_mark: | |
Breakout-MinAtar |
:heavy_check_mark: | :heavy_check_mark: | |
Freeway-MinAtar |
:heavy_check_mark: | :heavy_check_mark: | |
Seaquest-MinAtar |
:x: | :x: | |
SpaceInvaders-MinAtar |
:heavy_check_mark: | :heavy_check_mark: |
Miscellaneous Environments.
| Environment Name | Implemented | Tested | Single Step Speed Gain (JAX vs. NumPy) |
|---|---|---|---|
BernoulliBandit-misc |
:heavy_check_mark: | :heavy_check_mark: | |
GaussianBandit-misc |
:heavy_check_mark: | :heavy_check_mark: | |
FourRooms-misc |
:heavy_check_mark: | :heavy_check_mark: |
Installation :memo:
gymnax can be directly installed from PyPi.
pip install gymnax
Alternatively, you can clone this repository and 'manually' install the gymnax:
git clone https://github.com/RobertTLange/gymnax.git
cd gymnax
pip install -e .
Benchmarking Details :train:
Examples :school_satchel:
- :notebook: Environment API - Check out the API and accelerated control environments.
- :notebook: Anakin Agent - Check out the DeepMind's Anakin agent with
gymnax'sCatch-bsuiteenvironment. - :notebook: CMA-ES - CMA-ES in JAX with vectorized population evaluation.
Acknowledgements & Citing gymnax :pencil2:
To cite this repository:
@software{gymnax2021github,
author = {Robert Tjarko Lange},
title = {{gymnax}: A {JAX}-based Reinforcement Learning Environment Library},
url = {http://github.com/RobertTLange/gymnax},
version = {0.0.1},
year = {2021},
}
Much of the design of gymnax has been inspired by the classic OpenAI gym RL environment API and DeepMind's JAX eco-system. I am grateful to the JAX team and Matteo Hessel for their support and motivating words. Finally, a big thank you goes out to the TRC team at Google for granting me TPU quota for benchmarking gymnax.
Notes, Development & Questions :question:
- If you find a bug or want a new feature, feel free to contact me @RobertTLange or create an issue :hugs:
- You can check out the history of release modifications in
CHANGELOG.md(added, changed, fixed). - You can find a set of open milestones in
CONTRIBUTING.md.
Design Notes (control flow, random numbers, episode termination).
- Each step transition requires you to pass a set of environment parameters
env.step(rng, state, action, env_params), which specify the 'hyperparameters' of the environment. You can gymnaxautomatically resets an episode after termination. This way we can ensure that trajectory rollouts with fixed amounts of steps continue rolling out transitions.- If you want calculate evaluation returns simply mask the sum using the binary discount vector.
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file gymnax-0.0.1.tar.gz.
File metadata
- Download URL: gymnax-0.0.1.tar.gz
- Upload date:
- Size: 39.2 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/3.6.0 importlib_metadata/4.8.2 pkginfo/1.8.1 requests/2.26.0 requests-toolbelt/0.9.1 tqdm/4.62.3 CPython/3.10.0
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
88f3814cee2e106c9b4b1f8fc10880d5d715be8a2386677067b1a4009897ce3f
|
|
| MD5 |
f6cac498b5f85dfe3c2eaa9fd7832706
|
|
| BLAKE2b-256 |
3ac8740b3f11f2abc119f69a02b3d53799d018aec399497d0637e452acaecc97
|
File details
Details for the file gymnax-0.0.1-py3-none-any.whl.
File metadata
- Download URL: gymnax-0.0.1-py3-none-any.whl
- Upload date:
- Size: 60.1 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/3.6.0 importlib_metadata/4.8.2 pkginfo/1.8.1 requests/2.26.0 requests-toolbelt/0.9.1 tqdm/4.62.3 CPython/3.10.0
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
dd6691436c774ffae2fe65ec3b88b17655bd6ed99d2596b031f738d328f76fe5
|
|
| MD5 |
6ab05a909d87763bae70412888f45a12
|
|
| BLAKE2b-256 |
86550539e5434e3ddfb26b3cc8a7f500dc07bb070451b719862d33bf7a945617
|