Skip to main content

A flexible JAX RL protocol.

Project description

Parallax

A JAX Reinforcement Learning Protocol

Python License


Why Parallax?

JAX RL environments need pure functions and immutable state, but there's no standard for what that looks like. Parallax defines a minimal reset/step contract so any environment exposes the same interface.

  • For JAX RL users: Write agents, experience collection, and training loops once. Swap environments without changing your code.
  • For Gymnasium users: The same familiar concepts (reset, step, observation, reward) rebuilt for JAX. Pure functions instead of mutable objects, so everything works with jit, vmap, and scan.

Protocol, not a framework. No base class, no registration.

Install

pip install parallax-rl

Quick Start

import jax

env = GridWorld()
state = env.reset(key=jax.random.key(0))

for _ in range(200):
    action = agent(state.observation)
    state = env.step(state, action)

    if state.done:
        break

How It Works

RL environments are conventionally stateful (Gymnasium, PettingZoo, etc.). Calling env.step() mutates the environment in place. JAX needs pure functions and immutable data, so Parallax splits things in two:

Env is stateless. It has two pure functions (reset and step) with no internal state.

State is a JAX pytree that holds all the data. Every call to reset or step returns a new State with precomputed fields:

state = env.reset(key=jax.random.key(0))
state = env.step(state, action)

state.env_state    # raw environment data (any pytree)
state.observation  # what the agent sees
state.reward       # scalar reward
state.termination  # episode ended naturally
state.truncation   # episode was cut short
state.done         # termination | truncation
state.info         # extra metadata (dict)
state.step_count   # current timestep
state.key          # JAX RNG key

State is pure data. All values are computed in reset/step and stored directly.

Building an Environment

Implement reset and step. Each returns a State with all fields computed:

import jax
import jax.numpy as jnp
from typing import NamedTuple
from jaxtyping import Array, PRNGKeyArray
from parallax import Space, State, spaces


class GridState(NamedTuple):
    pos: Array
    goal: Array


class GridWorld:
    action_space: Space = spaces.Discrete(4)
    observation_space: Space = spaces.Box(0.0, 4.0, (4,))

    def reset(self, *, key: PRNGKeyArray) -> State:
        key, goal_key = jax.random.split(key)
        pos = jnp.zeros(2, dtype=jnp.float32)
        goal = jax.random.randint(goal_key, (2,), minval=1, maxval=5).astype(jnp.float32)
        return State(
            env_state=GridState(pos=pos, goal=goal),
            observation=jnp.concatenate([pos, goal]),
            reward=jnp.float32(0.0),
            termination=jnp.bool_(False),
            truncation=jnp.bool_(False),
            info={},
            step_count=jnp.int32(0),
            key=key,
        )

    def step(self, state: State, action: Array) -> State:
        moves = jnp.array([[0, 1], [0, -1], [1, 0], [-1, 0]], dtype=jnp.float32)
        pos = jnp.clip(state.env_state.pos + moves[action], 0.0, 4.0)
        goal = state.env_state.goal
        return State(
            env_state=GridState(pos=pos, goal=goal),
            observation=jnp.concatenate([pos, goal]),
            reward=jnp.exp(-jnp.linalg.norm(pos - goal)),
            termination=jnp.all(pos == goal),
            truncation=jnp.bool_(False),
            info={},
            step_count=state.step_count + 1,
            key=state.key,
        )

env_state is your raw environment data and can be any JAX pytree. The other fields (observation, reward, etc.) are derived from it in reset/step.

For multi-agent environments, agents are a dimension on your arrays. Reward, termination, and truncation become shape (num_agents,) while the method signatures stay the same. Environments where agents have different action space sizes will need padding and masking to maintain fixed array shapes. This is a JAX constraint (need for fixed shapes) rather than a Parallax one.

Wrappers

Wrappers compose to add functionality:

from parallax import AutoResetWrapper, TimeLimit, VmapWrapper

num_envs = 128
env = VmapWrapper(AutoResetWrapper(TimeLimit(GridWorld(), max_steps=200)), num_envs=num_envs)
state = env.reset(key=jax.random.key(0))
state = env.step(state, actions)

For manual resets (e.g. when you need terminal observations for value bootstrapping):

env = VmapWrapper(TimeLimit(GridWorld(), max_steps=200), num_envs=num_envs)
state = env.step(state, actions)
state = env.reset(key=reset_key, state=state, done=state.done)

Adapters

Use existing JAX RL environments with Parallax via adapters:

import gymnax
from parallax.adapters import GymnaxAdapter

