A flexible JAX RL protocol.
Project description
Why Parallax?
JAX RL environments need pure functions and immutable state, but there's no standard for what that looks like. Parallax defines a minimal reset/step contract so any environment exposes the same interface.
- For JAX RL users: Write agents, experience collection, and training loops once. Swap environments without changing your code.
- For Gymnasium users: The same familiar concepts (reset, step, observation, reward) rebuilt for JAX. Pure functions instead of mutable objects, so everything works with
jit,vmap, andscan.
Protocol, not a framework. No base class, no registration.
Install
pip install parallax-rl
# With adapter dependencies
pip install parallax-rl[brax] # Brax environments
pip install parallax-rl[gymnax] # Gymnax environments
pip install parallax-rl[mjx] # MuJoCo Playground (MJX) environments
pip install parallax-rl[adapters] # All adapters
Quick Start
import jax
env = GridWorld()
state = env.reset(key=jax.random.key(0))
for _ in range(200):
action = agent(state.observation)
state = env.step(state, action)
if state.done:
break
How It Works
RL environments are conventionally stateful (Gymnasium, PettingZoo, etc.). Calling env.step() mutates the environment in place. JAX needs pure functions and immutable data, so Parallax splits things in two:
Env is stateless. It has two pure functions (reset and step) with no internal state.
State is a JAX pytree that holds all the data. Every call to reset or step returns a new State with precomputed fields:
state = env.reset(key=jax.random.key(0))
state = env.step(state, action)
state.env_state # raw environment data (any pytree)
state.observation # what the agent sees
state.reward # scalar reward
state.termination # episode ended naturally
state.truncation # episode was cut short
state.done # termination | truncation
state.info # extra metadata (dict)
state.step_count # current timestep
state.key # JAX RNG key
State is pure data. All values are computed in reset/step and stored directly.
Building an Environment
Implement reset and step. Each returns a State with all fields computed:
import jax
import jax.numpy as jnp
from typing import NamedTuple
from jaxtyping import Array, PRNGKeyArray
from parallax import Space, State, spaces
class GridState(NamedTuple):
pos: Array
goal: Array
class GridWorld:
action_space: Space = spaces.Discrete(4)
observation_space: Space = spaces.Box(0.0, 4.0, (4,))
def reset(self, *, key: PRNGKeyArray) -> State:
key, goal_key = jax.random.split(key)
pos = jnp.zeros(2, dtype=jnp.float32)
goal = jax.random.randint(goal_key, (2,), minval=1, maxval=5).astype(jnp.float32)
return State(
env_state=GridState(pos=pos, goal=goal),
observation=jnp.concatenate([pos, goal]),
reward=jnp.float32(0.0),
termination=jnp.bool_(False),
truncation=jnp.bool_(False),
info={},
step_count=jnp.int32(0),
key=key,
)
def step(self, state: State, action: Array) -> State:
moves = jnp.array([[0, 1], [0, -1], [1, 0], [-1, 0]], dtype=jnp.float32)
pos = jnp.clip(state.env_state.pos + moves[action], 0.0, 4.0)
goal = state.env_state.goal
return State(
env_state=GridState(pos=pos, goal=goal),
observation=jnp.concatenate([pos, goal]),
reward=jnp.exp(-jnp.linalg.norm(pos - goal)),
termination=jnp.all(pos == goal),
truncation=jnp.bool_(False),
info={},
step_count=state.step_count + 1,
key=state.key,
)
env_state is your raw environment data and can be any JAX pytree. The other fields (observation, reward, etc.) are derived from it in reset/step.
For multi-agent environments, agents are a dimension on your arrays. Reward, termination, and truncation become shape (num_agents,) while the method signatures stay the same. Environments where agents have different action space sizes will need padding and masking to maintain fixed array shapes. This is a JAX constraint (need for fixed shapes) rather than a Parallax one.
Wrappers
Wrappers compose to add functionality:
from parallax import AutoResetWrapper, TimeLimit, VmapWrapper
num_envs = 128
env = VmapWrapper(AutoResetWrapper(TimeLimit(GridWorld(), max_steps=200)), num_envs=num_envs)
state = env.reset(key=jax.random.key(0))
state = env.step(state, actions)
For manual resets (e.g. when you need terminal observations for value bootstrapping):
env = VmapWrapper(TimeLimit(GridWorld(), max_steps=200), num_envs=num_envs)
state = env.step(state, actions)
state = env.reset(key=reset_key, state=state, done=state.done)
Adapters
Use existing JAX RL environments with Parallax via adapters:
import gymnax
from parallax.adapters import GymnaxAdapter
env = GymnaxAdapter(gymnax.make("CartPole-v1")[0])
env = VmapWrapper(env, num_envs=128)
import brax.envs
from parallax.adapters import BraxAdapter
env = BraxAdapter(brax.envs.get_environment("ant"))
env = VmapWrapper(env, num_envs=128)
from mujoco_playground import registry
from parallax.adapters import MJXAdapter
env = MJXAdapter(registry.load("HumanoidWalk", config_overrides={"impl": "jax"}))
env = VmapWrapper(env, num_envs=128)
Adapters map foreign reset/step APIs to the Parallax protocol. Brax and MJX adapters extract episode length from the underlying environment and handle truncation internally. Brax's built-in auto-reset is stripped automatically to preserve terminal observations.
Custom Properties
Subclass State to add extra fields. For example, adding an action mask to GridWorld:
from dataclasses import dataclass
@jax.tree_util.register_dataclass
@dataclass
class MaskedState(State):
action_mask: Bool[Array, "4"]
Then return MaskedState from your env's reset and step:
class MaskedGridWorld(GridWorld):
def reset(self, *, key: PRNGKeyArray) -> MaskedState:
state = super().reset(key=key)
return MaskedState(**vars(state), action_mask=compute_mask(state.env_state))
def step(self, state: MaskedState, action: Array) -> MaskedState:
state = super().step(state, action)
return MaskedState(**vars(state), action_mask=compute_mask(state.env_state))
state.action_mask # fully typed, works with jit/vmap/wrappers
Collecting Experience
Use jax.lax.scan for vectorized rollouts. Manual resets let you capture terminal observations before resetting done environments, which is needed for value bootstrapping:
from dataclasses import dataclass
from parallax import VmapWrapper
@jax.tree_util.register_dataclass
@dataclass
class Experience:
observation: jax.Array
next_observation: jax.Array
action: jax.Array
reward: jax.Array
termination: jax.Array
num_envs = 128
env = VmapWrapper(GridWorld(), num_envs=num_envs)
key = jax.random.key(0)
key, reset_key = jax.random.split(key)
state = env.reset(key=reset_key)
obs = state.observation
def step_fn(carry, _):
state, obs, key = carry
key, action_key, reset_key = jax.random.split(key, 3)
action = jax.vmap(env.action_space.sample)(key=jax.random.split(action_key, num_envs))
state = env.step(state, action)
next_obs = state.observation
experience = Experience(
observation=obs,
next_observation=next_obs,
action=action,
reward=state.reward,
termination=state.termination,
)
# Reset environments where done, terminal obs captured above
state = env.reset(key=reset_key, state=state, done=state.done)
obs = state.observation
return (state, obs, key), experience
(state, obs, key), experiences = jax.lax.scan(step_fn, (state, obs, key), None, length=256)
Project details
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file parallax_rl-0.2.5.tar.gz.
File metadata
- Download URL: parallax_rl-0.2.5.tar.gz
- Upload date:
- Size: 19.1 kB
- Tags: Source
- Uploaded using Trusted Publishing? Yes
- Uploaded via: twine/6.1.0 CPython/3.13.7
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
231551e59c4730262761b965744ba4893db728450ce87be1131cb7424ce7e63c
|
|
| MD5 |
d371af838ee770cab13546c4617128b3
|
|
| BLAKE2b-256 |
1306c0f33c2478cdabc7b5ea4caf7f5908bb6f1945d40bbe39a8a314b9fa27eb
|
Provenance
The following attestation bundles were made for parallax_rl-0.2.5.tar.gz:
Publisher:
publish.yml on Auxeno/parallax
-
Statement:
-
Statement type:
https://in-toto.io/Statement/v1 -
Predicate type:
https://docs.pypi.org/attestations/publish/v1 -
Subject name:
parallax_rl-0.2.5.tar.gz -
Subject digest:
231551e59c4730262761b965744ba4893db728450ce87be1131cb7424ce7e63c - Sigstore transparency entry: 1217980960
- Sigstore integration time:
-
Permalink:
Auxeno/parallax@24a937e913af97afb70e8329eb0a92aad064b3b5 -
Branch / Tag:
refs/tags/v0.2.5 - Owner: https://github.com/Auxeno
-
Access:
public
-
Token Issuer:
https://token.actions.githubusercontent.com -
Runner Environment:
github-hosted -
Publication workflow:
publish.yml@24a937e913af97afb70e8329eb0a92aad064b3b5 -
Trigger Event:
release
-
Statement type:
File details
Details for the file parallax_rl-0.2.5-py3-none-any.whl.
File metadata
- Download URL: parallax_rl-0.2.5-py3-none-any.whl
- Upload date:
- Size: 14.4 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? Yes
- Uploaded via: twine/6.1.0 CPython/3.13.7
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
f8dce74336487162032bb58a3cc24c20bd4d4b8241821b9bb24009d38ec967e3
|
|
| MD5 |
8e0438d26e2e74a341420804ea4eb719
|
|
| BLAKE2b-256 |
f88d45660d9efc5d80e020f184c0ee7cccfb63f5af8b6a14a6de35b48bb6aacf
|
Provenance
The following attestation bundles were made for parallax_rl-0.2.5-py3-none-any.whl:
Publisher:
publish.yml on Auxeno/parallax
-
Statement:
-
Statement type:
https://in-toto.io/Statement/v1 -
Predicate type:
https://docs.pypi.org/attestations/publish/v1 -
Subject name:
parallax_rl-0.2.5-py3-none-any.whl -
Subject digest:
f8dce74336487162032bb58a3cc24c20bd4d4b8241821b9bb24009d38ec967e3 - Sigstore transparency entry: 1217981005
- Sigstore integration time:
-
Permalink:
Auxeno/parallax@24a937e913af97afb70e8329eb0a92aad064b3b5 -
Branch / Tag:
refs/tags/v0.2.5 - Owner: https://github.com/Auxeno
-
Access:
public
-
Token Issuer:
https://token.actions.githubusercontent.com -
Runner Environment:
github-hosted -
Publication workflow:
publish.yml@24a937e913af97afb70e8329eb0a92aad064b3b5 -
Trigger Event:
release
-
Statement type: