Skip to main content

Accelerated CHIP-8 Arcade RL Environments for JAX

Project description

Octax Logo OCTAX: Accelerated CHIP-8 Arcade Environments for Reinforcement Learning in JAX

Build Linux Build Windows Build macOS Python Platform PyPI License

High-performance CHIP-8 arcade game environments for reinforcement learning research

FeaturesInstallationQuick StartGamesPerformanceCitation

📄 Preprint is available at: https://arxiv.org/abs/2510.01764


OCTAX provides a JAX-based suite of classic arcade game environments implemented through CHIP-8 emulation. It offers orders-of-magnitude speedups over traditional CPU emulators while maintaining perfect fidelity to original game mechanics, making it ideal for large-scale reinforcement learning experimentation.

OCTAX Games Overview
Sample of 20+ classic arcade games available in OCTAX

Why OCTAX?

Modern RL research demands extensive experimentation with thousands of parallel environments and comprehensive hyperparameter sweeps. Traditional arcade emulators remain CPU-bound, creating a computational bottleneck. OCTAX solves this by:

  • 🚀 GPU Acceleration: End-to-end JAX implementation runs thousands of game instances in parallel
  • ⚡ Massive Speedups: 14× faster than CPU-based alternatives at high parallelization
  • 🎮 Authentic Games: Perfect fidelity to original CHIP-8 mechanics across 20+ games
  • 🔧 Easy Integration: Compatible with Gymnasium and popular RL frameworks
  • 📊 Research-Ready: Spans puzzle, action, strategy, and exploration genres

Key Features

End-to-End GPU Acceleration

  • Fully vectorized CHIP-8 emulation in JAX
  • JIT-compiled for maximum performance
  • Scales from single environments to 8192+ parallel instances
  • Eliminates CPU-GPU transfer bottlenecks

Diverse Game Portfolio

  • 20+ Games

    spanning multiple genres:

    • Puzzle: Tetris, Blinky (Pac-Man), Worm (Snake)
    • Action: Brix (Breakout), Pong, Squash, Wipe-Off
    • Strategy: Missile Command, Tank Battle, UFO
    • Exploration: Cavern (7 levels), Space Flight (10 levels)
    • Shooter: Airplane, Deep8, Shooting Stars

Research-Friendly Design

  • Gymnax-compatible interface for seamless JAX integration
  • Customizable reward functions and termination conditions
  • Frame stacking and observation preprocessing
  • Multiple color schemes for visualization
  • Modular architecture for easy extension

Built for Scale

  • Train experiments that took days in hours
  • Run comprehensive hyperparameter sweeps feasibly
  • Achieve statistical validity with hundreds of seeds
  • Perfect for curriculum learning and meta-RL research

Installation

# Basic
pip install octax

# With visualization
pip install "octax[gui]"

# With RL training
pip install "octax[training]"

# Everything
pip install "octax[all]"

For GPU acceleration (highly recommended):

pip install --upgrade "jax[cuda12]"

# Optional to run the train.py script
pip install -r requirements_training.txt

From source:

git clone https://github.com/riiswa/octax.git
cd octax
pip install -e .

Quick Start

Basic Environment Usage

import jax
import jax.numpy as jnp
from octax.environments import create_environment

# Create environment
env, metadata = create_environment("brix")
print(f"Playing: {metadata['title']}")

# Simple random policy
@jax.jit
def rollout(rng):
    state, obs, info = env.reset(rng)
    
    def step(carry, _):
        rng, state, obs = carry
        rng, action_rng = jax.random.split(rng)
        action = jax.random.randint(action_rng, (), 0, env.num_actions)
        
        next_state, next_obs, reward, terminated, truncated, info = env.step(state, action)
        return (rng, next_state, next_obs), reward
    
    final_carry, rewards = jax.lax.scan(step, (rng, state, obs), length=1000)
    return jnp.sum(rewards)

# Run episode
rng = jax.random.PRNGKey(0)
total_reward = rollout(rng)
print(f"Total reward: {total_reward}")

Vectorized Training (64 Parallel Environments)

# Run 64 environments in parallel
@jax.jit
def vectorized_rollout(rng, num_envs=64):
    rngs = jax.random.split(rng, num_envs)
    return jax.vmap(rollout)(rngs)

rewards = vectorized_rollout(rng, 64)
print(f"Mean reward: {jnp.mean(rewards):.2f} ± {jnp.std(rewards):.2f}")

Gymnax Integration

from octax.environments import create_environment
from octax.wrappers import OctaxGymnaxWrapper

# Create Gymnax-compatible environment
env, metadata = create_environment("brix")
gymnax_env = OctaxGymnaxWrapper(env)
env_params = gymnax_env.default_params

# Use with any Gymnax-compatible algorithm
rng = jax.random.PRNGKey(0)
obs, state = gymnax_env.reset(rng, env_params)

for _ in range(100):
    rng, rng_action, rng_step = jax.random.split(rng, 3)
    action = gymnax_env.action_space(env_params).sample(rng_action)
    obs, state, reward, done, info = gymnax_env.step(
        rng_step, state, action, env_params
    )

Observation Space: The agent receives the raw CHIP-8 display as a (frame_skip, 32, 64) boolean array, where frame_skip (default: 4) provides temporal information. Each frame is a 32×64 monochrome image capturing the complete visual state—exactly what a human player would see.

Action Space: A discrete space where actions map to game-specific CHIP-8 keys. For example, Pong uses [1, 4] (up/down), Brix uses [4, 6] (left/right), and Tetris uses [4, 5, 6, 7] (rotate/move). An additional no-op action is always available. Games automatically configure their relevant action subsets, eliminating irrelevant keys from the action space.

