Skip to main content

JAX-compatible version of Open AI's gym

Project description

Gymnax - Classic Gym Environments in JAX

PyversionsPyPI versionColab

Are you fed up with slow CPU-based RL environment processes? Do you want to leverage massive vectorization for high-throughput RL experiments? gymnax brings the power of jit and vmap to classic OpenAI gym environments.

Basic gymnax API Usage :stew:

  • Classic Open AI gym wrapper including gymnax.make, env.reset, env.step:
import jax
import gymnax

rng = jax.random.PRNGKey(0)
rng, key_reset, key_policy, key_step = jax.random.split(rng, 4)

env, env_params = gymnax.make("Pendulum-v1")

obs, state = env.reset(key_reset, env_params)
action = env.action_space(env_params).sample(key_policy)
n_obs, n_state, reward, done, _ = env.step(key_step, state, action, env_params)

Episode Rollouts, Vectorization & Acceleration

  • Easy composition of JAX primitives (e.g. jit, vmap, pmap):
def rollout(rng_input, policy_params, env_params, num_env_steps):
      """Rollout a jitted gymnax episode with lax.scan."""
      # Reset the environment
      rng_reset, rng_episode = jax.random.split(rng_input)
      obs, state = env.reset(rng_reset, env_params)

      def policy_step(state_input, tmp):
          """lax.scan compatible step transition in jax env."""
          obs, state, policy_params, rng = state_input
          rng, rng_step, rng_net = jax.random.split(rng, 3)
          action = network.apply({"params": policy_params}, obs, rng=rng_net)
          next_o, next_s, reward, done, _ = env.step(
              rng_step, state, action, env_params
          )
          carry = [next_o.squeeze(), next_s, policy_params, rng]
          return carry, [reward, done]

      # Scan over episode step loop
      _, scan_out = jax.lax.scan(
          policy_step,
          [obs, state, policy_params, rng_episode],
          [jnp.zeros((num_env_steps, 2))],
      )
      # Return masked sum of rewards accumulated by agent in episode
      rewards, dones = scan_out[0], scan_out[1]
      rewards = rewards.reshape(num_env_steps, 1)
      ep_mask = (jnp.cumsum(dones) < 1).reshape(num_env_steps, 1)
      return jnp.sum(rewards * ep_mask)
# Jit-Compiled Episode Rollout
jit_rollout = jax.jit(rollout, static_argnums=3)

# Vmap across random keys for Batch Rollout
batch_rollout = jax.vmap(jit_rollout, in_axes=(0, None, None, None))
  • Vectorization over different environment parametrizations:
env.step(key_step, state, action, env_params)

Implemented Accelerated Environments :earth_africa:

Classic Control OpenAI gym environments.
Environment Name Implemented Tested Single Step Speed Gain (JAX vs. NumPy)
Pendulum-v0 :heavy_check_mark: :heavy_check_mark:
CartPole-v0 :heavy_check_mark: :heavy_check_mark:
MountainCar-v0 :heavy_check_mark: :heavy_check_mark:
MountainCarContinuous-v0 :heavy_check_mark: :heavy_check_mark:
Acrobot-v1 :heavy_check_mark: :heavy_check_mark:
DeepMind's BSuite environments.
Environment Name Implemented Tested Single Step Speed Gain (JAX vs. NumPy)
Catch-bsuite :heavy_check_mark: :heavy_check_mark:
DeepSea-bsuite :heavy_check_mark: :heavy_check_mark:
MemoryChain-bsuite :heavy_check_mark: :heavy_check_mark:
UmbrellaChain-bsuite :heavy_check_mark: :heavy_check_mark:
DiscountingChain-bsuite :heavy_check_mark: :heavy_check_mark:
MNISTBandit-bsuite :heavy_check_mark: :heavy_check_mark:
SimpleBandit-bsuite :heavy_check_mark: :heavy_check_mark:
K. Young's and T. Tian's MinAtar environments.
Environment Name Implemented Tested Single Step Speed Gain (JAX vs. NumPy)
Asterix-MinAtar :heavy_check_mark: :heavy_check_mark:
Breakout-MinAtar :heavy_check_mark: :heavy_check_mark:
Freeway-MinAtar :heavy_check_mark: :heavy_check_mark:
Seaquest-MinAtar :x: :x:
SpaceInvaders-MinAtar :heavy_check_mark: :heavy_check_mark:
Miscellaneous Environments.
Environment Name Implemented Tested Single Step Speed Gain (JAX vs. NumPy)
BernoulliBandit-misc :heavy_check_mark: :heavy_check_mark:
GaussianBandit-misc :heavy_check_mark: :heavy_check_mark:
FourRooms-misc :heavy_check_mark: :heavy_check_mark:

Installation :memo:

gymnax can be directly installed from PyPi.

pip install gymnax

Alternatively, you can clone this repository and 'manually' install the gymnax:

git clone https://github.com/RobertTLange/gymnax.git
cd gymnax
pip install -e .

Benchmarking Details :train:

Examples :school_satchel:

  • :notebook: Environment API - Check out the API and accelerated control environments.
  • :notebook: Anakin Agent - Check out the DeepMind's Anakin agent with gymnax's Catch-bsuite environment.
  • :notebook: CMA-ES - CMA-ES in JAX with vectorized population evaluation.

Acknowledgements & Citing gymnax :pencil2:

To cite this repository:

@software{gymnax2021github,
  author = {Robert Tjarko Lange},
  title = {{gymnax}: A {JAX}-based Reinforcement Learning Environment Library},
  url = {http://github.com/RobertTLange/gymnax},
  version = {0.0.1},
  year = {2021},
}

Much of the design of gymnax has been inspired by the classic OpenAI gym RL environment API and DeepMind's JAX eco-system. I am grateful to the JAX team and Matteo Hessel for their support and motivating words. Finally, a big thank you goes out to the TRC team at Google for granting me TPU quota for benchmarking gymnax.

Notes, Development & Questions :question:

  • If you find a bug or want a new feature, feel free to contact me @RobertTLange or create an issue :hugs:
  • You can check out the history of release modifications in CHANGELOG.md (added, changed, fixed).
  • You can find a set of open milestones in CONTRIBUTING.md.
Design Notes (control flow, random numbers, episode termination).
  1. Each step transition requires you to pass a set of environment parameters env.step(rng, state, action, env_params), which specify the 'hyperparameters' of the environment. You can
  2. gymnax automatically resets an episode after termination. This way we can ensure that trajectory rollouts with fixed amounts of steps continue rolling out transitions.
  3. If you want calculate evaluation returns simply mask the sum using the binary discount vector.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

gymnax-0.0.1.tar.gz (39.2 kB view hashes)

Uploaded Source

Built Distribution

gymnax-0.0.1-py3-none-any.whl (60.1 kB view hashes)

Uploaded Python 3

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page