Skip to main content

JAX implementations of OpenAI's gym environments

Project description


Reinforcement Learning Environments in JAX 🌍

Are you fed up with slow CPU-based RL environment processes? Do you want to leverage massive vectorization for high-throughput RL experiments? gymnax brings the power of jit and vmap/pmap to the classic gym API. It supports a range of different environments including classic control, bsuite, MinAtar and a collection of classic/meta RL tasks. gymnax allows explicit functional control of environment settings (random seed or hyperparameters), which enables accelerated & parallelized rollouts for different configurations (e.g. for meta RL). By executing both environment and policy on the accelerator, it facilitates the Anakin sub-architecture proposed in the Podracer paper (Hessel et al., 2021) and highly distributed evolutionary optimization (using e.g. evosax). We provide training & checkpoints for both PPO & ES in gymnax-blines. Get started here 👉 Colab.

Basic gymnax API Usage 🍲

import jax
import gymnax

key = jax.random.key(0)
key, key_reset, key_act, key_step = jax.random.split(key, 4)

# Instantiate the environment & its settings.
env, env_params = gymnax.make("Pendulum-v1")

# Reset the environment.
obs, state = env.reset(key_reset, env_params)

# Sample a random action.
action = env.action_space(env_params).sample(key_act)

# Perform the step transition.
n_obs, n_state, reward, done, _ = env.step(key_step, state, action, env_params)

Implemented Accelerated Environments 🏎️

Environment Name Reference Source 🤖 Ckpt (Return) Secs/1M 🦶
A100 (2k 🌎)
Acrobot-v1 Brockman et al. (2016) Click PPO, ES (R: -80) 0.07
Pendulum-v1 Brockman et al. (2016) Click PPO, ES (R: -130) 0.07
CartPole-v1 Brockman et al. (2016) Click PPO, ES (R: 500) 0.05
MountainCar-v0 Brockman et al. (2016) Click PPO, ES (R: -118) 0.07
MountainCarContinuous-v0 Brockman et al. (2016) Click PPO, ES (R: 92) 0.09
Asterix-MinAtar Young & Tian (2019) Click PPO (R: 15) 0.92
Breakout-MinAtar Young & Tian (2019) Click PPO (R: 28) 0.19
Freeway-MinAtar Young & Tian (2019) Click PPO (R: 58) 0.87
SpaceInvaders-MinAtar Young & Tian (2019) Click PPO (R: 131) 0.33
Catch-bsuite Osband et al. (2019) Click PPO, ES (R: 1) 0.15
DeepSea-bsuite Osband et al. (2019) Click PPO, ES (R: 0) 0.22
MemoryChain-bsuite Osband et al. (2019) Click PPO, ES (R: 0.1) 0.13
UmbrellaChain-bsuite Osband et al. (2019) Click PPO, ES (R: 1) 0.08
DiscountingChain-bsuite Osband et al. (2019) Click PPO, ES (R: 1.1) 0.06
MNISTBandit-bsuite Osband et al. (2019) Click - -
SimpleBandit-bsuite Osband et al. (2019) Click - -
FourRooms-misc Sutton et al. (1999) Click PPO, ES (R: 1) 0.07
MetaMaze-misc Micconi et al. (2020) Click ES (R: 32) 0.09
PointRobot-misc Dorfman et al. (2021) Click ES (R: 10) 0.08
BernoulliBandit-misc Wang et al. (2017) Click ES (R: 90) 0.08
GaussianBandit-misc Lange & Sprekeler (2022) Click ES (R: 0) 0.07
Reacher-misc Lenton et al. (2021) Click
Swimmer-misc Lenton et al. (2021) Click
Pong-misc Kirsch (2018) Click

* All displayed speeds are estimated for 1M step transitions (random policy) on a NVIDIA A100 GPU using jit compiled episode rollouts with 2000 environment workers. For more detailed speed comparisons on different accelerators (CPU, RTX 2080Ti) and MLP policies, please refer to the gymnax-blines documentation.

Installation ⏳

The latest gymnax release can directly be installed from PyPI:

pip install gymnax

If you want to get the most recent commit, please install directly from the repository:

pip install git+https://github.com/RobertTLange/gymnax.git@main

In order to use JAX on your accelerators, you can find more details in the JAX documentation.

Examples 📖

