Skip to main content

JAX-compatible version of Open AI's gym environments

Project description


Classic Gym Environments in JAX 🌍

Are you fed up with slow CPU-based RL environment processes? Do you want to leverage massive vectorization for high-throughput RL experiments? gymnax brings the power of jit and vmap/pmap to the classic gym API. It support a range of different environments including classic control, bsuite, MinAtar and a collection of classic RL tasks. gymnax allows explicit functional control of environment settings (random seed or hyperparameters), which enables parallelized rollouts for different configurations (e.g. for meta RL). Finally, we provide training pipelines and checkpoints for both PPO and ES in the gymnax-blines repository. Get started here 👉 Colab.

Basic gymnax API Usage 🍲

import jax
import gymnax

rng = jax.random.PRNGKey(0)
rng, key_reset, key_act, key_step = jax.random.split(rng, 4)

env, env_params = gymnax.make("Pendulum-v1")

obs, state = env.reset(key_reset, env_params)
action = env.action_space(env_params).sample(key_act)
n_obs, n_state, reward, done, _ = env.step(key_step, state, action, env_params)

Implemented Accelerated Environments 🏎️

Environment Name Reference Source Speed Up vs np Trained 🤖 (Avg Return)
Pendulum-v1 Brockman et al. (2016) Click ~?x PPO, ES (R: -100)
CartPole-v1 Brockman et al. (2016) Click ~?x PPO, ES (R: -100)
MountainCar-v0 Brockman et al. (2016) Click ~?x PPO, ES (R: -100)
MountainCarContinuous-v0 Brockman et al. (2016) Click ~?x PPO, ES (R: -100)
Acrobot-v1 Brockman et al. (2016) Click ~?x PPO, ES (R: -100)
Catch-bsuite Osband et al. (2019) Click ~?x PPO, ES (R: -100)
DeepSea-bsuite Osband et al. (2019) Click ~?x PPO, ES (R: -100)
MemoryChain-bsuite Osband et al. (2019) Click ~?x PPO, ES (R: -100)
UmbrellaChain-bsuite Osband et al. (2019) Click ~?x PPO, ES (R: -100)
DiscountingChain-bsuite Osband et al. (2019) Click ~?x PPO, ES (R: -100)
MNISTBandit-bsuite Osband et al. (2019) Click ~?x PPO, ES (R: -100)
SimpleBandit-bsuite Osband et al. (2019) Click ~?x PPO, ES (R: -100)
Asterix-MinAtar Young & Tian (2019) Click ~?x PPO, ES (R: -100)
Breakout-MinAtar Young & Tian (2019) Click ~?x PPO, ES (R: -100)
Freeway-MinAtar Young & Tian (2019) Click ~?x PPO, ES (R: -100)
Seaquest-MinAtar Young & Tian (2019) Click ~? x PPO, ES (R: -100)
SpaceInvaders-MinAtar Young & Tian (2019) Click ~?x PPO, ES (R: -100)
FourRooms-misc Sutton et al. (1999) Click ~?x PPO, ES (R: -100)
MetaMaze-misc Micconi et al. (2020) Click ~?x PPO, ES (R: -100)
PointRobot-misc Dorfman et al. (2021) Click ~?x PPO, ES (R: -100)
BernoulliBandit-misc Wang et al. (2017) - - PPO, ES (R: -100)
GaussianBandit-misc Lange & Sprekeler (2022) - - PPO, ES (R: -100)

Installation ⏳

The latest gymnax release can directly be installed from PyPI:

pip install gymnax

If you want to get the most recent commit, please install directly from the repository:

pip install git+https://github.com/RobertTLange/gymnax.git@main

In order to use JAX on your accelerators, you can find more details in the JAX documentation.

Examples 📖

  • 📓 Environment API - Check out the API, how to train an Anakin agent on Catch-bsuite
  • 📓 ES with gymnax - Using CMA-ES in JAX with vectorized population evaluations powered by gymnax.
  • 📓 Trained baselines - Check out the trained baseline agents in gymnax-blines.

Key Selling Points 💵

  • Environment vectorization & acceleration: Easy composition of JAX primitives (e.g. jit, vmap, pmap):
# Jit-accelerated step transition
jit_step = jax.jit(env.step)

# vmap across random keys for batch rollouts
vreset_rng = jax.vmap(env.reset, in_axes=(0, None))
vstep_rng = jax.vmap(env.step, in_axes=(0, 0, 0, None))

# vmap across environment parameters (e.g. for meta-learning)
vreset_env = jax.vmap(env.reset, in_axes=(None, 0))
vstep_env = jax.vmap(env.step, in_axes=(None, 0, 0, 0))
  • Scan through entire episode rollouts: You can also lax.scan through entire reset, step episode loops for fast compilation:
def rollout(rng_input, policy_params, env_params, num_env_steps):
      """Rollout a jitted gymnax episode with lax.scan."""
      # Reset the environment
      rng_reset, rng_episode = jax.random.split(rng_input)
      obs, state = env.reset(rng_reset, env_params)

      def policy_step(state_input, tmp):
          """lax.scan compatible step transition in jax env."""
          obs, state, policy_params, rng = state_input
          rng, rng_step, rng_net = jax.random.split(rng, 3)
          action = network.apply(policy_params, obs)
          next_o, next_s, reward, done, _ = env.step(
              rng_step, state, action, env_params
          )
          carry = [next_o, next_s, policy_params, rng]
          return carry, [reward, done]

      # Scan over episode step loop
      _, scan_out = jax.lax.scan(
          policy_step,
          [obs, state, policy_params, rng_episode],
          [jnp.zeros((num_env_steps, 2))],
      )
      # Return masked sum of rewards accumulated by agent in episode
      rewards, dones = scan_out[0], scan_out[1]
      rewards = rewards.reshape(num_env_steps, 1)
      ep_mask = (jnp.cumsum(dones) < 1).reshape(num_env_steps, 1)
      return jnp.sum(rewards * ep_mask)
  • Super fast acceleration:

  • Training pipelines & pretrained agents: Check out gymnax-blines for trained agents and PPO/ES pipelines.

Acknowledgements & Citing gymnax ✏️

If you use gymnax in your research, please cite it as follows:

@software{gymnax2022github,
  author = {Robert Tjarko Lange},
  title = {{gymnax}: A {JAX}-based Reinforcement Learning Environment Library},
  url = {http://github.com/RobertTLange/gymnax},
  version = {0.0.2},
  year = {2022},
}

We acknowledge financial support the Google TRC and the Deutsche Forschungsgemeinschaft (DFG, German Research Foundation) under Germany's Excellence Strategy - EXC 2002/1 "Science of Intelligence" - project number 390523135.

Development 👷

You can run the test suite via python -m pytest -vv --all. If you find a bug or are missing your favourite feature, feel free to create an issue and/or start contributing 🤗.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

gymnax-0.0.2.tar.gz (19.4 kB view hashes)

Uploaded Source

Built Distribution

gymnax-0.0.2-py3-none-any.whl (17.6 kB view hashes)

Uploaded Python 3

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page