Skip to main content

JAX-compatible version of Open AI's gym

Project description

Gymnax - Classic Gym Environments in JAX

PyversionsPyPI versionColab

Are you fed up with slow CPU-based RL environment processes? Do you want to leverage massive vectorization for high-throughput RL experiments? gymnax brings the power of jit and vmap to classic OpenAI gym environments.

Basic gymnax API Usage :stew:

  • Classic Open AI gym wrapper including gymnax.make, env.reset, env.step:
import jax
import gymnax

rng = jax.random.PRNGKey(0)
rng, key_reset, key_policy, key_step = jax.random.split(rng, 4)

env, env_params = gymnax.make("Pendulum-v1")

obs, state = env.reset(key_reset, env_params)
action = env.action_space(env_params).sample(key_policy)
n_obs, n_state, reward, done, _ = env.step(key_step, state, action, env_params)

Episode Rollouts, Vectorization & Acceleration

  • Easy composition of JAX primitives (e.g. jit, vmap, pmap):
def rollout(rng_input, policy_params, env_params, num_env_steps):
      """Rollout a jitted gymnax episode with lax.scan."""
      # Reset the environment
      rng_reset, rng_episode = jax.random.split(rng_input)
      obs, state = env.reset(rng_reset, env_params)

      def policy_step(state_input, tmp):
          """lax.scan compatible step transition in jax env."""
          obs, state, policy_params, rng = state_input
          rng, rng_step, rng_net = jax.random.split(rng, 3)
          action = network.apply({"params": policy_params}, obs, rng=rng_net)
          next_o, next_s, reward, done, _ = env.step(
              rng_step, state, action, env_params
          )
          carry = [next_o.squeeze(), next_s, policy_params, rng]
          return carry, [reward, done]

      # Scan over episode step loop
      _, scan_out = jax.lax.scan(
          policy_step,
          [obs, state, policy_params, rng_episode],
          [jnp.zeros((num_env_steps, 2))],
      )
      # Return masked sum of rewards accumulated by agent in episode
      rewards, dones = scan_out[0], scan_out[1]
      rewards = rewards.reshape(num_env_steps, 1)
      ep_mask = (jnp.cumsum(dones) < 1).reshape(num_env_steps, 1)
      return jnp.sum(rewards * ep_mask)
# Jit-Compiled Episode Rollout
jit_rollout = jax.jit(rollout, static_argnums=3)

# Vmap across random keys for Batch Rollout
batch_rollout = jax.vmap(jit_rollout, in_axes=(0, None, None, None))
  • Vectorization over different environment parametrizations:
env.step(key_step, state, action, env_params)

Implemented Accelerated Environments :earth_africa:

Classic Control OpenAI gym environments.
Environment Name Implemented Tested Single Step Speed Gain (JAX vs. NumPy)
Pendulum-v0 :heavy_check_mark: :heavy_check_mark:
CartPole-v0 :heavy_check_mark: :heavy_check_mark:
MountainCar-v0 :heavy_check_mark: :heavy_check_mark:
MountainCarContinuous-v0 :heavy_check_mark: :heavy_check_mark:
Acrobot-v1 :heavy_check_mark: :heavy_check_mark:
DeepMind's BSuite environments.
Environment Name Implemented Tested Single Step Speed Gain (JAX vs. NumPy)
Catch-bsuite :heavy_check_mark: :heavy_check_mark:
DeepSea-bsuite :heavy_check_mark: :heavy_check_mark:
MemoryChain-bsuite :heavy_check_mark: :heavy_check_mark:
UmbrellaChain-bsuite :heavy_check_mark: :heavy_check_mark:
DiscountingChain-bsuite :heavy_check_mark: :heavy_check_mark:
MNISTBandit-bsuite :heavy_check_mark: :heavy_check_mark:
SimpleBandit-bsuite :heavy_check_mark: :heavy_check_mark:
K. Young's and T. Tian's MinAtar environments.
Environment Name Implemented Tested Single Step Speed Gain (JAX vs. NumPy)
Asterix-MinAtar :heavy_check_mark: :heavy_check_mark:
Breakout-MinAtar :heavy_check_mark: :heavy_check_mark:
Freeway-MinAtar :heavy_check_mark: :heavy_check_mark:
Seaquest-MinAtar :x: :x:
SpaceInvaders-MinAtar :heavy_check_mark: :heavy_check_mark:
Miscellaneous Environments.
Environment Name Implemented Tested Single Step Speed Gain (JAX vs. NumPy)
BernoulliBandit-misc :heavy_check_mark: :heavy_check_mark:
GaussianBandit-misc :heavy_check_mark: :heavy_check_mark:
FourRooms-misc :heavy_check_mark: :heavy_check_mark:

