POMDP Arcade Environments on the GPU
Project description
POPGym Arcade - GPU-Accelerated MDPs and POMDPs
POPGym Arcade is a GPU-accelerated Atari-style benchmark and suite of analysis tools for reinforcement learning.
Tasks
POPGym Arcade contains pixel-based tasks in the style of the Arcade Learning Environment.
Each environment provides:
- Three difficulty settings
- One observation and action space shared across all envs
- Fully observable and partially observable configurations
- Fast and easy GPU vectorization using
jax - Standardized returns in
[0,1]or[-1, 1]
Throughput
Expect ~10M frames per second on an RTX4090. Most of our policies converge in less than 60 minutes of training.
Baselines
We provide a single training script for all algorithms and memory models. The memax library provides 18 different memory models for use in our script.
RL Algorithms
Getting Started
To install the environments, run
pip install popgym-arcade
If you plan to use our training scripts, install the baselines as well. If you want to play the games yourself, also use the human flag.
pip install 'popgym-arcade[baselines,human]'
Note: If you do not already have jax installed, we install CPU jax by default. For GPU acceleration, run pip install jax[cuda12] after installing popgym-arcade.
Human Play
The play script installed with pip install popgym-arcade[human] lets you play the games yourself using the arrow keys and spacebar.
popgym-arcade-play NoisyCartPoleEasy # play MDP 256 pixel version
popgym-arcade-play BattleShipEasy -p -o 128 # play POMDP 128 pixel version
Creating and Stepping Environments
Our tasks are gymnax environments and work with wrappers and code designed to work with gymnax. The following example demonstrates how to integrate POPGym Arcade into your code.
import popgym_arcade
import jax
# Create POMDP env variant
env, env_params = popgym_arcade.make("BattleShipEasy", partial_obs=True)
# Let's vectorize and compile the env
# Note when you are training a policy, it is better to compile your policy_update rather than the env_step
reset = jax.jit(jax.vmap(env.reset, in_axes=(0, None)))
step = jax.jit(jax.vmap(env.step, in_axes=(0, 0, 0, None)))
# Initialize four vectorized environments
n_envs = 4
# Initialize PRNG keys
key = jax.random.key(0)
reset_keys = jax.random.split(key, n_envs)
# Reset environments
observation, env_state = reset(reset_keys, env_params)
# Step the POMDP
for t in range(10):
# Propagate some randomness
action_key, step_key = jax.random.split(jax.random.key(t))
action_keys = jax.random.split(action_key, n_envs)
step_keys = jax.random.split(step_key, n_envs)
# Pick actions at random
actions = jax.vmap(env.action_space(env_params).sample)(action_keys)
# Step the env to the next state
# No need to reset after initial reset, gymnax automatically resets when done
observation, env_state, reward, done, info = step(step_keys, env_state, actions, env_params)
# POMDP and MDP variants share states
# We can plug the POMDP states into the MDP and continue playing
mdp, mdp_params = popgym_arcade.make("BattleShipEasy", partial_obs=False)
mdp_reset = jax.jit(jax.vmap(mdp.reset, in_axes=(0, None)))
mdp_step = jax.jit(jax.vmap(mdp.step, in_axes=(0, 0, 0, None)))
action_keys = jax.random.split(jax.random.key(t + 1), n_envs)
step_keys = jax.random.split(jax.random.key(t + 2), n_envs)
markov_state, env_state, reward, done, info = mdp_step(step_keys, env_state, actions, mdp_params)
Memory Introspection Tools
We implement visualization tools to probe which pixels persist in agent memory, and their impact on Q value predictions. Try the code below or our example script to under how your agent uses memory.
from popgym_arcade.baselines.model.builder import QNetworkRNN
from popgym_arcade.baselines.utils import get_saliency_maps, vis_fn
import equinox as eqx
import jax
config = {
# Env string
"ENV_NAME": "NavigatorEasy",
# Whether to use full or partial observability
"PARTIAL": True,
# Memory model type (see models directory)
"MEMORY_TYPE": "lru",
# Evaluation episode seed
"SEED": 0,
# Observation size in pixels (128 or 256)
"OBS_SIZE": 128,
}
# Initialize the random key
rng = jax.random.PRNGKey(config["SEED"])
# Initialize the model
network = QNetworkRNN(rng, rnn_type=config["MEMORY_TYPE"], obs_size=config["OBS_SIZE"])
# Load the model
model = eqx.tree_deserialise_leaves("PATH_TO_YOUR_MODEL_WEIGHTS.pkl", network)
# Compute the saliency maps
grads, obs_seq, grad_accumulator = get_saliency_maps(rng, model, config)
# Visualize the saliency maps
# If you have latex installed, set use_latex=True
vis_fn(grads, obs_seq, config, use_latex=False)
Other Useful Libraries
stable-gymnax- A (stable)jax-capablegymnasiumAPImemax- Recurrent models forjaxpopgym- The original collection of POMDPs, implemented innumpypopjaxrl- Ajaxversion ofpopgympopjym- A more readable version ofpopjaxrlenvironments that served as a basis for our work
Citation
@article{wang2025popgym,
title={Investigating Memory in RL with POPGym Arcade},
author={Wang, Zekang and He, Zhe and Zhang, Borong and Toledo, Edan and Morad, Steven},
journal={arXiv preprint arXiv:2503.01450},
year={2025}
}
Project details
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file popgym_arcade-0.0.6.tar.gz.
File metadata
- Download URL: popgym_arcade-0.0.6.tar.gz
- Upload date:
- Size: 101.3 kB
- Tags: Source
- Uploaded using Trusted Publishing? Yes
- Uploaded via: twine/6.1.0 CPython/3.13.7
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
61cc45783d48ba46a24704c509f41a1fd39a911f65e0d7a5bc80be1c28f88367
|
|
| MD5 |
3eb8c065761972671ab51e0747139c51
|
|
| BLAKE2b-256 |
6d5354931a8b9e7c81e89abcf5ddc69515070cebf6c397576f77d57808551ec5
|
Provenance
The following attestation bundles were made for popgym_arcade-0.0.6.tar.gz:
Publisher:
python-publish.yml on bolt-research/popgym-arcade
-
Statement:
-
Statement type:
https://in-toto.io/Statement/v1 -
Predicate type:
https://docs.pypi.org/attestations/publish/v1 -
Subject name:
popgym_arcade-0.0.6.tar.gz -
Subject digest:
61cc45783d48ba46a24704c509f41a1fd39a911f65e0d7a5bc80be1c28f88367 - Sigstore transparency entry: 752269987
- Sigstore integration time:
-
Permalink:
bolt-research/popgym-arcade@462d18a050e4710e45c0bdd07a4e71e981ee5445 -
Branch / Tag:
refs/tags/0.0.6 - Owner: https://github.com/bolt-research
-
Access:
public
-
Token Issuer:
https://token.actions.githubusercontent.com -
Runner Environment:
github-hosted -
Publication workflow:
python-publish.yml@462d18a050e4710e45c0bdd07a4e71e981ee5445 -
Trigger Event:
release
-
Statement type:
File details
Details for the file popgym_arcade-0.0.6-py3-none-any.whl.
File metadata
- Download URL: popgym_arcade-0.0.6-py3-none-any.whl
- Upload date:
- Size: 128.6 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? Yes
- Uploaded via: twine/6.1.0 CPython/3.13.7
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
deb2462ae370dfdc88668af3c1f099da6877eb7557783f9868b1c46308400557
|
|
| MD5 |
abac144a539620c68d84de8aa3c34d33
|
|
| BLAKE2b-256 |
32ea3ff6b36a8a83a1e7af346f52cf33f5b9607ed71c7b6f938899db87e6e546
|
Provenance
The following attestation bundles were made for popgym_arcade-0.0.6-py3-none-any.whl:
Publisher:
python-publish.yml on bolt-research/popgym-arcade
-
Statement:
-
Statement type:
https://in-toto.io/Statement/v1 -
Predicate type:
https://docs.pypi.org/attestations/publish/v1 -
Subject name:
popgym_arcade-0.0.6-py3-none-any.whl -
Subject digest:
deb2462ae370dfdc88668af3c1f099da6877eb7557783f9868b1c46308400557 - Sigstore transparency entry: 752269991
- Sigstore integration time:
-
Permalink:
bolt-research/popgym-arcade@462d18a050e4710e45c0bdd07a4e71e981ee5445 -
Branch / Tag:
refs/tags/0.0.6 - Owner: https://github.com/bolt-research
-
Access:
public
-
Token Issuer:
https://token.actions.githubusercontent.com -
Runner Environment:
github-hosted -
Publication workflow:
python-publish.yml@462d18a050e4710e45c0bdd07a4e71e981ee5445 -
Trigger Event:
release
-
Statement type: