Skip to main content

JAX-compatible version of Open AI's gym environments

Project description


Classic Gym Environments in JAX 🌍

Are you fed up with slow CPU-based RL environment processes? Do you want to leverage massive vectorization for high-throughput RL experiments? gymnax brings the power of jit and vmap/pmap to the classic gym API. It support a range of different environments including classic control, bsuite, MinAtar and a collection of classic RL tasks. gymnax allows explicit functional control of environment settings (random seed or hyperparameters), which enables parallelized rollouts for different configurations (e.g. for meta RL). Finally, we provide training pipelines and checkpoints for both PPO and ES in the gymnax-blines repository. Get started here 👉 Colab.

Basic gymnax API Usage 🍲

import jax
import gymnax

rng = jax.random.PRNGKey(0)
rng, key_reset, key_act, key_step = jax.random.split(rng, 4)

env, env_params = gymnax.make("Pendulum-v1")

obs, state = env.reset(key_reset, env_params)
action = env.action_space(env_params).sample(key_act)
n_obs, n_state, reward, done, _ = env.step(key_step, state, action, env_params)

Implemented Accelerated Environments 🏎️

Environment Name Reference Source np Speed Up 🤖 ckpt (Return)
Pendulum-v1 Brockman et al. (2016) Click ~?x PPO, ES (R: -100)
CartPole-v1 Brockman et al. (2016) Click ~?x PPO, ES (R: -100)
MountainCar-v0 Brockman et al. (2016) Click ~?x PPO, ES (R: -100)
MountainCarContinuous-v0 Brockman et al. (2016) Click ~?x PPO, ES (R: -100)
Acrobot-v1 Brockman et al. (2016) Click ~?x PPO, ES (R: -100)
Catch-bsuite Osband et al. (2019) Click ~?x PPO, ES (R: -100)
DeepSea-bsuite Osband et al. (2019) Click ~?x PPO, ES (R: -100)
MemoryChain-bsuite Osband et al. (2019) Click ~?x PPO, ES (R: -100)
UmbrellaChain-bsuite Osband et al. (2019) Click ~?x PPO, ES (R: -100)
DiscountingChain-bsuite Osband et al. (2019) Click ~?x PPO, ES (R: -100)
MNISTBandit-bsuite Osband et al. (2019) Click ~?x PPO, ES (R: -100)
SimpleBandit-bsuite Osband et al. (2019) Click ~?x PPO, ES (R: -100)
Asterix-MinAtar Young & Tian (2019) Click ~?x PPO, ES (R: -100)
Breakout-MinAtar Young & Tian (2019) Click ~?x PPO, ES (R: -100)
Freeway-MinAtar Young & Tian (2019) Click ~?x PPO, ES (R: -100)
Seaquest-MinAtar Young & Tian (2019) Click ~? x PPO, ES (R: -100)
SpaceInvaders-MinAtar Young & Tian (2019) Click ~?x PPO, ES (R: -100)
FourRooms-misc Sutton et al. (1999) Click ~?x PPO, ES (R: -100)
MetaMaze-misc Micconi et al. (2020) Click ~?x PPO, ES (R: -100)
PointRobot-misc Dorfman et al. (2021) Click ~?x PPO, ES (R: -100)
BernoulliBandit-misc Wang et al. (2017) - - PPO, ES (R: -100)
GaussianBandit-misc Lange & Sprekeler (2022) - - PPO, ES (R: -100)

Installation ⏳

The latest gymnax release can directly be installed from PyPI:

pip install gymnax

If you want to get the most recent commit, please install directly from the repository:

pip install git+https://github.com/RobertTLange/gymnax.git@main

In order to use JAX on your accelerators, you can find more details in the JAX documentation.

Examples 📖

Key Selling Points 💵

  • Environment vectorization & acceleration: Easy composition of JAX primitives (e.g. jit, vmap, pmap):
# Jit-accelerated step transition
jit_step = jax.jit(env.step)

# vmap across random keys for batch rollouts
vreset_rng = jax.vmap(env.reset, in_axes=(0, None))
vstep_rng = jax.vmap(env.step, in_axes=(0, 0, 0, None))

