Skip to main content

(Stable) JAX implementations of OpenAI's gym environments

Project description


Reinforcement Learning Environments in JAX 🌍

This is a forked version of gymnax, because the maintainers stopped supporting the official version. This version works with the latest version of jax, and has some additional bugfixes.

Are you fed up with slow CPU-based RL environment processes? Do you want to leverage massive vectorization for high-throughput RL experiments? gymnax brings the power of jit and vmap/pmap to the classic gym API. It supports a range of different environments including classic control, bsuite, MinAtar and a collection of classic/meta RL tasks. gymnax allows explicit functional control of environment settings (random seed or hyperparameters), which enables accelerated & parallelized rollouts for different configurations (e.g. for meta RL). By executing both environment and policy on the accelerator, it facilitates the Anakin sub-architecture proposed in the Podracer paper (Hessel et al., 2021) and highly distributed evolutionary optimization (using e.g. evosax). We provide training & checkpoints for both PPO & ES in gymnax-blines. Get started here 👉 Colab.

Basic gymnax API Usage 🍲

import jax
import gymnax

key = jax.random.key(0)
key, key_reset, key_act, key_step = jax.random.split(key, 4)

# Instantiate the environment & its settings.
env, env_params = gymnax.make("Pendulum-v1")

# Reset the environment.
obs, state = env.reset(key_reset, env_params)

# Sample a random action.
action = env.action_space(env_params).sample(key_act)

# Perform the step transition.
n_obs, n_state, reward, done, _ = env.step(key_step, state, action, env_params)

Implemented Accelerated Environments 🏎️

Environment Name Reference Source 🤖 Ckpt (Return) Secs/1M 🦶
A100 (2k 🌎)
Acrobot-v1 Brockman et al. (2016) Click PPO, ES (R: -80) 0.07
Pendulum-v1 Brockman et al. (2016) Click PPO, ES (R: -130) 0.07
CartPole-v1 Brockman et al. (2016) Click PPO, ES (R: 500) 0.05
MountainCar-v0 Brockman et al. (2016) Click PPO, ES (R: -118) 0.07
MountainCarContinuous-v0 Brockman et al. (2016) Click PPO, ES (R: 92) 0.09
Asterix-MinAtar Young & Tian (2019) Click PPO (R: 15) 0.92
Breakout-MinAtar Young & Tian (2019) Click PPO (R: 28) 0.19
Freeway-MinAtar Young & Tian (2019) Click PPO (R: 58) 0.87
SpaceInvaders-MinAtar Young & Tian (2019) Click PPO (R: 131) 0.33
Catch-bsuite Osband et al. (2019) Click PPO, ES (R: 1) 0.15
DeepSea-bsuite Osband et al. (2019) Click PPO, ES (R: 0) 0.22
MemoryChain-bsuite Osband et al. (2019) Click PPO, ES (R: 0.1) 0.13
UmbrellaChain-bsuite Osband et al. (2019) Click PPO, ES (R: 1) 0.08
DiscountingChain-bsuite Osband et al. (2019) Click PPO, ES (R: 1.1) 0.06
MNISTBandit-bsuite Osband et al. (2019) Click - -
SimpleBandit-bsuite Osband et al. (2019) Click - -
FourRooms-misc Sutton et al. (1999) Click PPO, ES (R: 1) 0.07
MetaMaze-misc Micconi et al. (2020) Click ES (R: 32) 0.09
PointRobot-misc Dorfman et al. (2021) Click ES (R: 10) 0.08
BernoulliBandit-misc Wang et al. (2017) Click ES (R: 90) 0.08
GaussianBandit-misc Lange & Sprekeler (2022) Click ES (R: 0) 0.07
Reacher-misc Lenton et al. (2021) Click
Swimmer-misc Lenton et al. (2021) Click
Pong-misc Kirsch (2018) Click

