Skip to main content

JAX-compatible version of Open AI's gym environments

Project description


Classic Gym Environments in JAX 🌍

Are you fed up with slow CPU-based RL environment processes? Do you want to leverage massive vectorization for high-throughput RL experiments? gymnax brings the power of jit and vmap/pmap to the classic gym API. It support a range of different environments including classic control, bsuite, MinAtar and a collection of classic RL tasks. gymnax allows explicit functional control of environment settings (random seed or hyperparameters), which enables parallelized rollouts for different configurations (e.g. for meta RL). Finally, we provide training pipelines and checkpoints for both PPO and ES in the gymnax-blines repository. Get started here 👉 Colab.

Basic gymnax API Usage 🍲

import jax
import gymnax

rng = jax.random.PRNGKey(0)
rng, key_reset, key_act, key_step = jax.random.split(rng, 4)

env, env_params = gymnax.make("Pendulum-v1")

obs, state = env.reset(key_reset, env_params)
action = env.action_space(env_params).sample(key_act)
n_obs, n_state, reward, done, _ = env.step(key_step, state, action, env_params)

Implemented Accelerated Environments 🏎️

Environment Name Reference Source Speed Up vs np Trained 🤖 (Avg Return)
Pendulum-v1 Brockman et al. (2016) Click ~?x PPO, ES (R: -100)
CartPole-v1 Brockman et al. (2016) Click ~?x PPO, ES (R: -100)
MountainCar-v0 Brockman et al. (2016) Click ~?x PPO, ES (R: -100)
MountainCarContinuous-v0 Brockman et al. (2016) Click ~?x PPO, ES (R: -100)
Acrobot-v1 Brockman et al. (2016) Click ~?x PPO, ES (R: -100)
Catch-bsuite Osband et al. (2019) Click ~?x PPO, ES (R: -100)
DeepSea-bsuite Osband et al. (2019) Click ~?x PPO, ES (R: -100)
MemoryChain-bsuite Osband et al. (2019) Click ~?x PPO, ES (R: -100)
UmbrellaChain-bsuite Osband et al. (2019) Click ~?x PPO, ES (R: -100)
DiscountingChain-bsuite Osband et al. (2019) Click ~?x PPO, ES (R: -100)
MNISTBandit-bsuite Osband et al. (2019) Click ~?x PPO, ES (R: -100)
SimpleBandit-bsuite Osband et al. (2019) Click ~?x PPO, ES (R: -100)
Asterix-MinAtar Young & Tian (2019) Click ~?x PPO, ES (R: -100)
Breakout-MinAtar Young & Tian (2019) Click ~?x PPO, ES (R: -100)
Freeway-MinAtar Young & Tian (2019) Click ~?x PPO, ES (R: -100)
Seaquest-MinAtar Young & Tian (2019) Click ~? x PPO, ES (R: -100)
SpaceInvaders-MinAtar Young & Tian (2019) Click ~?x PPO, ES (R: -100)
FourRooms-misc Sutton et al. (1999) Click ~?x PPO, ES (R: -100)
MetaMaze-misc Micconi et al. (2020) Click ~?x PPO, ES (R: -100)
PointRobot-misc Dorfman et al. (2021) Click ~?x PPO, ES (R: -100)
BernoulliBandit-misc Wang et al. (2017) - - PPO, ES (R: -100)
GaussianBandit-misc Lange & Sprekeler (2022) - - PPO, ES (R: -100)

Installation ⏳

The latest gymnax release can directly be installed from PyPI:

pip install gymnax

If you want to get the most recent commit, please install directly from the repository:

pip install git+https://github.com/RobertTLange/gymnax.git@main

In order to use JAX on your accelerators, you can find more details in the JAX documentation.

Examples 📖

  • 📓 Environment API - Check out the API, how to train an Anakin agent on Catch-bsuite
  • 📓 ES with gymnax - Using CMA-ES in JAX with vectorized population evaluations powered by gymnax.
  • 📓 Trained baselines - Check out the trained baseline agents in gymnax-blines.

Key Selling Points 💵

  • Environment vectorization & acceleration: Easy composition of JAX primitives (e.g. jit, vmap, pmap):
# Jit-accelerated step transition
jit_step = jax.jit(env.step)