env = GymnaxAdapter(gymnax.make("CartPole-v1")[0])
env = VmapWrapper(env, num_envs=128)
import brax.envs
from parallax.adapters import BraxAdapter

env = BraxAdapter(brax.envs.get_environment("ant"))
env = VmapWrapper(env, num_envs=128)

Adapters map foreign reset/step APIs to the Parallax protocol. Brax's built-in auto-reset is stripped automatically to preserve terminal observations.

Custom Properties

Subclass State to add extra fields. For example, adding an action mask to GridWorld:

from dataclasses import dataclass

@jax.tree_util.register_dataclass
@dataclass
class MaskedState(State):
    action_mask: Bool[Array, "4"]

Then return MaskedState from your env's reset and step:

class MaskedGridWorld(GridWorld):
    def reset(self, *, key: PRNGKeyArray) -> MaskedState:
        state = super().reset(key=key)
        return MaskedState(**vars(state), action_mask=compute_mask(state.env_state))

    def step(self, state: MaskedState, action: Array) -> MaskedState:
        state = super().step(state, action)
        return MaskedState(**vars(state), action_mask=compute_mask(state.env_state))

state.action_mask  # fully typed, works with jit/vmap/wrappers

Collecting Experience

Use jax.lax.scan for vectorized rollouts. Manual resets let you capture terminal observations before resetting done environments, which is needed for value bootstrapping:

from dataclasses import dataclass
from parallax import VmapWrapper

@jax.tree_util.register_dataclass
@dataclass
class Experience:
    observation: jax.Array
    next_observation: jax.Array
    action: jax.Array
    reward: jax.Array
    termination: jax.Array
    
num_envs = 128
env = VmapWrapper(GridWorld(), num_envs=num_envs)

key = jax.random.key(0)
key, reset_key = jax.random.split(key)
state = env.reset(key=reset_key)
obs = state.observation

def step_fn(carry, _):
    state, obs, key = carry
    key, action_key, reset_key = jax.random.split(key, 3)
    action = jax.vmap(env.action_space.sample)(key=jax.random.split(action_key, num_envs))

    state = env.step(state, action)
    next_obs = state.observation

    experience = Experience(
        observation=obs,
        next_observation=next_obs,
        action=action,
        reward=state.reward,
        termination=state.termination,
    )

    # Reset environments where done, terminal obs captured above
    state = env.reset(key=reset_key, state=state, done=state.done)
    obs = state.observation

    return (state, obs, key), experience

(state, obs, key), experiences = jax.lax.scan(step_fn, (state, obs, key), None, length=256)

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

parallax_rl-0.2.3.tar.gz (18.7 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

parallax_rl-0.2.3-py3-none-any.whl (14.2 kB view details)

Uploaded Python 3

File details

Details for the file parallax_rl-0.2.3.tar.gz.

File metadata

  • Download URL: parallax_rl-0.2.3.tar.gz
  • Upload date:
  • Size: 18.7 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: twine/6.1.0 CPython/3.13.7

File hashes

Hashes for parallax_rl-0.2.3.tar.gz
Algorithm Hash digest
SHA256 945b4c569eaecfcc94a3b3d0a3b78f746f4c795b5cc26b79dcfd18ab36f84a99
MD5 8df0cbc6c55c55c5835e8d44e5e121d9
BLAKE2b-256 15ee26de26cc28b33cfd819678fcad12fedeeba11b40af15a5d3be35280401f7

See more details on using hashes here.

Provenance

The following attestation bundles were made for parallax_rl-0.2.3.tar.gz:

Publisher: publish.yml on Auxeno/parallax

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

File details

Details for the file parallax_rl-0.2.3-py3-none-any.whl.

File metadata

  • Download URL: parallax_rl-0.2.3-py3-none-any.whl
  • Upload date:
  • Size: 14.2 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: twine/6.1.0 CPython/3.13.7

File hashes

Hashes for parallax_rl-0.2.3-py3-none-any.whl
Algorithm Hash digest
SHA256 ff532dbabeccb4129e8094d8bcdde9c496961fe4fec79a8e467e6541fd17910f
MD5 6c40ac7ae995699ac7291a2fb5236cb4
BLAKE2b-256 dfb29497e7cb1b86cf89d9ff3aa42c0fe38337f46b67dfc3a52252ba5e4331e7

See more details on using hashes here.

Provenance

The following attestation bundles were made for parallax_rl-0.2.3-py3-none-any.whl:

Publisher: publish.yml on Auxeno/parallax

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page