Port of David Ha's Slime volleyball for gymnasium.
Project description
Slime Volleyball
A Gymnasium-compatible multi-agent environment for the classic Slime Volleyball game, with dual numpy/JAX backends. Originally developed by David Ha and ported from hardmaru/slimevolleygym.
Two agents (agent_left, agent_right) play volleyball. The game ends when one agent loses all lives (default: 1) or after max_steps (default: 3000) timesteps.
Installation
pip install -e .
Dependencies: numpy, gymnasium
Optional: jax / jaxlib (JAX backend), pyglet (rendering), opencv-python (pixel observations)
Quick Start
Multi-Agent (PettingZoo-style)
from slime_volleyball.slimevolley_env import SlimeVolleyEnv
env = SlimeVolleyEnv()
obs, info = env.reset(seed=42)
for _ in range(1000):
actions = {"agent_left": env.action_space["agent_left"].sample(),
"agent_right": env.action_space["agent_right"].sample()}
obs, rewards, terminateds, truncateds, info = env.step(actions)
if terminateds["__all__"] or truncateds["__all__"]:
obs, info = env.reset()
Single-Agent (Right Agent vs Baseline)
When you pass a single integer action, the right agent is controlled by that action and the left agent uses the built-in baseline RNN policy:
env = SlimeVolleyEnv()
obs, info = env.reset(seed=42)
obs, rewards, terminateds, truncateds, info = env.step(3) # right agent jumps
Boost Mode
from slime_volleyball.slimevolley_boost_env import SlimeVolleyBoostEnv
env = SlimeVolleyBoostEnv()
obs, info = env.reset(seed=42)
# 16-dim observations, Discrete(13) actions (adds boost/powerup variants)
Dual Backend
The environment supports two backends selected at construction time. When backend="numpy" (default), no JAX code runs and JAX does not need to be installed. When backend="jax", the environment uses jax.numpy for all array operations and exposes JIT-compiled functions for high-performance training.
# NumPy backend (default) — no JAX required
env = SlimeVolleyEnv()
# JAX backend
env = SlimeVolleyEnv(config={"backend": "jax"})
Internally, array operations go through a backend proxy (from slime_volleyball.backend import xp) that resolves to numpy or jax.numpy depending on the active backend. The entire game physics and observation pipeline is implemented as pure functions operating on an immutable EnvState dataclass.
JAX Functional API
With backend="jax", the environment exposes jax_step and jax_reset — JIT-compiled pure functions compatible with jax.vmap and jax.lax.scan for vectorized training (JaxMARL-style).
import jax
import jax.numpy as jnp
from slime_volleyball.slimevolley_env import SlimeVolleyEnv
env = SlimeVolleyEnv(config={"backend": "jax"})
env.reset(seed=0)
step_fn = env.jax_step # (EnvState, actions) -> (obs, state, rewards, terms, truncs, infos)
reset_fn = env.jax_reset # (rng_key) -> (obs, state, infos)
Vectorized Environments (vmap)
num_envs = 64
keys = jax.random.split(jax.random.PRNGKey(0), num_envs)
# Parallel reset
obs, states, _ = jax.vmap(reset_fn)(keys)
# obs.shape == (64, 2, 12), one obs per agent per env
# Parallel step
actions = jnp.zeros((num_envs, 2, 3), dtype=jnp.float32)
obs, states, rewards, terms, truncs, _ = jax.vmap(step_fn)(states, actions)
Rollouts with lax.scan
def rollout_step(state, _):
actions = jnp.zeros((2, 3), dtype=jnp.float32)
obs, new_state, rewards, _, _, _ = step_fn(state, actions)
return new_state, rewards
single_state = jax.tree_util.tree_map(lambda x: x[0], states)
final_state, all_rewards = jax.lax.scan(rollout_step, single_state, None, length=1000)
# all_rewards.shape == (1000, 2)
Observations
Observations are perspective-normalized: each agent sees the world as if it's on a consistent side (x-values and velocities are reflected by the agent's direction).
Base (12-dim): [x, y, vx, vy, ball_x, ball_y, ball_vx, ball_vy, opp_x, opp_y, opp_vx, opp_vy] — all divided by a scale factor of 10.
Boost (16-dim): Appends [powerup_avail, powerup_timer, opp_powerup_avail, opp_powerup_timer].
Actions
Base — Discrete(6):
| Index | Action | [fwd, bwd, jump] |
|---|---|---|
| 0 | Noop | [0, 0, 0] |
| 1 | Left | [1, 0, 0] |
| 2 | UpLeft | [1, 0, 1] |
| 3 | Up | [0, 0, 1] |
| 4 | UpRight | [0, 1, 1] |
| 5 | Right | [0, 1, 0] |
Boost — Discrete(13): Indices 0–5 as above, indices 6–12 repeat with boost activated ([fwd, bwd, jump, 1]).
The functional step pipeline takes continuous [fwd, bwd, jump] arrays directly (shape (2, 3) or (2, 4) with boost). The discrete-to-continuous mapping is handled by the environment wrapper.
Configuration
Pass a config dict to the environment constructor:
env = SlimeVolleyEnv(config={
"backend": "numpy", # "numpy" or "jax"
"max_steps": 3000, # max timesteps per episode
"survival_reward": False, # add +0.01 reward per step survived
"from_pixels": False, # pixel observations (numpy only)
"human_inputs": False, # invert left agent controls for human play
"seed": None, # initial RNG seed
})
EnvState
All mutable game state is captured in a single frozen dataclass (EnvState). State transitions use dataclasses.replace() to produce new states. On the JAX backend, EnvState is registered as a pytree for jit/vmap/scan compatibility.
| Field | Shape | Description |
|---|---|---|
ball_pos |
(2,) |
Ball position [x, y] |
ball_vel |
(2,) |
Ball velocity [vx, vy] |
ball_prev_pos |
(2,) |
Previous ball position |
agent_pos |
(2, 2) |
Agent positions [[x,y], [x,y]] |
agent_vel |
(2, 2) |
Agent velocities |
agent_desired_vel |
(2, 2) |
Desired velocities from actions |
agent_life |
(2,) |
Lives remaining per agent |
agent_powerup_avail |
(2,) |
Powerups available |
agent_powerup_timer |
(2,) |
Powerup countdown timer |
delay_life |
() |
DelayScreen frames remaining |
time |
() |
Timestep counter |
done |
() |
Episode done flag |
rng_key |
— | JAX PRNGKey or numpy Generator |
Project Structure
slime_volleyball/
├── slimevolley_env.py # SlimeVolleyEnv (Gymnasium wrapper)
├── slimevolley_boost_env.py # SlimeVolleyBoostEnv (boost/powerup variant)
├── baseline_policy.py # 120-parameter RNN baseline opponent
├── rendering.py # Pyglet-based 2D rendering
├── backend/
│ ├── __init__.py # xp proxy (numpy or jax.numpy)
│ ├── _dispatch.py # set_backend(), get_backend(), BackendProxy
│ ├── array_ops.py # set_at() for JAX .at[].set() vs numpy indexing
│ └── env_state.py # EnvState frozen dataclass + pytree registration
├── core/
│ ├── constants.py # Physics constants, rendering dimensions
│ ├── physics.py # Pure functional physics (agents, ball, collisions)
│ ├── observations.py # Perspective-normalized observation computation
│ ├── step_pipeline.py # Functional reset/step + factory builders
│ ├── game.py # Legacy OOP game engine (rendering only)
│ ├── agent.py # Legacy Agent class (rendering only)
│ ├── objects.py # Particle, Wall, RelativeState (rendering only)
│ └── utils.py # Coordinate conversion, DelayScreen
tests/
├── test_backend.py # Backend dispatch tests
├── test_env_state.py # EnvState + JAX pytree tests
├── test_physics.py # Physics function tests
├── test_observations.py # Observation computation tests
├── test_step_pipeline.py # Functional reset/step tests (numpy + JAX)
├── test_env_wrapper.py # Environment wrapper tests (numpy, JAX, boost)
├── test_rendering.py # Rendering state sync tests
└── test_cross_backend.py # Cross-backend parity tests
examples/
├── train_slimevolley_jax.py # IPPO training script (JAX)
├── episode.gif # Trained policy visualization
├── policy.onnx # ONNX-exported trained policy
└── slimevolley_training.png # Training reward curve
Training
An IPPO (Independent PPO) training script is provided in examples/train_slimevolley_jax.py. It uses the JAX functional API with jax.vmap for parallel environments and jax.lax.scan for the training loop.
python examples/train_slimevolley_jax.py
This produces:
slimevolley_training.png— reward curve over trainingepisode.gif— GIF of the trained policy playingpolicy.onnx— ONNX export of the trained actor network
Interactive Visualization
python visualize_env.py
Uses the legacy OOP game engine with pyglet rendering. Keyboard controls:
| Key | Agent | Action |
|---|---|---|
| Left Arrow | Right agent | Left |
| Right Arrow | Right agent | Right |
| Up Arrow | Right agent | Jump |
| A | Left agent | Left |
| D | Left agent | Right |
| W | Left agent | Jump |
Running Tests
pytest tests/
Credits
- Original game: David Ha (blog post)
- Original gym environment: hardmaru/slimevolleygym
Project details
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file slimevb-0.1.1.tar.gz.
File metadata
- Download URL: slimevb-0.1.1.tar.gz
- Upload date:
- Size: 57.8 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.2.0 CPython/3.11.7
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
510fbd2a7d7cd9336d990f0a2bf09c0e005fe5282d04efb4a848fa77488d49e4
|
|
| MD5 |
ac80cdcec792656775b7d6bb13bfcc72
|
|
| BLAKE2b-256 |
3cf948ac81d49584b910dc295dc035e3b595884cb8d1ee59249e5e42a244e291
|
File details
Details for the file slimevb-0.1.1-py3-none-any.whl.
File metadata
- Download URL: slimevb-0.1.1-py3-none-any.whl
- Upload date:
- Size: 42.7 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.2.0 CPython/3.11.7
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
ba8ba977fa753c05278b77f979302b9f17a15db2f87df2ec40e7d408309b822e
|
|
| MD5 |
3ce8ea786755116aba9fff507ca91096
|
|
| BLAKE2b-256 |
6b0b976eaf10e1778883e66b02c9d3727e0cd315a0de933010962c35112af05e
|