JAX-compatible version of Open AI's gym
Project description
Gymnax - Classic Gym Environments in JAX
Are you fed up with slow CPU-based RL environment processes? Do you want to leverage massive vectorization for high-throughput RL experiments? gymnax
brings the power of jit
and vmap
to classic OpenAI gym environments.
Basic gymnax
API Usage :stew:
- Classic Open AI gym wrapper including
gymnax.make
,env.reset
,env.step
:
import jax
import gymnax
rng = jax.random.PRNGKey(0)
rng, key_reset, key_policy, key_step = jax.random.split(rng, 4)
env, env_params = gymnax.make("Pendulum-v1")
obs, state = env.reset(key_reset, env_params)
action = env.action_space(env_params).sample(key_policy)
n_obs, n_state, reward, done, _ = env.step(key_step, state, action, env_params)
Episode Rollouts, Vectorization & Acceleration
- Easy composition of JAX primitives (e.g.
jit
,vmap
,pmap
):
def rollout(rng_input, policy_params, env_params, num_env_steps):
"""Rollout a jitted gymnax episode with lax.scan."""
# Reset the environment
rng_reset, rng_episode = jax.random.split(rng_input)
obs, state = env.reset(rng_reset, env_params)
def policy_step(state_input, tmp):
"""lax.scan compatible step transition in jax env."""
obs, state, policy_params, rng = state_input
rng, rng_step, rng_net = jax.random.split(rng, 3)
action = network.apply({"params": policy_params}, obs, rng=rng_net)
next_o, next_s, reward, done, _ = env.step(
rng_step, state, action, env_params
)
carry = [next_o.squeeze(), next_s, policy_params, rng]
return carry, [reward, done]
# Scan over episode step loop
_, scan_out = jax.lax.scan(
policy_step,
[obs, state, policy_params, rng_episode],
[jnp.zeros((num_env_steps, 2))],
)
# Return masked sum of rewards accumulated by agent in episode
rewards, dones = scan_out[0], scan_out[1]
rewards = rewards.reshape(num_env_steps, 1)
ep_mask = (jnp.cumsum(dones) < 1).reshape(num_env_steps, 1)
return jnp.sum(rewards * ep_mask)
# Jit-Compiled Episode Rollout
jit_rollout = jax.jit(rollout, static_argnums=3)
# Vmap across random keys for Batch Rollout
batch_rollout = jax.vmap(jit_rollout, in_axes=(0, None, None, None))
- Vectorization over different environment parametrizations:
env.step(key_step, state, action, env_params)
Implemented Accelerated Environments :earth_africa:
Classic Control OpenAI gym environments.
Environment Name | Implemented | Tested | Single Step Speed Gain (JAX vs. NumPy) |
---|---|---|---|
Pendulum-v0 |
:heavy_check_mark: | :heavy_check_mark: | |
CartPole-v0 |
:heavy_check_mark: | :heavy_check_mark: | |
MountainCar-v0 |
:heavy_check_mark: | :heavy_check_mark: | |
MountainCarContinuous-v0 |
:heavy_check_mark: | :heavy_check_mark: | |
Acrobot-v1 |
:heavy_check_mark: | :heavy_check_mark: |
DeepMind's BSuite environments.
Environment Name | Implemented | Tested | Single Step Speed Gain (JAX vs. NumPy) |
---|---|---|---|
Catch-bsuite |
:heavy_check_mark: | :heavy_check_mark: | |
DeepSea-bsuite |
:heavy_check_mark: | :heavy_check_mark: | |
MemoryChain-bsuite |
:heavy_check_mark: | :heavy_check_mark: | |
UmbrellaChain-bsuite |
:heavy_check_mark: | :heavy_check_mark: | |
DiscountingChain-bsuite |
:heavy_check_mark: | :heavy_check_mark: | |
MNISTBandit-bsuite |
:heavy_check_mark: | :heavy_check_mark: | |
SimpleBandit-bsuite |
:heavy_check_mark: | :heavy_check_mark: |
K. Young's and T. Tian's MinAtar environments.
Environment Name | Implemented | Tested | Single Step Speed Gain (JAX vs. NumPy) |
---|---|---|---|
Asterix-MinAtar |
:heavy_check_mark: | :heavy_check_mark: | |
Breakout-MinAtar |
:heavy_check_mark: | :heavy_check_mark: | |
Freeway-MinAtar |
:heavy_check_mark: | :heavy_check_mark: | |
Seaquest-MinAtar |
:x: | :x: | |
SpaceInvaders-MinAtar |
:heavy_check_mark: | :heavy_check_mark: |
Miscellaneous Environments.
Environment Name | Implemented | Tested | Single Step Speed Gain (JAX vs. NumPy) |
---|---|---|---|
BernoulliBandit-misc |
:heavy_check_mark: | :heavy_check_mark: | |
GaussianBandit-misc |
:heavy_check_mark: | :heavy_check_mark: | |
FourRooms-misc |
:heavy_check_mark: | :heavy_check_mark: |
Installation :memo:
gymnax
can be directly installed from PyPi.
pip install gymnax
Alternatively, you can clone this repository and 'manually' install the gymnax
:
git clone https://github.com/RobertTLange/gymnax.git
cd gymnax
pip install -e .
Benchmarking Details :train:
Examples :school_satchel:
- :notebook: Environment API - Check out the API and accelerated control environments.
- :notebook: Anakin Agent - Check out the DeepMind's Anakin agent with
gymnax
'sCatch-bsuite
environment. - :notebook: CMA-ES - CMA-ES in JAX with vectorized population evaluation.
Acknowledgements & Citing gymnax
:pencil2:
To cite this repository:
@software{gymnax2021github,
author = {Robert Tjarko Lange},
title = {{gymnax}: A {JAX}-based Reinforcement Learning Environment Library},
url = {http://github.com/RobertTLange/gymnax},
version = {0.0.1},
year = {2021},
}
Much of the design of gymnax
has been inspired by the classic OpenAI gym RL environment API and DeepMind's JAX eco-system. I am grateful to the JAX team and Matteo Hessel for their support and motivating words. Finally, a big thank you goes out to the TRC team at Google for granting me TPU quota for benchmarking gymnax
.
Notes, Development & Questions :question:
- If you find a bug or want a new feature, feel free to contact me @RobertTLange or create an issue :hugs:
- You can check out the history of release modifications in
CHANGELOG.md
(added, changed, fixed). - You can find a set of open milestones in
CONTRIBUTING.md
.
Design Notes (control flow, random numbers, episode termination).
- Each step transition requires you to pass a set of environment parameters
env.step(rng, state, action, env_params)
, which specify the 'hyperparameters' of the environment. You can gymnax
automatically resets an episode after termination. This way we can ensure that trajectory rollouts with fixed amounts of steps continue rolling out transitions.- If you want calculate evaluation returns simply mask the sum using the binary discount vector.
Project details
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
File details
Details for the file gymnax-0.0.1.tar.gz
.
File metadata
- Download URL: gymnax-0.0.1.tar.gz
- Upload date:
- Size: 39.2 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/3.6.0 importlib_metadata/4.8.2 pkginfo/1.8.1 requests/2.26.0 requests-toolbelt/0.9.1 tqdm/4.62.3 CPython/3.10.0
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | 88f3814cee2e106c9b4b1f8fc10880d5d715be8a2386677067b1a4009897ce3f |
|
MD5 | f6cac498b5f85dfe3c2eaa9fd7832706 |
|
BLAKE2b-256 | 3ac8740b3f11f2abc119f69a02b3d53799d018aec399497d0637e452acaecc97 |
File details
Details for the file gymnax-0.0.1-py3-none-any.whl
.
File metadata
- Download URL: gymnax-0.0.1-py3-none-any.whl
- Upload date:
- Size: 60.1 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/3.6.0 importlib_metadata/4.8.2 pkginfo/1.8.1 requests/2.26.0 requests-toolbelt/0.9.1 tqdm/4.62.3 CPython/3.10.0
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | dd6691436c774ffae2fe65ec3b88b17655bd6ed99d2596b031f738d328f76fe5 |
|
MD5 | 6ab05a909d87763bae70412888f45a12 |
|
BLAKE2b-256 | 86550539e5434e3ddfb26b3cc8a7f500dc07bb070451b719862d33bf7a945617 |