Skip to main content

POMDP Arcade Environments on the GPU

Project description

POPGym Arcade - GPU-Accelerated MDPs and POMDPs

POPGym Arcade is a GPU-accelerated Atari-style benchmark and suite of analysis tools for reinforcement learning.

Tests

Tasks

POPGym Arcade contains pixel-based tasks in the style of the Arcade Learning Environment.

GIF 1 GIF 1 GIF 2 GIF 2 GIF 3 GIF 3 GIF 4 GIF 4 GIF 5 GIF 5 GIF 6 GIF 6 GIF 7 GIF 7 GIF 8 GIF 8 GIF 9 GIF 9 GIF 10 GIF 10

Each environment provides:

  • Three difficulty settings
  • One observation and action space shared across all envs
  • Fully observable and partially observable configurations
  • Fast and easy GPU vectorization using jax
  • Standardized returns in [0,1] or [-1, 1]

Throughput

Expect ~10M frames per second on an RTX4090. Most of our policies converge in less than 60 minutes of training.

Baselines

We implement a simple on-policy Q learning algorithm known as PQN. We also implement various memory models:

Log Complexity RNNs

Classical RNNs

Getting Started

To install the environments, run

pip install popgym-arcade

If you plan to use our training scripts, install the baselines as well

pip install 'popgym-arcade[baselines]'

Note: If you do not already have jax installed, we install CPU jax by default. For GPU acceleration, run pip install jax[cuda12] after installing popgym-arcade.

Human Play

The play script lets you play the games yourself using the arrow keys and spacebar.

popgym-arcade-play NoisyCartPoleEasy        # play MDP 256 pixel version
popgym-arcade-play BattleShipEasy -p -o 128 # play POMDP 128 pixel version

Creating and Stepping Environments

Our tasks are gymnax environments and work with wrappers and code designed to work with gymnax. The following example demonstrates how to integrate POPGym Arcade into your code.

import popgym_arcade
import jax

# Create POMDP env variant
env, env_params = popgym_arcade.make("BattleShipEasy", partial_obs=True)

# Let's vectorize and compile the env
# Note when you are training a policy, it is better to compile your policy_update rather than the env_step
reset = jax.jit(jax.vmap(env.reset, in_axes=(0, None)))
step = jax.jit(jax.vmap(env.step, in_axes=(0, 0, 0, None)))
    
# Initialize four vectorized environments
n_envs = 4
# Initialize PRNG keys
key = jax.random.key(0)
reset_keys = jax.random.split(key, n_envs)
    
# Reset environments
observation, env_state = reset(reset_keys, env_params)

# Step the POMDP
for t in range(10):
    # Propagate some randomness
    action_key, step_key = jax.random.split(jax.random.key(t))
    action_keys = jax.random.split(action_key, n_envs)
    step_keys = jax.random.split(step_key, n_envs)
    # Pick actions at random
    actions = jax.vmap(env.action_space(env_params).sample)(action_keys)
    # Step the env to the next state
    # No need to reset after initial reset, gymnax automatically resets when done
    observation, env_state, reward, done, info = step(step_keys, env_state, actions, env_params)

# POMDP and MDP variants share states
# We can plug the POMDP states into the MDP and continue playing
mdp, mdp_params = popgym_arcade.make("BattleShipEasy", partial_obs=False)
mdp_reset = jax.jit(jax.vmap(mdp.reset, in_axes=(0, None)))
mdp_step = jax.jit(jax.vmap(mdp.step, in_axes=(0, 0, 0, None)))

action_keys = jax.random.split(jax.random.key(t + 1), n_envs)
step_keys = jax.random.split(jax.random.key(t + 2), n_envs)
markov_state, env_state, reward, done, info = mdp_step(step_keys, env_state, actions, mdp_params)

Memory Introspection Tools

We implement visualization tools to probe which pixels persist in agent memory, and their impact on Q value predictions. Try the code below or our example script to under how your agent uses memory.

from popgym_arcade.baselines.model.builder import QNetworkRNN
from popgym_arcade.baselines.utils import get_saliency_maps, vis_fn
import equinox as eqx
import jax

config = {
    # Env string
    "ENV_NAME": "NavigatorEasy",
    # Whether to use full or partial observability
    "PARTIAL": True,
    # Memory model type (see models directory)
    "MEMORY_TYPE": "lru",
    # Evaluation episode seed
    "SEED": 0,
    # Observation size in pixels (128 or 256)
    "OBS_SIZE": 128,
}

# Initialize the random key
rng = jax.random.PRNGKey(config["SEED"])

# Initialize the model
network = QNetworkRNN(rng, rnn_type=config["MEMORY_TYPE"], obs_size=config["OBS_SIZE"])
# Load the model
model = eqx.tree_deserialise_leaves("PATH_TO_YOUR_MODEL_WEIGHTS.pkl", network)
# Compute the saliency maps
grads, obs_seq, grad_accumulator = get_saliency_maps(rng, model, config)
# Visualize the saliency maps
# If you have latex installed, set use_latex=True
vis_fn(grads, obs_seq, config, use_latex=False)

Other Useful Libraries

  • gymnax - The jax-capable gymnasium API
  • popgym - The original collection of POMDPs, implemented in numpy
  • popjaxrl - A jax version of popgym
  • popjym - A more readable version of popjaxrl environments that served as a basis for our work

Citation

@article{wang2025popgym,
  title={POPGym Arcade: Parallel Pixelated POMDPs},
  author={Wang, Zekang and He, Zhe and Zhang, Borong and Toledo, Edan and Morad, Steven},
  journal={arXiv preprint arXiv:2503.01450},
  year={2025}
}

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

popgym_arcade-0.0.5.tar.gz (90.8 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

popgym_arcade-0.0.5-py3-none-any.whl (130.0 kB view details)

Uploaded Python 3

File details

Details for the file popgym_arcade-0.0.5.tar.gz.

File metadata

  • Download URL: popgym_arcade-0.0.5.tar.gz
  • Upload date:
  • Size: 90.8 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: twine/6.1.0 CPython/3.13.7

File hashes

Hashes for popgym_arcade-0.0.5.tar.gz
Algorithm Hash digest
SHA256 2bb4e7650f02d244b0b3d857f7b07a5ccc126919b3d0bed8077a6ce1a6cf0d5b
MD5 15b4d054d8ddc2e51ffe8ede8ad0aa6c
BLAKE2b-256 44e197a527cfe7f2a9bc804737adbc16b56788b1c65daeaeda1f2c20cefad73e

See more details on using hashes here.

Provenance

The following attestation bundles were made for popgym_arcade-0.0.5.tar.gz:

Publisher: python-publish.yml on bolt-research/popgym-arcade

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

File details

Details for the file popgym_arcade-0.0.5-py3-none-any.whl.

File metadata

  • Download URL: popgym_arcade-0.0.5-py3-none-any.whl
  • Upload date:
  • Size: 130.0 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: twine/6.1.0 CPython/3.13.7

File hashes

Hashes for popgym_arcade-0.0.5-py3-none-any.whl
Algorithm Hash digest
SHA256 6bbf2f02bc3f707fe362661cec13f45fbb92c64527d34ec1179d708ad19a33c2
MD5 a117714695f4a87e665d3f866d92987c
BLAKE2b-256 154626f9fe74b77ed768651085190a13e6f58e22fee5ad3f72aabe53a94b4e0a

See more details on using hashes here.

Provenance

The following attestation bundles were made for popgym_arcade-0.0.5-py3-none-any.whl:

Publisher: python-publish.yml on bolt-research/popgym-arcade

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page