* All displayed speeds are estimated for 1M step transitions (random policy) on a NVIDIA A100 GPU using jit compiled episode rollouts with 2000 environment workers. For more detailed speed comparisons on different accelerators (CPU, RTX 2080Ti) and MLP policies, please refer to the gymnax-blines documentation.

Installation ⏳

The latest gymnax release can directly be installed from PyPI:

pip install gymnax

If you want to get the most recent commit, please install directly from the repository:

pip install git+https://github.com/RobertTLange/gymnax.git@main

In order to use JAX on your accelerators, you can find more details in the JAX documentation.

Examples 📖

Key Selling Points 💵

  • Environment vectorization & acceleration: Easy composition of JAX primitives (e.g. jit, vmap, pmap):

    # Jit-accelerated step transition
    jit_step = jax.jit(env.step)
    
    # map (vmap/pmap) across random keys for batch rollouts
    reset_key = jax.vmap(env.reset, in_axes=(0, None))
    step_key = jax.vmap(env.step, in_axes=(0, 0, 0, None))
    
    # map (vmap/pmap) across env parameters (e.g. for meta-learning)
    reset_params = jax.vmap(env.reset, in_axes=(None, 0))
    step_params = jax.vmap(env.step, in_axes=(None, 0, 0, 0))
    

    For speed comparisons with standard vectorized NumPy environments check out gymnax-blines.

  • Scan through entire episode rollouts: You can also lax.scan through entire reset, step episode loops for fast compilation:

    def rollout(key_input, policy_params, env_params, steps_in_episode):
        """Rollout a jitted gymnax episode with lax.scan."""
        # Reset the environment
        key_reset, key_episode = jax.random.split(key_input)
        obs, state = env.reset(key_reset, env_params)
    
        def policy_step(state_input, tmp):
            """lax.scan compatible step transition in jax env."""
            obs, state, policy_params, key = state_input
            key, key_step, key_net = jax.random.split(key, 3)
            action = model.apply(policy_params, obs)
            next_obs, next_state, reward, done, _ = env.step(
                key_step, state, action, env_params
            )
            carry = [next_obs, next_state, policy_params, key]
            return carry, [obs, action, reward, next_obs, done]
    
        # Scan over episode step loop
        _, scan_out = jax.lax.scan(
            policy_step,
            [obs, state, policy_params, key_episode],
            (),
            steps_in_episode
        )
        # Return masked sum of rewards accumulated by agent in episode
        obs, action, reward, next_obs, done = scan_out
        return obs, action, reward, next_obs, done
    
  • Build-in visualization tools: You can also smoothly generate GIF animations using the Visualizer tool, which covers all classic_control, MinAtar and most misc environments:

    from gymnax.visualize import Visualizer
    
    state_seq, reward_seq = [], []
    key, key_reset = jax.random.split(key)
    obs, env_state = env.reset(key_reset, env_params)
    while True:
        state_seq.append(env_state)
        key, key_act, key_step = jax.random.split(key, 3)
        action = env.action_space(env_params).sample(key_act)
        next_obs, next_env_state, reward, done, info = env.step(
            key_step, env_state, action, env_params
        )
        reward_seq.append(reward)
        if done:
            break
        else:
          obs = next_obs
          env_state = next_env_state
    
    cum_rewards = jnp.cumsum(jnp.array(reward_seq))
    vis = Visualizer(env, env_params, state_seq, cum_rewards)
    vis.animate(f"docs/anim.gif")
    
  • Training pipelines & pretrained agents: Check out gymnax-blines for trained agents, expert rollout visualizations and PPO/ES pipelines. The agents are minimally tuned, but can help you get up and running.

  • Simple batch agent evaluation: Work-in-progress.

    from gymnax.experimental import RolloutWrapper
    
    # Define rollout manager for pendulum env
    manager = RolloutWrapper(model.apply, env_name="Pendulum-v1")
    
    # Simple single episode rollout for policy
    obs, action, reward, next_obs, done, cum_ret = manager.single_rollout(key, policy_params)
    
    # Multiple rollouts for same network (different key, e.g. eval)
    key_batch = jax.random.split(key, 10)
    obs, action, reward, next_obs, done, cum_ret = manager.batch_rollout(
        key_batch, policy_params
    )
    
    # Multiple rollouts for different networks + key (e.g. for ES)
    batch_params = jax.tree.map(  # Stack parameters or use different
        lambda x: jnp.tile(x, (5, 1)).reshape(5, *x.shape), policy_params
    )
    obs, action, reward, next_obs, done, cum_ret = manager.population_rollout(
        key_batch, batch_params
    )
    