Available Games

Category Games Required Capabilities
Puzzle Tetris, Blinky, Worm Long-horizon planning, spatial reasoning
Action Brix, Pong, Squash, Vertical Brix, Wipe Off, Filter Timing, prediction, reactive control
Strategy Missile Command, Rocket, Submarine, Tank Battle, UFO Resource management, tactical decisions
Exploration Cavern (7 levels), Flight Runner, Space Flight (10 levels), Spacejam! Spatial exploration, continuous navigation
Shooter Airplane, Deep8, Shooting Stars Simple reaction, basic timing

All environments support:

  • Customizable frame skip and action repeat
  • Configurable episode lengths
  • Built-in score tracking
  • Multiple rendering modes
  • Frame stacking for temporal information

Performance

OCTAX achieves substantial speedups over traditional CPU-based environments through JAX vectorization:

Performance Scaling
OCTAX vs EnvPool performance scaling across parallelization levels (RTX 3090)

Key Results:

  • 14× faster than EnvPool at 8192 parallel environments
  • 350,000 steps/second on a single RTX 3090
  • Near-linear scaling up to GPU memory limits
  • Sub-second compilation for fast iteration

This enables:

  • Comprehensive experiments: Run 100+ hyperparameter configurations overnight
  • Statistical rigor: Train with 50+ random seeds for reliable results
  • Rapid prototyping: Iterate on algorithms with immediate feedback

Training Results

PPO and PQN training across 16 diverse games (5M timesteps, 12 seeds each):

Training Results
PPO and PQN learning curves showing diverse challenges across game genres

Project Structure

octax/
├── octax/                  # Core package
│   ├── emulator.py        # CHIP-8 emulator implementation
│   ├── env.py             # RL environment wrapper
│   ├── environments/      # Game-specific configurations
│   ├── instructions/      # CHIP-8 instruction handlers
│   ├── rendering.py       # Visualization utilities
│   └── wrappers.py        # Gymnax compatibility wrapper
├── examples/              # Usage examples
│   ├── example.py        # Basic rollout
│   ├── training_on_octax.py  # PPO training demo
│   └── rendering_demo.py # Visualization examples
├── tests/                 # Comprehensive test suite
└── tutorial/              # In-depth tutorials

Contributing

We welcome contributions! Ways to contribute:

Adding New Games

  1. Find a CHIP-8 ROM: Many public domain games available

  2. Analyze the game

    : Use our interactive emulator (

    main.py
    

    ) to identify:

    • Score registers (look for BCD operations marked with 🎯)
    • Termination conditions (game over states)
    • Required controls (which keys are used)
  3. Create environment file: See octax/environments/ for examples

  4. Test and submit: Ensure score/termination work correctly

# octax/environments/my_game.py
from octax import EmulatorState

rom_file = "my_game.ch8"

def score_fn(state: EmulatorState) -> float:
    return state.V[5]  # Score in register V5

def terminated_fn(state: EmulatorState) -> bool:
    return state.V[12] == 0  # Game over when V12 reaches 0

action_set = [4, 6]  # Left/Right controls
startup_instructions = 500  # Skip menu screens

metadata = {
    "title": "My Game",
    "description": "A classic arcade game",
    "release": "2024",
    "authors": ["Author Name"]
}

Other Contributions

  • Bug fixes and performance improvements
  • Documentation enhancements
  • Additional examples and tutorials
  • New features (Super-CHIP8 support, etc.)

Please open an issue to discuss major changes before implementing.

Citation

If you use OCTAX in your research, please cite our paper:

@misc{radji2025octax,
    title={Octax: Accelerated CHIP-8 Arcade Environments for Reinforcement Learning in JAX},
    author={Waris Radji and Thomas Michel and Hector Piteau},
    year={2025},
    eprint={2510.01764},
    archivePrefix={arXiv},
    primaryClass={cs.LG}
}

License

OCTAX is released under the MIT License. See LICENSE for details.

Acknowledgments


Made with ❤️ for the RL research community

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

octax-0.1.0.tar.gz (52.0 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

octax-0.1.0-py3-none-any.whl (55.0 kB view details)

Uploaded Python 3

File details

Details for the file octax-0.1.0.tar.gz.

File metadata

  • Download URL: octax-0.1.0.tar.gz
  • Upload date:
  • Size: 52.0 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.13.5

File hashes

Hashes for octax-0.1.0.tar.gz
Algorithm Hash digest
SHA256 06aaaeaff15c17fa35bd855307f912d22bc08904c236d41030725e54fb17cdc9
MD5 9a7033c80f1e285311d4c8b55df10f2a
BLAKE2b-256 3b280d4cedc26118b3bf89a50fe166799c651f912078801e936236226069e343

See more details on using hashes here.

File details

Details for the file octax-0.1.0-py3-none-any.whl.

File metadata

  • Download URL: octax-0.1.0-py3-none-any.whl
  • Upload date:
  • Size: 55.0 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.13.5

File hashes

Hashes for octax-0.1.0-py3-none-any.whl
Algorithm Hash digest
SHA256 93126591e4cb1246f9fbc8d9a5df5e94f367d539e261fdb6912dcfc9b0b487a5
MD5 0c88b5effa7fa2ca0252789e7ff3dba0
BLAKE2b-256 33b755af6eeaed8eb53d2a8f15a34471e9c8936f50c8609635ac0f370357b96c

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page