Skip to main content

POMDP Arcade Environments on the GPU

Project description

POPGym Arcade - GPU-Accelerated POMDPs

Tests

GIF 2 GIF 3 GIF 4 GIF 5 GIF 6 GIF 7
GIF 2 GIF 3 GIF 4 GIF 5 GIF 6 GIF 7

POPGym Arcade contains 7 pixel-based POMDPs in the style of the Arcade Learning Environment. Each environment provides:

  • 3 Difficulty settings
  • Common observation and action space shared across all envs
  • Fully observable and partially observable configurations
  • Fast and easy GPU vectorization using jax.vmap and jax.jit

Gradient Visualization

We also provide tools to visualize how policies use memory.

See below for further instructions.

Throughput

You can expect millions of frames per second on a consumer-grade GPU. With obs_size=128, most policies converge within 30-60 minutes of training.

Getting Started

Installation

To install the environments, run

pip install git+https://www.github.com/bolt-research/popgym-arcade

Or from source

git clone https://www.github.com/bolt-research/popgym-arcade
cd popgym_arcade
pip install -e .

If you plan to use our training scripts, install the baselines as well

pip install popgym_arcade[baselines] @ git+https://www.github.com/bolt-research/popgym-arcade.git

Or from source

git clone https://www.github.com/bolt-research/popgym-arcade
cd popgym_arcade
pip install -e '.[baselines]'

Creating and Stepping Environments

import popgym_arcade
import jax

# Create both POMDP and MDP env variants
pomdp, pomdp_params = popgym_arcade.make("BattleShipEasy", partial_obs=True)
mdp, mdp_params = popgym_arcade.make("BattleShipEasy", partial_obs=False)

# Let's vectorize and compile the envs
# Note when you are training a policy, it is better to compile your policy_update rather than the env_step
pomdp_reset = jax.jit(jax.vmap(pomdp.reset, in_axes=(0, None)))
pomdp_step = jax.jit(jax.vmap(pomdp.step, in_axes=(0, 0, 0, None)))
mdp_reset = jax.jit(jax.vmap(mdp.reset, in_axes=(0, None)))
mdp_step = jax.jit(jax.vmap(mdp.step, in_axes=(0, 0, 0, None)))
    
# Initialize four vectorized environments
n_envs = 4
# Initialize PRNG keys
key = jax.random.key(0)
reset_keys = jax.random.split(key, n_envs)
    
# Reset environments
observation, env_state = pomdp_reset(reset_keys, pomdp_params)

# Step the POMDPs
for t in range(10):
    # Propagate some randomness
    action_key, step_key = jax.random.split(jax.random.key(t))
    action_keys = jax.random.split(action_key, n_envs)
    step_keys = jax.random.split(step_key, n_envs)
    # Pick actions at random
    actions = jax.vmap(pomdp.action_space(pomdp_params).sample)(action_keys)
    # Step the env to the next state
    # No need to reset, gymnax automatically resets when done
    observation, env_state, reward, done, info = pomdp_step(step_keys, env_state, actions, pomdp_params)

# POMDP and MDP variants share states
# We can plug the POMDP states into the MDP and continue playing 
action_keys = jax.random.split(jax.random.key(t + 1), n_envs)
step_keys = jax.random.split(jax.random.key(t + 2), n_envs)
markov_state, env_state, reward, done, info = mdp_step(step_keys, env_state, actions, mdp_params)

Human Play

To best understand the environments, you should try and play them yourself. You can easily integrate with popgym-arcade with pygame.

First, you'll need to install pygame

pip install pygame

Try the play script to play the games yourself! All games accept arrow key input and spacebar.

python play.py

Memory Introspection Tools

We implement visualization tools to probe which pixels persist in agent memory, and their impact on Q value predictions. Try code below or vis example to visualize the memory your agent uses

from popgym_arcade.baselines.model.builder import QNetworkRNN
from popgym_arcade.baselines.utils import get_saliency_maps, vis_fn
import equinox as eqx
import jax

config = {
    "ENV_NAME": "NavigatorEasy",
    "PARTIAL": True,
    "MEMORY_TYPE": "lru",
    "SEED": 0,
    "OBS_SIZE": 128,
}

# Initialize the random key
rng = jax.random.PRNGKey(config["SEED"])

# Initialize the model
network = QNetworkRNN(rng, rnn_type=config["MEMORY_TYPE"], obs_size=config["OBS_SIZE"])
# Load the model
model = eqx.tree_deserialise_leaves("PATH_TO_YOUR_MODEL_WEIGHTS.pkl", network)
# Compute the saliency maps
grads, obs_seq, grad_accumulator = get_saliency_maps(rng, model, config)
# Visualize the saliency maps
# If you have latex installed, set use_latex=True
vis_fn(grads, obs_seq, config, use_latex=False)

Other Useful Libraries

  • gymnax - The (deprecated) jax-capable gymnasium API
  • stable-gymnax - A maintained and patched version of gymnax
  • popgym - The original collection of POMDPs, implemented in numpy
  • popjaxrl - A jax version of popgym
  • popjym - A more readable version of popjaxrl environments that served as a basis for our work

Citation

@article{wang2025popgym,
  title={POPGym Arcade: Parallel Pixelated POMDPs},
  author={Wang, Zekang and He, Zhe and Toledo, Edan and Morad, Steven},
  journal={arXiv preprint arXiv:2503.01450},
  year={2025}
}

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

popgym_arcade-0.0.1.tar.gz (77.4 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

popgym_arcade-0.0.1-py3-none-any.whl (113.5 kB view details)

Uploaded Python 3

File details

Details for the file popgym_arcade-0.0.1.tar.gz.

File metadata

  • Download URL: popgym_arcade-0.0.1.tar.gz
  • Upload date:
  • Size: 77.4 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: twine/6.1.0 CPython/3.12.9

File hashes

Hashes for popgym_arcade-0.0.1.tar.gz
Algorithm Hash digest
SHA256 9430c3802c537d2da8f8e26d93c52a3178d091fdf9114ce6c9ae44c7057532a5
MD5 84ec6496ab7f0432ab8e21a7f8498075
BLAKE2b-256 edf6366451e38318212ab09d9fef800a8e614de6634f9003158cbb8ba854c8b6

See more details on using hashes here.

Provenance

The following attestation bundles were made for popgym_arcade-0.0.1.tar.gz:

Publisher: python-publish.yml on bolt-research/popgym-arcade

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

File details

Details for the file popgym_arcade-0.0.1-py3-none-any.whl.

File metadata

  • Download URL: popgym_arcade-0.0.1-py3-none-any.whl
  • Upload date:
  • Size: 113.5 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: twine/6.1.0 CPython/3.12.9

File hashes

Hashes for popgym_arcade-0.0.1-py3-none-any.whl
Algorithm Hash digest
SHA256 d09fb8785ef6a2dc37469781b77ab1b8e583fb959aae42aad04515c05874ce14
MD5 bc4e5efc34fd8c5c213650cdc202f9da
BLAKE2b-256 b051d2dbd0a785a768b73d0eece12ae9cba949cf35de4b4511bed799d21fb498

See more details on using hashes here.

Provenance

The following attestation bundles were made for popgym_arcade-0.0.1-py3-none-any.whl:

Publisher: python-publish.yml on bolt-research/popgym-arcade

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page