Gymnasium-style API standard for RL environment creation in JAX.
Project description
Envrax is a lightweight open-source JAX-native Reinforcement Learning (RL) environment API standard for single-agents, equivalent to the Gymnasium package. It includes: base classes, spaces, wrappers, and a shared registry for building and utilizing RL environments with ease.
All environment logic follows a stateless functional design that builds on top of the JAX and Chex packages to benefit from JAX accelerator efficiency.
Why Envrax?
One of the downsides of RL research is sample efficiency. Often the environment becomes the main bottleneck for model training because it's restricted, and built, around CPU utilisation.
For example, the Atari suite is CPU constrained and, from our experience, when we increase the number of environments running in parallel, a single training step drastically increases wall-clock time. Gradient computations on a GPU could take ~30 seconds but the sample retrieval takes over 2+ minutes (400% increase) because of the CPU bottleneck and that's with efficiency tricks!
This begged a much deeper question -
what if we could eliminate the CPU bottleneck by loading the environment onto the same accelerator as the model?
Packages like Brax and Gymnax have shown the incredible benefits of JAX based environment approaches. However, they are limited to their unique approaches without a unified API standard. Gymnasium has always been a personal favourite of mine because of its API simplicity, but there is no JAX equivalent. Thus, Envrax was born.
Requirements
- Python 3.13+
- JAX 0.9+ (CPU, CUDA, or TPU backend)
Installation
pip install envrax
Or from source with uv:
git clone https://github.com/Achronus/envrax
cd envrax
uv sync
API Standard
Envrax enforces a small, strict interface so that every environment, regardless of the suite created, behaves identically under jax.jit, jax.vmap, and jax.lax.scan.
Every environment is built as a stateless Python object and environment states (envrax.EnvState) are defined as explicit chex.dataclass PyTrees passed to and returned from every call, making the full reset → step pipeline compatible with jax.jit, jax.vmap, and jax.lax.scan with zero modification.
At a glance, all environments inherit from the envrax.JaxEnv base class and then implement their own envrax.Spaces, methods, envrax.EnvState, and envrax.EnvConfig. By design, JaxEnv is generic over four type parameters: the observation space, the action space, the state type, and the config type (JaxEnv[ObsSpaceT, ActSpaceT, StateT, ConfigT]) to maximise IDE support.
Here are the core components:
import jax
from envrax import JaxEnv, EnvState, EnvConfig
from envrax.spaces import Box, Discrete
# Core inheritable items
config = EnvConfig() # static configuration
env = MyEnv(config=config) # e.g., MyEnv extends JaxEnv with JaxEnv[Box, Discrete, MyEnvState, EnvConfig]
# Required inputs
rng = jax.random.key(42) # PRNG key (only for reset)
# Core properties
obs_space = env.observation_space
action_space = env.action_space
# Core methods
obs, state = env.reset(rng) # rng is consumed and stored on state
obs, new_state, reward, done, info = env.step(state, action)
This differs slightly from the Gymnasium API standard to maintain JAX compatibility so that we can still trace pure functions without introducing side effects to JIT compilation. Specifically:
configlives on the env instance: we set theconfigonce at construction so that it never has to be passed toresetorstep.rnglives in the state: ourresetmethod consumes a PRNG key and stores its remainder instate.rng. Thestepmethod then splits thestate.rnginternally for any per-step randomness. This means we get to keep the stateless behaviour while threading randomness through the PyTree.
State and Config as PyTrees
The environment state (EnvState) and configuration (EnvConfig) are chex.dataclass PyTrees. You extend them with game-specific fields such as positions, velocities, timers, while maintaining full compatibility with JAX serialisation and batched transforms.
In their base forms we have:
@chex.dataclass
class EnvState:
rng: chex.PRNGKey # PRNG key threaded through the episode
step: jax.Array # current timestep
done: jax.Array # episode termination flag
@chex.dataclass
class EnvConfig:
max_steps: int = 1000 # maximum number of steps per episode
The EnvConfig acts as static configuration values that are declared once and never changed, while EnvState is mutated through the environment's lifecycle.
Generics and Type Safety
Every JaxEnv subclass declares its concrete observation, action, state, and config types:
class BallEnv(JaxEnv[Box, Discrete, BallState, BallConfig]): ...
This gives you full IDE autocomplete and static type-checking on env.observation_space, env.action_space, env.config, and the state returned by reset/step.
Spaces
Envrax provides some of the same core space types as Gymnasium (Discrete, Box, and MultiDiscrete) with the same names, semantics, and sample/contains methods.
Spaces are pure metadata that act as contracts for defining the environment spec. They describe the shapes, bounds, and dtypes for the items used in the RL environment.
| Symbol | Description |
|---|---|
Discrete(n) |
n integer actions in [0, n) |
Box(low, high, shape, dtype) |
Continuous array space |
MultiDiscrete(nvec) |
Vector of independent discrete sub-spaces |
Static fields like Discrete.n and Box.shape are Python values, so they can be used directly for static decisions in your env logic.
Wrappers & Composition
Envrax ports several of Gymnasium's most useful wrappers to the JAX-native interface. They follow the same nesting pattern Gymnasium uses where each wrapper takes an inner env and transforms its observations, rewards, or state. Each one follows the standard convention, exposing the same reset/step signatures as a plain JaxEnv but use composition to expand the original environment's functionality without a complete rewrite.
| Wrapper | Kind | Input obs | Output obs | Description |
|---|---|---|---|---|
JitWrapper |
pass-through | any env | same obs | Compiles reset + step with jax.jit; caches kernels to disk |
GrayscaleObservation |
pass-through | uint8[H, W, 3] |
uint8[H, W] |
NTSC luminance conversion |
ResizeObservation(h, w) |
pass-through | uint8[H, W] or uint8[H, W, C] |
uint8[h, w] or uint8[h, w, C] |
Bilinear resize (default 84×84) |
NormalizeObservation |
pass-through | uint8[...] |
float32[...] in [0, 1] |
Divide by 255 |
ClipReward |
pass-through | any reward | float32 ∈ {−1, 0, +1} |
Sign clipping |
ExpandDims |
pass-through | any env | same obs | Adds trailing 1 dim to reward and done |
EpisodeDiscount |
pass-through | any env | same obs | Converts done bool to float32 discount (1.0 / 0.0) |
RecordVideo |
pass-through | any env | same obs | Saves episode frames to MP4 (not JIT-compatible) |
FrameStackObservation(n_stack) |
stateful | uint8[H, W] |
uint8[H, W, n_stack] |
Rolling frame buffer (default 4); state: FrameStackState |
RecordEpisodeStatistics |
stateful | any env | same obs | Tracks episode return + length in info["episode"]; state: EpisodeStatisticsState |
Wrappers come in two flavours:
- Pass-through: these preserve the inner state type without changes (e.g.
ClipReward,GrayscaleObservation). - Stateful: these introduce a new outer state type (a
chex.dataclassextendingEnvState) that wraps the inner state in anenv_statefield (e.g.FrameStackObservation,RecordEpisodeStatistics).
Wrappers can be applied either as a list of class instances (no functools.partial needed) or composed manually. Envrax handles the rest automatically.
import envrax
from envrax.wrappers import (
ClipReward,
FrameStackObservation,
GrayscaleObservation,
ResizeObservation,
)
# Mix of plain classes and pre-configured wrappers — no `partial` needed
env = envrax.make(
"BallEnv-v0",
wrappers=[
GrayscaleObservation,
ResizeObservation(h=84, w=84),
FrameStackObservation(n_stack=4),
ClipReward,
],
)
The same wrappers also work as direct calls if you want to compose them manually:
env = GrayscaleObservation(env)
env = ResizeObservation(env, h=84, w=84)
env = FrameStackObservation(env, n_stack=4)
Registry, Factory & Suite Catalog
Envrax houses a shared registry that lets any installed suite package expose its environments through a single entry point. The registry stores EnvSpec objects keyed by canonical names and make() methods retrieves them with optional wrappers and JIT compilation.
As Envrax is the base API standard, it ships with zero environments so the registry starts out empty. Environments are contributed by downstream suite packages that call register() (or register_suite() for bulk registration) at import time. Examples of existing packages will be coming in the future.
The suite catalog is made up of three core components: EnvSpec, EnvSuite, and EnvSet:
| Class | Purpose |
|---|---|
EnvSpec |
Frozen dataclass holding (name, env_class, default_config, suite). Used for holding the environment of registration. Both register() and register_suite() build these and store them in the registry. |
EnvSuite |
A named, versioned collection of EnvSpecs from one suite (e.g. all MuJoCo Playground tasks). They hold a prefix, the suite category, the suite version, its required_packages, and a list of specs (EnvSpecs). They support slicing, iteration, and package availability checks. |
EnvSet |
An ordered collection of EnvSuite instances, for users who want to compose multiple suites into one heterogeneous benchmark (e.g. EnvSet(EnvSuiteA(), EnvSuiteB())). |
Single-env Registration
Use register() when you want to add one environment manually:
import envrax
from envrax import EnvConfig
envrax.register("MyEnv-v0", MyEnv, EnvConfig(), suite="my-pkg")
env = envrax.make("MyEnv-v0")
Bulk Registration via a Suite
Use register_suite() when shipping a whole benchmark suite. It requires the EnvSuite.specs list to be populated to detect new environments for the registry. For example:
from dataclasses import dataclass, field
from typing import List
from envrax import EnvSpec, EnvSuite, register_suite
# Our custom suite of environments
from demo_envs.games.cartpole import CartpoleEnv, CartpoleConfig
from demo_envs.games.ant import AntEnv, AntConfig
@dataclass
class DemoSuite(EnvSuite):
prefix: str = "demo"
category: str = "Demo Suite"
version: str = "v0"
required_packages: List[str] = field(default_factory=lambda: ["demo_suite"])
specs: List[EnvSpec] = field( # Must be populated
default_factory=lambda: [
EnvSpec("cartpole", CartpoleEnv, CartpoleConfig()),
EnvSpec("ant", AntEnv, AntConfig()),
]
)
def get_name(self, name: str, version: str | None = None) -> str:
return f"{self.prefix}/{name}-{version or self.version}"
# Register every spec in one call — no chance of forgetting one
register_suite(DemoSuite())
# Now usable from anywhere via the standard registry
env = envrax.make("demo/cartpole-v0")
Quick Start
Creating a New Environment
To get started, you first need to create a new environment that inherits from JaxEnv. Here's an example:
import chex
import jax
import jax.numpy as jnp
from envrax import JaxEnv, EnvState, EnvConfig
from envrax.spaces import Box, Discrete
@chex.dataclass
class BallState(EnvState):
ball_x: jax.Array
ball_y: jax.Array
@chex.dataclass
class BallConfig(EnvConfig):
friction: float = 0.98
reward_scale: float = 1.0
class BallEnv(JaxEnv[Box, Discrete, BallState, BallConfig]):
@property
def observation_space(self) -> Box:
return Box(low=0.0, high=1.0, shape=(2,), dtype=jnp.float32)
@property
def action_space(self) -> Discrete:
return Discrete(n=4)
def reset(self, rng: chex.PRNGKey):
rng, init_rng = jax.random.split(rng)
rng_x, rng_y = jax.random.split(init_rng)
state = BallState(
rng=rng,
step=jnp.int32(0),
done=jnp.bool_(False),
ball_x=jax.random.uniform(rng_x),
ball_y=jax.random.uniform(rng_y),
)
obs = jnp.array([state.ball_x, state.ball_y])
return obs, state
def step(self, state: BallState, action: jax.Array):
rng, _ = jax.random.split(state.rng)
# Use action to get new obs
# action: 0=left, 1=right, 2=up, 3=down
dx = jnp.array([-0.01, 0.01, 0.0, 0.0])[action] * self.config.friction
dy = jnp.array([0.0, 0.0, -0.01, 0.01])[action] * self.config.friction
# Get bounds
low, high = self.observation_space.low, self.observation_space.high
# Increment obs
new_x = jnp.clip(state.ball_x + dx, low, high)
new_y = jnp.clip(state.ball_y + dy, low, high)
# Update new state
new_state = state.replace(
rng=rng,
step=state.step + 1,
ball_x=new_x,
ball_y=new_y,
)
# Set new obs
obs = jnp.array([new_state.ball_x, new_state.ball_y])
# Compute reward, done, and info
reward = jnp.float32(1.0) * self.config.reward_scale
done = new_state.step >= self.config.max_steps
info = {"current_step": new_state.step}
return obs, new_state.replace(done=done), reward, done, info
This code should work "as is".
Making Parallel Copies of It
Like Gymnasium's vector module, Envrax has its own VecEnv wrapper that can be used to create any JaxEnv to run N parallel instances via jax.vmap. Each environment auto-resets independently when its episode ends.
import jax
import jax.numpy as jnp
from envrax import VecEnv, EnvConfig
env = BallEnv()
vec_env = VecEnv(env, num_envs=512)
obs, states = vec_env.reset(jax.random.key(42)) # obs: float32[512, 2]
actions = jnp.zeros(512, dtype=jnp.int32)
obs, states, rewards, dones, infos = vec_env.step(states, actions)
# rewards: float32[512]
# dones: bool[512]
This code should work "as is" with the custom BallEnv.
Managing Multiple Environments
Envrax also comes out-of-the-box with multi environment handling. This is useful for meta-learning, multi-task training, or any scenario where you need M different environments running simultaneously. For these cases, use MultiEnv or MultiVecEnv:
import jax
import envrax
# Create M heterogeneous environments (different classes, configs, shapes)
# pre_warm=False by default — compilation is deferred
multi = envrax.make_multi(["BallEnv-v0", "CartPole-v0", "BallEnv-v0"])
# Compile all JIT-wrapped envs in one setup phase (with progress bar)
multi.compile()
# Reset all M envs with a single PRNG key (split internally)
obs_list, states = multi.reset(jax.random.key(0))
# Step all M envs
actions = [jnp.int32(0) for _ in range(multi.num_envs)]
obs_list, states, rewards, dones, infos = multi.step(states, actions)
# Reset a single env (e.g., when its lifetime budget expires)
obs_list[0], states[0] = multi.reset_at(0, jax.random.key(1))
MultiVecEnv follows the same pattern but wraps VecEnv instances:
multi_vec = envrax.make_multi_vec(["BallEnv-v0", "CartPole-v0"], n_envs=64)
multi_vec.compile()
obs_list, states = multi_vec.reset(jax.random.key(0))
# obs_list[0].shape == (64, ...) — each element is already batched
Both classes return lists of values (not stacked arrays) since heterogeneous envs may have different observation shapes. Use multi.class_groups to identify which indices share a class for downstream batching.
make() — create with JIT and wrappers
import jax
import envrax
from envrax import EnvConfig
# Register our custom env (suite packages do this on import)
envrax.register("BallEnv-v0", BallEnv, EnvConfig(max_steps=500))
# JIT-compiled by default; warm-up step runs at construction time
env = envrax.make("BallEnv-v0")
obs, state = env.reset(jax.random.key(0))
# Apply wrappers (innermost-first)
from envrax.wrappers import NormalizeObservation, ClipReward
env = envrax.make(
"BallEnv-v0",
wrappers=[NormalizeObservation, ClipReward],
jit_compile=False,
)
# Vectorised environments
vec_env = envrax.make_vec("BallEnv-v0", n_envs=64)
obs, states = vec_env.reset(jax.random.key(0)) # obs: [64, ...]
# Multiple unique environments at once (pre_warm=False by default)
multi = envrax.make_multi(["BallEnv-v0", "ExtraEnv-v0"])
multi.compile() # separate setup phase
Training Loop
Same simple training loop as Gymnasium but JAXified!
import envrax
import jax
# Init the environment
env = envrax.make("BallEnv-v0")
# Set its initial state
key = jax.random.key(42)
obs, state = env.reset(key)
# Iterate through 1000 timesteps
for _ in range(1000):
action = env.action_space.sample(key)
obs, state, reward, done, info = env.step(state, action)
# If episode has ended, reset to start a new one
if done:
new_key, key = jax.random.split(key)
obs, state = env.reset(new_key)
JitWrapper — manual JIT control
import jax
from envrax.wrappers import JitWrapper
# Compile immediately (default)
env = JitWrapper(BallEnv())
obs, state = env.reset(jax.random.key(0))
# Defer compilation to a separate setup phase
env = JitWrapper(BallEnv(), pre_warm=False)
env.compile() # trigger XLA compilation explicitly
obs, state = env.reset(jax.random.key(0))
VecEnv also exposes a compile() method for the same purpose:
vec_env = VecEnv(BallEnv(), num_envs=64)
vec_env.compile() # warm up the vmapped reset + step
API Reference
Core Classes (envrax.env)
| Symbol | Description |
|---|---|
EnvState |
chex.dataclass — rng: PRNGKey, step: int32, done: bool. Extend to add game-specific fields. |
EnvConfig |
chex.dataclass — max_steps: int = 1000. Extend to add game-specific config. |
JaxEnv[ObsSpaceT, ActSpaceT, StateT, ConfigT] |
Generic abstract base. Implement reset, step, observation_space, action_space. |
Factory Functions (envrax.make)
| Symbol | Description |
|---|---|
make(name, *, config, wrappers, jit_compile, pre_warm, cache_dir) |
Create a single env with optional wrappers and JIT. Returns a JaxEnv. |
make_vec(name, n_envs, *, config, wrappers, ...) |
Create a VecEnv of n_envs parallel environments. |
make_multi(names, *, wrappers, ...) |
Create a MultiEnv managing M heterogeneous environments using each env's registered default config. pre_warm defaults to False. |
make_multi_vec(names, n_envs, *, wrappers, ...) |
Create a MultiVecEnv managing M heterogeneous vectorised environments using each env's registered default config. pre_warm defaults to False. |
Multi-Env Managers (envrax.multi_env, envrax.multi_vec_env)
| Symbol | Description |
|---|---|
MultiEnv(envs) |
Manages M heterogeneous JaxEnv instances. reset(rng), step(states, actions), reset_at(idx, rng), step_at(idx, state, action). Returns lists. |
MultiVecEnv(vec_envs) |
Manages M heterogeneous VecEnv instances. Same API as MultiEnv, but each element is already batched. |
.compile(progress=True) |
Trigger XLA compilation for all inner envs/VecEnvs with an optional tqdm progress bar. |
.class_groups |
Dict[str, List[int]] — env class name to indices, for downstream same-class batching. |
.observation_shapes / .action_shapes (MultiEnv) / .single_observation_shapes / .single_action_shapes (MultiVecEnv) |
Per-env shapes as List[Tuple[int, ...]]. |
.observation_sizes / .action_sizes (MultiEnv) / .single_observation_sizes / .single_action_sizes (MultiVecEnv) |
Per-env flat sizes (prod(shape)) as List[int]. |
.observation_dtypes / .action_dtypes (MultiEnv) / .single_observation_dtypes / .single_action_dtypes (MultiVecEnv) |
Per-env dtypes. |
.pad_dims() |
(max(action_sizes), max(observation_sizes)) — flat sizes large enough to fit any env, for padding when vmap-ing a single policy over a heterogeneous fleet. |
Registry & Catalog (envrax.registry, envrax.envs)
| Symbol | Description |
|---|---|
EnvSpec(name, env_class, default_config, suite="") |
Frozen dataclass — the unit of registration. Stored in the registry under its canonical name. |
EnvSuite |
Base class for declaring a suite of environments. Subclasses pin prefix, category, version, required_packages, and a specs: List[EnvSpec]. Override get_name() to produce canonical IDs. |
EnvSet(*suites) |
Collection of EnvSuite instances. Supports +, iteration, verify_packages(), and from_names() for rebuilding from persisted canonical IDs. |
register(name, cls, default_config, *, suite="") |
Register a single JaxEnv under a name. Builds an EnvSpec internally. |
register_suite(suite, *, version=None) |
Register every spec in an EnvSuite under its canonical IDs. |
get_spec(name) |
Return the full registered EnvSpec for an environment. |
registered_names() |
Sorted list of all registered environment names. |
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file envrax-0.1.8.tar.gz.
File metadata
- Download URL: envrax-0.1.8.tar.gz
- Upload date:
- Size: 235.7 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: uv/0.11.14 {"installer":{"name":"uv","version":"0.11.14","subcommand":["publish"]},"python":null,"implementation":{"name":null,"version":null},"distro":{"name":"Ubuntu","version":"24.04","id":"noble","libc":null},"system":{"name":null,"release":null},"cpu":null,"openssl_version":null,"setuptools_version":null,"rustc_version":null,"ci":true}
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
4ce84db1a4dd0096fa9544b388ce7d5cba79008ed293694265e95c77f62889f8
|
|
| MD5 |
df280a9ffaf17a50931932d3b0b68f69
|
|
| BLAKE2b-256 |
a556a2945ba3937f3dbcba0e464ed09edeb06808b66dd0e4c7fb5e62b4b1c25e
|
File details
Details for the file envrax-0.1.8-py3-none-any.whl.
File metadata
- Download URL: envrax-0.1.8-py3-none-any.whl
- Upload date:
- Size: 45.4 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: uv/0.11.14 {"installer":{"name":"uv","version":"0.11.14","subcommand":["publish"]},"python":null,"implementation":{"name":null,"version":null},"distro":{"name":"Ubuntu","version":"24.04","id":"noble","libc":null},"system":{"name":null,"release":null},"cpu":null,"openssl_version":null,"setuptools_version":null,"rustc_version":null,"ci":true}
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
79ac1a2086d68d00775e1210c24115cd70928a4ddda40ac2f7fd80f76dafc41a
|
|
| MD5 |
cf76932d92ef0a6218a8c5bebb58f154
|
|
| BLAKE2b-256 |
47bc3622424f59414737180e380d8f1e73a471d9e21bcd59d47f467a0d76df32
|