Skip to main content

JAX-compatible version of Open AI's gym environments

Project description


Reinforcement Learning Environments in JAX 🌍

Are you fed up with slow CPU-based RL environment processes? Do you want to leverage massive vectorization for high-throughput RL experiments? gymnax brings the power of jit and vmap/pmap to the classic gym API. It supports a range of different environments including classic control, bsuite, MinAtar and a collection of classic/meta RL tasks. gymnax allows explicit functional control of environment settings (random seed or hyperparameters), which enables accelerated & parallelized rollouts for different configurations (e.g. for meta RL). By executing both environment and policy on the accelerator, it facilitates the Anakin sub-architecture proposed in the Podracer paper (Hessel et al., 2021) and highly distributed evolutionary optimization (using e.g. evosax). We provide training & checkpoints for both PPO & ES in gymnax-blines. Get started here 👉 Colab.

Basic gymnax API Usage 🍲

import jax
import gymnax

rng = jax.random.PRNGKey(0)
rng, key_reset, key_act, key_step = jax.random.split(rng, 4)

# Instantiate the environment & its settings.
env, env_params = gymnax.make("Pendulum-v1")

# Reset the environment.
obs, state = env.reset(key_reset, env_params)

# Sample a random action.
action = env.action_space(env_params).sample(key_act)

# Perform the step transition.
n_obs, n_state, reward, done, _ = env.step(key_step, state, action, env_params)

Implemented Accelerated Environments 🏎️

Environment Name Reference Source 🤖 Ckpt (Return) Secs/1M 🦶
A100 (2k 🌎)
Acrobot-v1 Brockman et al. (2016) Click PPO, ES (R: -80) 0.07
Pendulum-v1 Brockman et al. (2016) Click PPO, ES (R: -130) 0.07
CartPole-v1 Brockman et al. (2016) Click PPO, ES (R: 500) 0.05
MountainCar-v0 Brockman et al. (2016) Click PPO, ES (R: -118) 0.07
MountainCarContinuous-v0 Brockman et al. (2016) Click PPO, ES (R: 92) 0.09
Asterix-MinAtar Young & Tian (2019) Click PPO (R: 15) 0.92
Breakout-MinAtar Young & Tian (2019) Click PPO (R: 28) 0.19
Freeway-MinAtar Young & Tian (2019) Click PPO (R: 58) 0.87
SpaceInvaders-MinAtar Young & Tian (2019) Click PPO (R: 131) 0.33
Catch-bsuite Osband et al. (2019) Click PPO, ES (R: 1) 0.15
DeepSea-bsuite Osband et al. (2019) Click PPO, ES (R: 0) 0.22
MemoryChain-bsuite Osband et al. (2019) Click PPO, ES (R: 0.1) 0.13
UmbrellaChain-bsuite Osband et al. (2019) Click PPO, ES (R: 1) 0.08
DiscountingChain-bsuite Osband et al. (2019) Click PPO, ES (R: 1.1) 0.06
MNISTBandit-bsuite Osband et al. (2019) Click - -
SimpleBandit-bsuite Osband et al. (2019) Click - -
FourRooms-misc Sutton et al. (1999) Click PPO, ES (R: 1) 0.07
MetaMaze-misc Micconi et al. (2020) Click ES (R: 32) 0.09
PointRobot-misc Dorfman et al. (2021) Click ES (R: 10) 0.08
BernoulliBandit-misc Wang et al. (2017) Click ES (R: 90) 0.08
GaussianBandit-misc Lange & Sprekeler (2022) Click ES (R: 0) 0.07

* All displayed speeds are estimated for 1M step transitions (random policy) on a NVIDIA A100 GPU using jit compiled episode rollouts with 2000 environment workers. For more detailed speed comparisons on different accelerators (CPU, RTX 2080Ti) and MLP policies, please refer to the gymnax-blines documentation.

Installation ⏳

The latest gymnax release can directly be installed from PyPI:

pip install gymnax

If you want to get the most recent commit, please install directly from the repository:

pip install git+https://github.com/RobertTLange/gymnax.git@main

In order to use JAX on your accelerators, you can find more details in the JAX documentation.

Examples 📖