Installation :memo:

gymnax can be directly installed from PyPi.

pip install gymnax

Alternatively, you can clone this repository and 'manually' install the gymnax:

git clone https://github.com/RobertTLange/gymnax.git
cd gymnax
pip install -e .

Benchmarking Details :train:

Examples :school_satchel:

  • :notebook: Environment API - Check out the API and accelerated control environments.
  • :notebook: Anakin Agent - Check out the DeepMind's Anakin agent with gymnax's Catch-bsuite environment.
  • :notebook: CMA-ES - CMA-ES in JAX with vectorized population evaluation.

Acknowledgements & Citing gymnax :pencil2:

To cite this repository:

@software{gymnax2021github,
  author = {Robert Tjarko Lange},
  title = {{gymnax}: A {JAX}-based Reinforcement Learning Environment Library},
  url = {http://github.com/RobertTLange/gymnax},
  version = {0.0.1},
  year = {2021},
}

Much of the design of gymnax has been inspired by the classic OpenAI gym RL environment API and DeepMind's JAX eco-system. I am grateful to the JAX team and Matteo Hessel for their support and motivating words. Finally, a big thank you goes out to the TRC team at Google for granting me TPU quota for benchmarking gymnax.

Notes, Development & Questions :question:

  • If you find a bug or want a new feature, feel free to contact me @RobertTLange or create an issue :hugs:
  • You can check out the history of release modifications in CHANGELOG.md (added, changed, fixed).
  • You can find a set of open milestones in CONTRIBUTING.md.
Design Notes (control flow, random numbers, episode termination).
  1. Each step transition requires you to pass a set of environment parameters env.step(rng, state, action, env_params), which specify the 'hyperparameters' of the environment. You can
  2. gymnax automatically resets an episode after termination. This way we can ensure that trajectory rollouts with fixed amounts of steps continue rolling out transitions.
  3. If you want calculate evaluation returns simply mask the sum using the binary discount vector.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

gymnax-0.0.1.tar.gz (39.2 kB view details)

Uploaded Source

Built Distribution

gymnax-0.0.1-py3-none-any.whl (60.1 kB view details)

Uploaded Python 3

File details

Details for the file gymnax-0.0.1.tar.gz.

File metadata

  • Download URL: gymnax-0.0.1.tar.gz
  • Upload date:
  • Size: 39.2 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/3.6.0 importlib_metadata/4.8.2 pkginfo/1.8.1 requests/2.26.0 requests-toolbelt/0.9.1 tqdm/4.62.3 CPython/3.10.0

File hashes

Hashes for gymnax-0.0.1.tar.gz
Algorithm Hash digest
SHA256 88f3814cee2e106c9b4b1f8fc10880d5d715be8a2386677067b1a4009897ce3f
MD5 f6cac498b5f85dfe3c2eaa9fd7832706
BLAKE2b-256 3ac8740b3f11f2abc119f69a02b3d53799d018aec399497d0637e452acaecc97

See more details on using hashes here.

File details

Details for the file gymnax-0.0.1-py3-none-any.whl.

File metadata

  • Download URL: gymnax-0.0.1-py3-none-any.whl
  • Upload date:
  • Size: 60.1 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/3.6.0 importlib_metadata/4.8.2 pkginfo/1.8.1 requests/2.26.0 requests-toolbelt/0.9.1 tqdm/4.62.3 CPython/3.10.0

File hashes

Hashes for gymnax-0.0.1-py3-none-any.whl
Algorithm Hash digest
SHA256 dd6691436c774ffae2fe65ec3b88b17655bd6ed99d2596b031f738d328f76fe5
MD5 6ab05a909d87763bae70412888f45a12
BLAKE2b-256 86550539e5434e3ddfb26b3cc8a7f500dc07bb070451b719862d33bf7a945617

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page