Resources & Other Great Tools 📝

  • 💻 Brax: JAX-based library for rigid body physics by Google Brain with JAX-style MuJoCo substitutes.
  • 💻 envpool: Vectorized parallel environment execution engine.
  • 💻 Jumanji: A suite of diverse and challenging RL environments in JAX.
  • 💻 Pgx: JAX-based classic board game environments.

Acknowledgements & Citing gymnax ✏️

If you use gymnax in your research, please cite it as follows:

@software{gymnax2022github,
  author = {Robert Tjarko Lange},
  title = {{gymnax}: A {JAX}-based Reinforcement Learning Environment Library},
  url = {http://github.com/RobertTLange/gymnax},
  version = {0.0.4},
  year = {2022},
}

We acknowledge financial support by the Google TRC and the Deutsche Forschungsgemeinschaft (DFG, German Research Foundation) under Germany's Excellence Strategy - EXC 2002/1 "Science of Intelligence" - project number 390523135.

Development 👷

You can install the dev and test dependencies with python -m pip -e .[dev,test] in the root of the repository, and then run the test suite via python -m pytest -vv --all. When running the test suite, it is strongly encouraged, but not required, to treat warnings as errors with python -m pytest -vv -W error --tb=short --all. If you find a bug or are missing your favourite feature, feel free to create an issue and/or start contributing 🤗.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

stable_gymnax-0.0.1.tar.gz (58.6 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

stable_gymnax-0.0.1-py3-none-any.whl (88.0 kB view details)

Uploaded Python 3

File details

Details for the file stable_gymnax-0.0.1.tar.gz.

File metadata

  • Download URL: stable_gymnax-0.0.1.tar.gz
  • Upload date:
  • Size: 58.6 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: twine/6.1.0 CPython/3.13.7

File hashes

Hashes for stable_gymnax-0.0.1.tar.gz
Algorithm Hash digest
SHA256 7ae3adb4795d80ba58de8afa7dd106d67bf4e680c76959dbb5a31b3b2573884f
MD5 7a405f583c3ff6bb7862893c8b418e6f
BLAKE2b-256 4c06381cac5b7a5165f982d39da30aa14b5f7bda821e1be20dd04fcc3533953a

See more details on using hashes here.

Provenance

The following attestation bundles were made for stable_gymnax-0.0.1.tar.gz:

Publisher: pypi_publish.yaml on smorad/stable-gymnax

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

File details

Details for the file stable_gymnax-0.0.1-py3-none-any.whl.

File metadata

  • Download URL: stable_gymnax-0.0.1-py3-none-any.whl
  • Upload date:
  • Size: 88.0 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: twine/6.1.0 CPython/3.13.7

File hashes

Hashes for stable_gymnax-0.0.1-py3-none-any.whl
Algorithm Hash digest
SHA256 37e0108667bf510d236a325e43b7d4642514b9be8d3e610cf98f027558bf91d7
MD5 d1079e16be5681afe50f9ec5d05984de
BLAKE2b-256 ec3f829fa67829fa92bd919ad79bae5c3fbd80887da42cf3f6c661756e9a32d3

See more details on using hashes here.

Provenance

The following attestation bundles were made for stable_gymnax-0.0.1-py3-none-any.whl:

Publisher: pypi_publish.yaml on smorad/stable-gymnax

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page