Key Selling Points 💵

  • Environment vectorization & acceleration: Easy composition of JAX primitives (e.g. jit, vmap, pmap):

    # Jit-accelerated step transition
    jit_step = jax.jit(env.step)
    
    # map (vmap/pmap) across random keys for batch rollouts
    reset_key = jax.vmap(env.reset, in_axes=(0, None))
    step_key = jax.vmap(env.step, in_axes=(0, 0, 0, None))
    
    # map (vmap/pmap) across env parameters (e.g. for meta-learning)
    reset_params = jax.vmap(env.reset, in_axes=(None, 0))
    step_params = jax.vmap(env.step, in_axes=(None, 0, 0, 0))
    

    For speed comparisons with standard vectorized NumPy environments check out gymnax-blines.

  • Scan through entire episode rollouts: You can also lax.scan through entire reset, step episode loops for fast compilation:

    def rollout(key_input, policy_params, env_params, steps_in_episode):
        """Rollout a jitted gymnax episode with lax.scan."""
        # Reset the environment
        key_reset, key_episode = jax.random.split(key_input)
        obs, state = env.reset(key_reset, env_params)
    
        def policy_step(state_input, tmp):
            """lax.scan compatible step transition in jax env."""
            obs, state, policy_params, key = state_input
            key, key_step, key_net = jax.random.split(key, 3)
            action = model.apply(policy_params, obs)
            next_obs, next_state, reward, done, _ = env.step(
                key_step, state, action, env_params
            )
            carry = [next_obs, next_state, policy_params, key]
            return carry, [obs, action, reward, next_obs, done]
    
        # Scan over episode step loop
        _, scan_out = jax.lax.scan(
            policy_step,
            [obs, state, policy_params, key_episode],
            (),
            steps_in_episode
        )
        # Return masked sum of rewards accumulated by agent in episode
        obs, action, reward, next_obs, done = scan_out
        return obs, action, reward, next_obs, done
    
  • Build-in visualization tools: You can also smoothly generate GIF animations using the Visualizer tool, which covers all classic_control, MinAtar and most misc environments:

    from gymnax.visualize import Visualizer
    
    state_seq, reward_seq = [], []
    key, key_reset = jax.random.split(key)
    obs, env_state = env.reset(key_reset, env_params)
    while True:
        state_seq.append(env_state)
        key, key_act, key_step = jax.random.split(key, 3)
        action = env.action_space(env_params).sample(key_act)
        next_obs, next_env_state, reward, done, info = env.step(
            key_step, env_state, action, env_params
        )
        reward_seq.append(reward)
        if done:
            break
        else:
          obs = next_obs
          env_state = next_env_state
    
    cum_rewards = jnp.cumsum(jnp.array(reward_seq))
    vis = Visualizer(env, env_params, state_seq, cum_rewards)
    vis.animate(f"docs/anim.gif")
    
  • Training pipelines & pretrained agents: Check out gymnax-blines for trained agents, expert rollout visualizations and PPO/ES pipelines. The agents are minimally tuned, but can help you get up and running.

  • Simple batch agent evaluation: Work-in-progress.

    from gymnax.experimental import RolloutWrapper
    
    # Define rollout manager for pendulum env
    manager = RolloutWrapper(model.apply, env_name="Pendulum-v1")
    
    # Simple single episode rollout for policy
    obs, action, reward, next_obs, done, cum_ret = manager.single_rollout(key, policy_params)
    
    # Multiple rollouts for same network (different key, e.g. eval)
    key_batch = jax.random.split(key, 10)
    obs, action, reward, next_obs, done, cum_ret = manager.batch_rollout(
        key_batch, policy_params
    )
    
    # Multiple rollouts for different networks + key (e.g. for ES)
    batch_params = jax.tree.map(  # Stack parameters or use different
        lambda x: jnp.tile(x, (5, 1)).reshape(5, *x.shape), policy_params
    )
    obs, action, reward, next_obs, done, cum_ret = manager.population_rollout(
        key_batch, batch_params
    )
    

Resources & Other Great Tools 📝

  • 💻 Brax: JAX-based library for rigid body physics by Google Brain with JAX-style MuJoCo substitutes.
  • 💻 envpool: Vectorized parallel environment execution engine.
  • 💻 Jumanji: A suite of diverse and challenging RL environments in JAX.
  • 💻 Pgx: JAX-based classic board game environments.

Acknowledgements & Citing gymnax ✏️

If you use gymnax in your research, please cite it as follows:

@software{gymnax2022github,
  author = {Robert Tjarko Lange},
  title = {{gymnax}: A {JAX}-based Reinforcement Learning Environment Library},
  url = {http://github.com/RobertTLange/gymnax},
  version = {0.0.4},
  year = {2022},
}

We acknowledge financial support by the Google TRC and the Deutsche Forschungsgemeinschaft (DFG, German Research Foundation) under Germany's Excellence Strategy - EXC 2002/1 "Science of Intelligence" - project number 390523135.

Development 👷

You can run the test suite via python -m pytest -vv --all. If you find a bug or are missing your favourite feature, feel free to create an issue and/or start contributing 🤗.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

gymnax-0.0.9.tar.gz (57.1 kB view details)

Uploaded Source

Built Distribution

gymnax-0.0.9-py3-none-any.whl (86.6 kB view details)

Uploaded Python 3

File details

Details for the file gymnax-0.0.9.tar.gz.

File metadata

  • Download URL: gymnax-0.0.9.tar.gz
  • Upload date:
  • Size: 57.1 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.1.0 CPython/3.9.7

File hashes

Hashes for gymnax-0.0.9.tar.gz
Algorithm Hash digest
SHA256 407ececbefb527b8e18d2bc4df15dd5251fd3507f3f527ad69fbfc7b2d8ae410
MD5 9952508bbc3e6942c1ef93686191280f
BLAKE2b-256 9c2c91105387a5f6f11fe4c6e773fecd6aeac5a04656f1ea03c28e5228a8a11c

See more details on using hashes here.

File details

Details for the file gymnax-0.0.9-py3-none-any.whl.

File metadata

  • Download URL: gymnax-0.0.9-py3-none-any.whl
  • Upload date:
  • Size: 86.6 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.1.0 CPython/3.9.7

File hashes

Hashes for gymnax-0.0.9-py3-none-any.whl
Algorithm Hash digest
SHA256 a73359cfbc2664f4963ccac06afe2b46ff7825cf48dd97200b0f4025f6f995bc
MD5 c25ba49770fc658bf84bdcad584e59c3
BLAKE2b-256 fd041042695a9aee9d619d4802272ffeb7ab88f4c241d16c2701f666ec7e529b

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page