Key Selling Points 💵

  • Environment vectorization & acceleration: Easy composition of JAX primitives (e.g. jit, vmap, pmap):

    # Jit-accelerated step transition
    jit_step = jax.jit(env.step)
    
    # map (vmap/pmap) across random keys for batch rollouts
    reset_rng = jax.vmap(env.reset, in_axes=(0, None))
    step_rng = jax.vmap(env.step, in_axes=(0, 0, 0, None))
    
    # map (vmap/pmap) across env parameters (e.g. for meta-learning)
    reset_params = jax.vmap(env.reset, in_axes=(None, 0))
    step_params = jax.vmap(env.step, in_axes=(None, 0, 0, 0))
    

    For speed comparisons with standard vectorized NumPy environments check out gymnax-blines.

  • Scan through entire episode rollouts: You can also lax.scan through entire reset, step episode loops for fast compilation:

    def rollout(rng_input, policy_params, env_params, steps_in_episode):
        """Rollout a jitted gymnax episode with lax.scan."""
        # Reset the environment
        rng_reset, rng_episode = jax.random.split(rng_input)
        obs, state = env.reset(rng_reset, env_params)
    
        def policy_step(state_input, tmp):
            """lax.scan compatible step transition in jax env."""
            obs, state, policy_params, rng = state_input
            rng, rng_step, rng_net = jax.random.split(rng, 3)
            action = model.apply(policy_params, obs)
            next_obs, next_state, reward, done, _ = env.step(
                rng_step, state, action, env_params
            )
            carry = [next_obs, next_state, policy_params, rng]
            return carry, [obs, action, reward, next_obs, done]
    
        # Scan over episode step loop
        _, scan_out = jax.lax.scan(
            policy_step,
            [obs, state, policy_params, rng_episode],
            (),
            steps_in_episode
        )
        # Return masked sum of rewards accumulated by agent in episode
        obs, action, reward, next_obs, done = scan_out
        return obs, action, reward, next_obs, done
    
  • Build-in visualization tools: You can also smoothly generate GIF animations using the Visualizer tool, which covers all classic_control, MinAtar and most misc environments:

    from gymnax.visualize import Visualizer
    
    state_seq, reward_seq = [], []
    rng, rng_reset = jax.random.split(rng)
    obs, env_state = env.reset(rng_reset, env_params)
    while True:
        state_seq.append(env_state)
        rng, rng_act, rng_step = jax.random.split(rng, 3)
        action = env.action_space(env_params).sample(rng_act)
        next_obs, next_env_state, reward, done, info = env.step(
            rng_step, env_state, action, env_params
        )
        reward_seq.append(reward)
        if done:
            break
        else:
          obs = next_obs
          env_state = next_env_state
    
    cum_rewards = jnp.cumsum(jnp.array(reward_seq))
    vis = Visualizer(env, env_params, state_seq, cum_rewards)
    vis.animate(f"docs/anim.gif")
    
  • Training pipelines & pretrained agents: Check out gymnax-blines for trained agents, expert rollout visualizations and PPO/ES pipelines. The agents are minimally tuned, but can help you get up and running.

  • Simple batch agent evaluation: Work-in-progress.

    from gymnax.experimental import RolloutWrapper
    
    # Define rollout manager for pendulum env
    manager = RolloutWrapper(model.apply, env_name="Pendulum-v1")
    
    # Simple single episode rollout for policy
    obs, action, reward, next_obs, done, cum_ret = manager.single_rollout(rng, policy_params)
    
    # Multiple rollouts for same network (different rng, e.g. eval)
    rng_batch = jax.random.split(rng, 10)
    obs, action, reward, next_obs, done, cum_ret = manager.batch_rollout(
        rng_batch, policy_params
    )
    
    # Multiple rollouts for different networks + rng (e.g. for ES)
    batch_params = jax.tree_map(  # Stack parameters or use different
        lambda x: jnp.tile(x, (5, 1)).reshape(5, *x.shape), policy_params
    )
    obs, action, reward, next_obs, done, cum_ret = manager.population_rollout(
        rng_batch, batch_params
    )
    

Resources & Other Great Tools 📝

  • 💻 Brax: JAX-based library for rigid body physics by Google Brain with JAX-style MuJoCo substitutes.
  • 💻 envpool: Vectorized parallel environment execution engine.

Acknowledgements & Citing gymnax ✏️

If you use gymnax in your research, please cite it as follows:

@software{gymnax2022github,
  author = {Robert Tjarko Lange},
  title = {{gymnax}: A {JAX}-based Reinforcement Learning Environment Library},
  url = {http://github.com/RobertTLange/gymnax},
  version = {0.0.4},
  year = {2022},
}

We acknowledge financial support by the Google TRC and the Deutsche Forschungsgemeinschaft (DFG, German Research Foundation) under Germany's Excellence Strategy - EXC 2002/1 "Science of Intelligence" - project number 390523135.

Development 👷

You can run the test suite via python -m pytest -vv --all. If you find a bug or are missing your favourite feature, feel free to create an issue and/or start contributing 🤗.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

gymnax-0.0.5.tar.gz (52.8 kB view details)

Uploaded Source

Built Distribution

gymnax-0.0.5-py3-none-any.whl (78.6 kB view details)

Uploaded Python 3

File details

Details for the file gymnax-0.0.5.tar.gz.

File metadata

  • Download URL: gymnax-0.0.5.tar.gz
  • Upload date:
  • Size: 52.8 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/4.0.1 CPython/3.10.6

File hashes

Hashes for gymnax-0.0.5.tar.gz
Algorithm Hash digest
SHA256 6a8b3aa9f9490c1f79d1f5c00c44909a4a3a7f54bc691ad13056b14371266506
MD5 7a06eabb6b88e47fe0bcdb041912b5b3
BLAKE2b-256 87e98f01f8e6364a006202a01595e8711d9ef8e93300b8cf71909accbc8d4961

See more details on using hashes here.

File details

Details for the file gymnax-0.0.5-py3-none-any.whl.

File metadata

  • Download URL: gymnax-0.0.5-py3-none-any.whl
  • Upload date:
  • Size: 78.6 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/4.0.1 CPython/3.10.6

File hashes

Hashes for gymnax-0.0.5-py3-none-any.whl
Algorithm Hash digest
SHA256 e11ed79aeb5870a919fad463ef3e77e1348371b6efb56df2120f0b26729bbdbc
MD5 8e08072cc0a601fbd8d0b736e9527b14
BLAKE2b-256 5dc12c987337f6804ee7b522e53f52dfefbc7edbdbc90f9d7ce2a5c7622557a0

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page