# vmap across environment parameters (e.g. for meta-learning)
vreset_env = jax.vmap(env.reset, in_axes=(None, 0))
vstep_env = jax.vmap(env.step, in_axes=(None, 0, 0, 0))
  • Scan through entire episode rollouts: You can also lax.scan through entire reset, step episode loops for fast compilation:
def rollout(rng_input, policy_params, env_params, num_env_steps):
      """Rollout a jitted gymnax episode with lax.scan."""
      # Reset the environment
      rng_reset, rng_episode = jax.random.split(rng_input)
      obs, state = env.reset(rng_reset, env_params)

      def policy_step(state_input, tmp):
          """lax.scan compatible step transition in jax env."""
          obs, state, policy_params, rng = state_input
          rng, rng_step, rng_net = jax.random.split(rng, 3)
          action = network.apply(policy_params, obs)
          next_o, next_s, reward, done, _ = env.step(
              rng_step, state, action, env_params
          )
          carry = [next_o, next_s, policy_params, rng]
          return carry, [reward, done]

      # Scan over episode step loop
      _, scan_out = jax.lax.scan(
          policy_step,
          [obs, state, policy_params, rng_episode],
          [jnp.zeros((num_env_steps, 2))],
      )
      # Return masked sum of rewards accumulated by agent in episode
      rewards, dones = scan_out[0], scan_out[1]
      rewards = rewards.reshape(num_env_steps, 1)
      ep_mask = (jnp.cumsum(dones) < 1).reshape(num_env_steps, 1)
      return jnp.sum(rewards * ep_mask)
  • Super fast acceleration:

  • Training pipelines & pretrained agents: Check out gymnax-blines for trained agents and PPO/ES pipelines.

Acknowledgements & Citing gymnax ✏️

If you use gymnax in your research, please cite it as follows:

@software{gymnax2022github,
  author = {Robert Tjarko Lange},
  title = {{gymnax}: A {JAX}-based Reinforcement Learning Environment Library},
  url = {http://github.com/RobertTLange/gymnax},
  version = {0.0.2},
  year = {2022},
}

We acknowledge financial support the Google TRC and the Deutsche Forschungsgemeinschaft (DFG, German Research Foundation) under Germany's Excellence Strategy - EXC 2002/1 "Science of Intelligence" - project number 390523135.

Development 👷

You can run the test suite via python -m pytest -vv --all. If you find a bug or are missing your favourite feature, feel free to create an issue and/or start contributing 🤗.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

gymnax-0.0.3.tar.gz (47.5 kB view details)

Uploaded Source

Built Distribution

gymnax-0.0.3-py3-none-any.whl (72.0 kB view details)

Uploaded Python 3

File details

Details for the file gymnax-0.0.3.tar.gz.

File metadata

  • Download URL: gymnax-0.0.3.tar.gz
  • Upload date:
  • Size: 47.5 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/4.0.1 CPython/3.10.4

File hashes

Hashes for gymnax-0.0.3.tar.gz
Algorithm Hash digest
SHA256 8435397df780b376c01235b7fe3c91df9d32d8c47183b648ca0dd8f75270ea17
MD5 69f011cc6181e4cc78bc60384d3a792f
BLAKE2b-256 10fbd96669c465c074595d9bcdde0d78da9d20b0a00fe9b9a593c661e061e0c3

See more details on using hashes here.

File details

Details for the file gymnax-0.0.3-py3-none-any.whl.

File metadata

  • Download URL: gymnax-0.0.3-py3-none-any.whl
  • Upload date:
  • Size: 72.0 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/4.0.1 CPython/3.10.4

File hashes

Hashes for gymnax-0.0.3-py3-none-any.whl
Algorithm Hash digest
SHA256 5bba9484936760cbebce3307888b775e1e3818529eeb5f48bda1f731a5fda2c2
MD5 64c2634abcb4f80dd525f77f10bdc761
BLAKE2b-256 3d03c6fbd99ec01d0d6299f5e8ff2785d196afbd075f8e0a03c6cccdd84115a2

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page