POMDP Arcade Environments on the GPU
Project description
POPGym Arcade - GPU-Accelerated POMDPs
POPGym Arcade contains 7 pixel-based POMDPs in the style of the Arcade Learning Environment. Each environment provides:
- 3 Difficulty settings
- Common observation and action space shared across all envs
- Fully observable and partially observable configurations
- Fast and easy GPU vectorization using
jax.vmapandjax.jit
Gradient Visualization
We also provide tools to visualize how policies use memory.
See below for further instructions.
Throughput
You can expect millions of frames per second on a consumer-grade GPU. With obs_size=128, most policies converge within 30-60 minutes of training.
Getting Started
Installation
To install the environments, run
pip install git+https://www.github.com/bolt-research/popgym-arcade
Or from source
git clone https://www.github.com/bolt-research/popgym-arcade
cd popgym_arcade
pip install -e .
If you plan to use our training scripts, install the baselines as well
pip install popgym_arcade[baselines] @ git+https://www.github.com/bolt-research/popgym-arcade.git
Or from source
git clone https://www.github.com/bolt-research/popgym-arcade
cd popgym_arcade
pip install -e '.[baselines]'
Creating and Stepping Environments
import popgym_arcade
import jax
# Create both POMDP and MDP env variants
pomdp, pomdp_params = popgym_arcade.make("BattleShipEasy", partial_obs=True)
mdp, mdp_params = popgym_arcade.make("BattleShipEasy", partial_obs=False)
# Let's vectorize and compile the envs
# Note when you are training a policy, it is better to compile your policy_update rather than the env_step
pomdp_reset = jax.jit(jax.vmap(pomdp.reset, in_axes=(0, None)))
pomdp_step = jax.jit(jax.vmap(pomdp.step, in_axes=(0, 0, 0, None)))
mdp_reset = jax.jit(jax.vmap(mdp.reset, in_axes=(0, None)))
mdp_step = jax.jit(jax.vmap(mdp.step, in_axes=(0, 0, 0, None)))
# Initialize four vectorized environments
n_envs = 4
# Initialize PRNG keys
key = jax.random.key(0)
reset_keys = jax.random.split(key, n_envs)
# Reset environments
observation, env_state = pomdp_reset(reset_keys, pomdp_params)
# Step the POMDPs
for t in range(10):
# Propagate some randomness
action_key, step_key = jax.random.split(jax.random.key(t))
action_keys = jax.random.split(action_key, n_envs)
step_keys = jax.random.split(step_key, n_envs)
# Pick actions at random
actions = jax.vmap(pomdp.action_space(pomdp_params).sample)(action_keys)
# Step the env to the next state
# No need to reset, gymnax automatically resets when done
observation, env_state, reward, done, info = pomdp_step(step_keys, env_state, actions, pomdp_params)
# POMDP and MDP variants share states
# We can plug the POMDP states into the MDP and continue playing
action_keys = jax.random.split(jax.random.key(t + 1), n_envs)
step_keys = jax.random.split(jax.random.key(t + 2), n_envs)
markov_state, env_state, reward, done, info = mdp_step(step_keys, env_state, actions, mdp_params)
Human Play
To best understand the environments, you should try and play them yourself. You can easily integrate with popgym-arcade with pygame.
First, you'll need to install pygame
pip install pygame
Try the play script to play the games yourself! All games accept arrow key input and spacebar.
python play.py
Memory Introspection Tools
We implement visualization tools to probe which pixels persist in agent memory, and their impact on Q value predictions. Try code below or vis example to visualize the memory your agent uses
from popgym_arcade.baselines.model.builder import QNetworkRNN
from popgym_arcade.baselines.utils import get_saliency_maps, vis_fn
import equinox as eqx
import jax
config = {
"ENV_NAME": "NavigatorEasy",
"PARTIAL": True,
"MEMORY_TYPE": "lru",
"SEED": 0,
"OBS_SIZE": 128,
}
# Initialize the random key
rng = jax.random.PRNGKey(config["SEED"])
# Initialize the model
network = QNetworkRNN(rng, rnn_type=config["MEMORY_TYPE"], obs_size=config["OBS_SIZE"])
# Load the model
model = eqx.tree_deserialise_leaves("PATH_TO_YOUR_MODEL_WEIGHTS.pkl", network)
# Compute the saliency maps
grads, obs_seq, grad_accumulator = get_saliency_maps(rng, model, config)
# Visualize the saliency maps
# If you have latex installed, set use_latex=True
vis_fn(grads, obs_seq, config, use_latex=False)
Other Useful Libraries
gymnax- The (deprecated)jax-capablegymnasiumAPIstable-gymnax- A maintained and patched version ofgymnaxpopgym- The original collection of POMDPs, implemented innumpypopjaxrl- Ajaxversion ofpopgympopjym- A more readable version ofpopjaxrlenvironments that served as a basis for our work
Citation
@article{wang2025popgym,
title={POPGym Arcade: Parallel Pixelated POMDPs},
author={Wang, Zekang and He, Zhe and Toledo, Edan and Morad, Steven},
journal={arXiv preprint arXiv:2503.01450},
year={2025}
}
Project details
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file popgym_arcade-0.0.1.tar.gz.
File metadata
- Download URL: popgym_arcade-0.0.1.tar.gz
- Upload date:
- Size: 77.4 kB
- Tags: Source
- Uploaded using Trusted Publishing? Yes
- Uploaded via: twine/6.1.0 CPython/3.12.9
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
9430c3802c537d2da8f8e26d93c52a3178d091fdf9114ce6c9ae44c7057532a5
|
|
| MD5 |
84ec6496ab7f0432ab8e21a7f8498075
|
|
| BLAKE2b-256 |
edf6366451e38318212ab09d9fef800a8e614de6634f9003158cbb8ba854c8b6
|
Provenance
The following attestation bundles were made for popgym_arcade-0.0.1.tar.gz:
Publisher:
python-publish.yml on bolt-research/popgym-arcade
-
Statement:
-
Statement type:
https://in-toto.io/Statement/v1 -
Predicate type:
https://docs.pypi.org/attestations/publish/v1 -
Subject name:
popgym_arcade-0.0.1.tar.gz -
Subject digest:
9430c3802c537d2da8f8e26d93c52a3178d091fdf9114ce6c9ae44c7057532a5 - Sigstore transparency entry: 236067548
- Sigstore integration time:
-
Permalink:
bolt-research/popgym-arcade@d9a4e87223813a0173566844ce387d18e30ecf46 -
Branch / Tag:
refs/tags/0.0.1 - Owner: https://github.com/bolt-research
-
Access:
public
-
Token Issuer:
https://token.actions.githubusercontent.com -
Runner Environment:
github-hosted -
Publication workflow:
python-publish.yml@d9a4e87223813a0173566844ce387d18e30ecf46 -
Trigger Event:
release
-
Statement type:
File details
Details for the file popgym_arcade-0.0.1-py3-none-any.whl.
File metadata
- Download URL: popgym_arcade-0.0.1-py3-none-any.whl
- Upload date:
- Size: 113.5 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? Yes
- Uploaded via: twine/6.1.0 CPython/3.12.9
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
d09fb8785ef6a2dc37469781b77ab1b8e583fb959aae42aad04515c05874ce14
|
|
| MD5 |
bc4e5efc34fd8c5c213650cdc202f9da
|
|
| BLAKE2b-256 |
b051d2dbd0a785a768b73d0eece12ae9cba949cf35de4b4511bed799d21fb498
|
Provenance
The following attestation bundles were made for popgym_arcade-0.0.1-py3-none-any.whl:
Publisher:
python-publish.yml on bolt-research/popgym-arcade
-
Statement:
-
Statement type:
https://in-toto.io/Statement/v1 -
Predicate type:
https://docs.pypi.org/attestations/publish/v1 -
Subject name:
popgym_arcade-0.0.1-py3-none-any.whl -
Subject digest:
d09fb8785ef6a2dc37469781b77ab1b8e583fb959aae42aad04515c05874ce14 - Sigstore transparency entry: 236067552
- Sigstore integration time:
-
Permalink:
bolt-research/popgym-arcade@d9a4e87223813a0173566844ce387d18e30ecf46 -
Branch / Tag:
refs/tags/0.0.1 - Owner: https://github.com/bolt-research
-
Access:
public
-
Token Issuer:
https://token.actions.githubusercontent.com -
Runner Environment:
github-hosted -
Publication workflow:
python-publish.yml@d9a4e87223813a0173566844ce387d18e30ecf46 -
Trigger Event:
release
-
Statement type: