Accelerated CHIP-8 Arcade RL Environments for JAX
Project description
OCTAX: Accelerated CHIP-8 Arcade Environments for Reinforcement Learning in JAX
High-performance CHIP-8 arcade game environments for reinforcement learning research
Features • Installation • Quick Start • Games • Performance • Citation
📄 Preprint is available at: https://arxiv.org/abs/2510.01764
OCTAX provides a JAX-based suite of classic arcade game environments implemented through CHIP-8 emulation. It offers orders-of-magnitude speedups over traditional CPU emulators while maintaining perfect fidelity to original game mechanics, making it ideal for large-scale reinforcement learning experimentation.
Sample of 20+ classic arcade games available in OCTAX
Why OCTAX?
Modern RL research demands extensive experimentation with thousands of parallel environments and comprehensive hyperparameter sweeps. Traditional arcade emulators remain CPU-bound, creating a computational bottleneck. OCTAX solves this by:
- 🚀 GPU Acceleration: End-to-end JAX implementation runs thousands of game instances in parallel
- ⚡ Massive Speedups: 14× faster than CPU-based alternatives at high parallelization
- 🎮 Authentic Games: Perfect fidelity to original CHIP-8 mechanics across 20+ games
- 🔧 Easy Integration: Compatible with Gymnasium and popular RL frameworks
- 📊 Research-Ready: Spans puzzle, action, strategy, and exploration genres
Key Features
End-to-End GPU Acceleration
- Fully vectorized CHIP-8 emulation in JAX
- JIT-compiled for maximum performance
- Scales from single environments to 8192+ parallel instances
- Eliminates CPU-GPU transfer bottlenecks
Diverse Game Portfolio
-
20+ Games
spanning multiple genres:
- Puzzle: Tetris, Blinky (Pac-Man), Worm (Snake)
- Action: Brix (Breakout), Pong, Squash, Wipe-Off
- Strategy: Missile Command, Tank Battle, UFO
- Exploration: Cavern (7 levels), Space Flight (10 levels)
- Shooter: Airplane, Deep8, Shooting Stars
Research-Friendly Design
- Gymnax-compatible interface for seamless JAX integration
- Customizable reward functions and termination conditions
- Frame stacking and observation preprocessing
- Multiple color schemes for visualization
- Modular architecture for easy extension
Built for Scale
- Train experiments that took days in hours
- Run comprehensive hyperparameter sweeps feasibly
- Achieve statistical validity with hundreds of seeds
- Perfect for curriculum learning and meta-RL research
Installation
# Basic
pip install octax
# With visualization
pip install "octax[gui]"
# With RL training
pip install "octax[training]"
# Everything
pip install "octax[all]"
For GPU acceleration (highly recommended):
pip install --upgrade "jax[cuda12]"
# Optional to run the train.py script
pip install -r requirements_training.txt
From source:
git clone https://github.com/riiswa/octax.git
cd octax
pip install -e .
Quick Start
Basic Environment Usage
import jax
import jax.numpy as jnp
from octax.environments import create_environment
# Create environment
env, metadata = create_environment("brix")
print(f"Playing: {metadata['title']}")
# Simple random policy
@jax.jit
def rollout(rng):
state, obs, info = env.reset(rng)
def step(carry, _):
rng, state, obs = carry
rng, action_rng = jax.random.split(rng)
action = jax.random.randint(action_rng, (), 0, env.num_actions)
next_state, next_obs, reward, terminated, truncated, info = env.step(state, action)
return (rng, next_state, next_obs), reward
final_carry, rewards = jax.lax.scan(step, (rng, state, obs), length=1000)
return jnp.sum(rewards)
# Run episode
rng = jax.random.PRNGKey(0)
total_reward = rollout(rng)
print(f"Total reward: {total_reward}")
Vectorized Training (64 Parallel Environments)
# Run 64 environments in parallel
@jax.jit
def vectorized_rollout(rng, num_envs=64):
rngs = jax.random.split(rng, num_envs)
return jax.vmap(rollout)(rngs)
rewards = vectorized_rollout(rng, 64)
print(f"Mean reward: {jnp.mean(rewards):.2f} ± {jnp.std(rewards):.2f}")
Gymnax Integration
from octax.environments import create_environment
from octax.wrappers import OctaxGymnaxWrapper
# Create Gymnax-compatible environment
env, metadata = create_environment("brix")
gymnax_env = OctaxGymnaxWrapper(env)
env_params = gymnax_env.default_params
# Use with any Gymnax-compatible algorithm
rng = jax.random.PRNGKey(0)
obs, state = gymnax_env.reset(rng, env_params)
for _ in range(100):
rng, rng_action, rng_step = jax.random.split(rng, 3)
action = gymnax_env.action_space(env_params).sample(rng_action)
obs, state, reward, done, info = gymnax_env.step(
rng_step, state, action, env_params
)
Observation Space: The agent receives the raw CHIP-8 display as a (frame_skip, 32, 64) boolean array, where frame_skip (default: 4) provides temporal information. Each frame is a 32×64 monochrome image capturing the complete visual state—exactly what a human player would see.
Action Space: A discrete space where actions map to game-specific CHIP-8 keys. For example, Pong uses [1, 4] (up/down), Brix uses [4, 6] (left/right), and Tetris uses [4, 5, 6, 7] (rotate/move). An additional no-op action is always available. Games automatically configure their relevant action subsets, eliminating irrelevant keys from the action space.
Available Games
| Category | Games | Required Capabilities |
|---|---|---|
| Puzzle | Tetris, Blinky, Worm | Long-horizon planning, spatial reasoning |
| Action | Brix, Pong, Squash, Vertical Brix, Wipe Off, Filter | Timing, prediction, reactive control |
| Strategy | Missile Command, Rocket, Submarine, Tank Battle, UFO | Resource management, tactical decisions |
| Exploration | Cavern (7 levels), Flight Runner, Space Flight (10 levels), Spacejam! | Spatial exploration, continuous navigation |
| Shooter | Airplane, Deep8, Shooting Stars | Simple reaction, basic timing |
All environments support:
- Customizable frame skip and action repeat
- Configurable episode lengths
- Built-in score tracking
- Multiple rendering modes
- Frame stacking for temporal information
Performance
OCTAX achieves substantial speedups over traditional CPU-based environments through JAX vectorization:
OCTAX vs EnvPool performance scaling across parallelization levels (RTX 3090)
Key Results:
- 14× faster than EnvPool at 8192 parallel environments
- 350,000 steps/second on a single RTX 3090
- Near-linear scaling up to GPU memory limits
- Sub-second compilation for fast iteration
This enables:
- Comprehensive experiments: Run 100+ hyperparameter configurations overnight
- Statistical rigor: Train with 50+ random seeds for reliable results
- Rapid prototyping: Iterate on algorithms with immediate feedback
Training Results
PPO and PQN training across 16 diverse games (5M timesteps, 12 seeds each):
PPO and PQN learning curves showing diverse challenges across game genres
Project Structure
octax/
├── octax/ # Core package
│ ├── emulator.py # CHIP-8 emulator implementation
│ ├── env.py # RL environment wrapper
│ ├── environments/ # Game-specific configurations
│ ├── instructions/ # CHIP-8 instruction handlers
│ ├── rendering.py # Visualization utilities
│ └── wrappers.py # Gymnax compatibility wrapper
├── examples/ # Usage examples
│ ├── example.py # Basic rollout
│ ├── training_on_octax.py # PPO training demo
│ └── rendering_demo.py # Visualization examples
├── tests/ # Comprehensive test suite
└── tutorial/ # In-depth tutorials
Contributing
We welcome contributions! Ways to contribute:
Adding New Games
-
Find a CHIP-8 ROM: Many public domain games available
-
Analyze the game
: Use our interactive emulator (
main.py) to identify:
- Score registers (look for BCD operations marked with 🎯)
- Termination conditions (game over states)
- Required controls (which keys are used)
-
Create environment file: See
octax/environments/for examples -
Test and submit: Ensure score/termination work correctly
# octax/environments/my_game.py
from octax import EmulatorState
rom_file = "my_game.ch8"
def score_fn(state: EmulatorState) -> float:
return state.V[5] # Score in register V5
def terminated_fn(state: EmulatorState) -> bool:
return state.V[12] == 0 # Game over when V12 reaches 0
action_set = [4, 6] # Left/Right controls
startup_instructions = 500 # Skip menu screens
metadata = {
"title": "My Game",
"description": "A classic arcade game",
"release": "2024",
"authors": ["Author Name"]
}
Other Contributions
- Bug fixes and performance improvements
- Documentation enhancements
- Additional examples and tutorials
- New features (Super-CHIP8 support, etc.)
Please open an issue to discuss major changes before implementing.
Citation
If you use OCTAX in your research, please cite our paper:
@misc{radji2025octax,
title={Octax: Accelerated CHIP-8 Arcade Environments for Reinforcement Learning in JAX},
author={Waris Radji and Thomas Michel and Hector Piteau},
year={2025},
eprint={2510.01764},
archivePrefix={arXiv},
primaryClass={cs.LG}
}
License
OCTAX is released under the MIT License. See LICENSE for details.
Acknowledgments
- CHIP-8 games from the CHIP-8 Database
- Inspired by the Arcade Learning Environment
- Built with JAX, Flax, and Optax
Made with ❤️ for the RL research community
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file octax-0.1.1.tar.gz.
File metadata
- Download URL: octax-0.1.1.tar.gz
- Upload date:
- Size: 64.5 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.2.0 CPython/3.13.5
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
97a5252770746afbbef19a7bde5103ebc6293e56f21d01dc0a74e41206a2c33b
|
|
| MD5 |
a8fce41e558eef60b5a6042b6796983d
|
|
| BLAKE2b-256 |
29936fdc6a99a154a53744d7bd6925e06275388a1b83ee26a36c170c7462c6e2
|
File details
Details for the file octax-0.1.1-py3-none-any.whl.
File metadata
- Download URL: octax-0.1.1-py3-none-any.whl
- Upload date:
- Size: 83.0 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.2.0 CPython/3.13.5
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
77a3606921763ad022be7969c1a2006e88ea5fc2105ba875bfd5d7170ce8009b
|
|
| MD5 |
203a287bfdc20b4b299d938b1fde668f
|
|
| BLAKE2b-256 |
2b09558db9c58b3d0f0404bcf5dae53e1a14aa0839d01d5166196198e2caec2f
|