# vmap across random keys for batch rollouts
vreset_rng = jax.vmap(env.reset, in_axes=(0, None))
vstep_rng = jax.vmap(env.step, in_axes=(0, 0, 0, None))

# vmap across environment parameters (e.g. for meta-learning)
vreset_env = jax.vmap(env.reset, in_axes=(None, 0))
vstep_env = jax.vmap(env.step, in_axes=(None, 0, 0, 0))
  • Scan through entire episode rollouts: You can also lax.scan through entire reset, step episode loops for fast compilation:
def rollout(rng_input, policy_params, env_params, num_env_steps):
      """Rollout a jitted gymnax episode with lax.scan."""
      # Reset the environment
      rng_reset, rng_episode = jax.random.split(rng_input)
      obs, state = env.reset(rng_reset, env_params)

      def policy_step(state_input, tmp):
          """lax.scan compatible step transition in jax env."""
          obs, state, policy_params, rng = state_input
          rng, rng_step, rng_net = jax.random.split(rng, 3)
          action = network.apply(policy_params, obs)
          next_o, next_s, reward, done, _ = env.step(
              rng_step, state, action, env_params
          )
          carry = [next_o, next_s, policy_params, rng]
          return carry, [reward, done]

      # Scan over episode step loop
      _, scan_out = jax.lax.scan(
          policy_step,
          [obs, state, policy_params, rng_episode],
          [jnp.zeros((num_env_steps, 2))],
      )
      # Return masked sum of rewards accumulated by agent in episode
      rewards, dones = scan_out[0], scan_out[1]
      rewards = rewards.reshape(num_env_steps, 1)
      ep_mask = (jnp.cumsum(dones) < 1).reshape(num_env_steps, 1)
      return jnp.sum(rewards * ep_mask)
  • Super fast acceleration:

  • Training pipelines & pretrained agents: Check out gymnax-blines for trained agents and PPO/ES pipelines.

Acknowledgements & Citing gymnax ✏️

If you use gymnax in your research, please cite it as follows:

@software{gymnax2022github,
  author = {Robert Tjarko Lange},
  title = {{gymnax}: A {JAX}-based Reinforcement Learning Environment Library},
  url = {http://github.com/RobertTLange/gymnax},
  version = {0.0.2},
  year = {2022},
}

We acknowledge financial support the Google TRC and the Deutsche Forschungsgemeinschaft (DFG, German Research Foundation) under Germany's Excellence Strategy - EXC 2002/1 "Science of Intelligence" - project number 390523135.

Development 👷

You can run the test suite via python -m pytest -vv --all. If you find a bug or are missing your favourite feature, feel free to create an issue and/or start contributing 🤗.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

gymnax-0.0.2.tar.gz (19.4 kB view details)

Uploaded Source

Built Distribution

gymnax-0.0.2-py3-none-any.whl (17.6 kB view details)

Uploaded Python 3

File details

Details for the file gymnax-0.0.2.tar.gz.

File metadata

  • Download URL: gymnax-0.0.2.tar.gz
  • Upload date:
  • Size: 19.4 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/4.0.1 CPython/3.10.4

File hashes

Hashes for gymnax-0.0.2.tar.gz
Algorithm Hash digest
SHA256 d0572b91fb4872f0947bd01cafa750d0e523411da8e8fa885a2eec26ff24be6c
MD5 42ba546cf046a8b7c6bd57d4408fdd14
BLAKE2b-256 b0aa21bae034deac4508c8352b9ceb446b8df94985b004f39b9a82b5925450ec

See more details on using hashes here.

File details

Details for the file gymnax-0.0.2-py3-none-any.whl.

File metadata

  • Download URL: gymnax-0.0.2-py3-none-any.whl
  • Upload date:
  • Size: 17.6 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/4.0.1 CPython/3.10.4

File hashes

Hashes for gymnax-0.0.2-py3-none-any.whl
Algorithm Hash digest
SHA256 87a8d08675801e553797bdb4cddfa730e6f14a184ec8ba9cc132a2a23c71396f
MD5 5203122a8fa5986d255ae6f41c6a6998
BLAKE2b-256 1d927c5e1933c531085e00d9bd905b943061dc75292edc446aca18d7539f7736

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page