Skip to main content

Single-Agent Reinforcement Learning with JAX

Project description

Stoa logo

A JAX-Native Interface for Reinforcement Learning Environments

🚀 Welcome to Stoa

Stoa provides a lightweight, JAX-native interface for reinforcement learning environments. It defines a common abstraction layer that enables different environment libraries to work together seamlessly in JAX workflows.

⚠️ Early Development – Core abstractions are in place, but the library is still growing!


🎯 What Stoa Provides

  • Common Interface: A standardized Environment base class that defines the contract for RL environments in JAX.
  • JAX-Native Design: Pure-functional step and reset operations compatible with JAX transformations like jit and vmap.
  • Environment Wrappers: A flexible system for composing and extending environments with additional functionality.
  • Space Definitions: Structured representations for observation, action, and state spaces.
  • TimeStep Protocol: A standardized TimeStep structure to represent environment transitions with clear termination and truncation signals.

🛠️ Installation

You can install the core stoa library via pip:

pip install stoa-env

This minimal installation includes the core API and wrappers but no specific environment adapters.

Environment Adapters

Adapters for external environment libraries are available as optional extras. You can install them individually or all at once.

Install a specific adapter:

# Example for Gymnax
pip install "stoa-env[gymnax]"

# Example for Brax
pip install "stoa-env[brax]"

Install all available adapters:

pip install "stoa-env[all]"

🧩 Available Adapters

Stoa currently supports the following JAX-native environment libraries:

  • Brax
  • Gymnax
  • Jumanji
  • Kinetix
  • Navix
  • PGX (Game environments)
  • MuJoCo Playground
  • XMinigrid

✨ Available Wrappers

Stoa provides a rich set of wrappers to modify and extend environment behavior:

  • Core Wrappers: AutoResetWrapper, RecordEpisodeMetrics, AddRNGKey, VmapWrapper.
  • Observation Wrappers: FlattenObservationWrapper, FrameStackingWrapper, ObservationExtractWrapper, AddActionMaskWrapper, AddStartFlagAndPrevAction, AddStepCountWrapper, MakeChannelLast, ObservationTypeWrapper.
  • Action Space Wrappers: MultiDiscreteToDiscreteWrapper, MultiBoundedToBoundedWrapper.
  • Utility Wrappers: EpisodeStepLimitWrapper, ConsistentExtrasWrapper.

⚡ Usage Example

Here's how to adapt a gymnax environment and compose it with several wrappers:

import jax
import gymnax
from stoa import GymnaxToStoa, FlattenObservationWrapper, AutoResetWrapper, RecordEpisodeMetrics

# 1. Instantiate a base environment from a supported library
gymnax_env, env_params = gymnax.make("CartPole-v1")

# 2. Adapt the environment to the Stoa interface
env = GymnaxToStoa(gymnax_env, env_params)

# 3. Apply standard wrappers
# Note: The order of wrappers matters.
env = AutoResetWrapper(env, next_obs_in_extras=True)
env = RecordEpisodeMetrics(env)

# JIT compile the reset and step functions for performance
env.reset = jax.jit(env.reset)
env.step = jax.jit(env.step)

# 4. Interact with the environment
rng_key = jax.random.PRNGKey(0)
state, timestep = env.reset(rng_key)
total_reward = 0

for _ in range(100):
    action = env.action_space().sample(rng_key)
    state, timestep = env.step(state, action)
    total_reward += timestep.reward

    if timestep.last():
        # Access metrics recorded by the RecordEpisodeMetrics wrapper
        episode_return = timestep.extras['episode_metrics']['episode_return']
        print(f"Episode finished. Return: {episode_return}")

        # The state has been auto-reset, so we can continue the loop
        total_reward = 0

🛣️ Roadmap

  • Documentation: Expand documentation with detailed tutorials and API references.
  • More Wrappers: Add more common utility wrappers (e.g., observation normalization, reward clipping).
  • Integration Examples: Provide examples of how to integrate stoa with popular JAX-based RL libraries.

🤝 Contributing

We're building Stoa to provide a common foundation for JAX-based RL research. Contributions are welcome!


📚 Related Projects

  • Stoix – Distributed single-agent RL in JAX
  • Gymnax – Classic control environments in JAX
  • Brax – Physics-based environments in JAX
  • Jumanji – Board games and optimization problems in JAX
  • Navix – Grid-world environments in JAX
  • PGX - Classic board and card game environments in JAX
  • Kinetix - Robotics environments in JAX

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

stoa_env-0.1.1.tar.gz (42.3 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

stoa_env-0.1.1-py3-none-any.whl (54.5 kB view details)

Uploaded Python 3

File details

Details for the file stoa_env-0.1.1.tar.gz.

File metadata

  • Download URL: stoa_env-0.1.1.tar.gz
  • Upload date:
  • Size: 42.3 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: twine/6.1.0 CPython/3.13.7

File hashes

Hashes for stoa_env-0.1.1.tar.gz
Algorithm Hash digest
SHA256 560c510bab4e3d88c58facd03d3987a7c5322adabad1fc88baa61110b23c325f
MD5 9cecf00a3d6da52b6a7199c097fb68b6
BLAKE2b-256 6f62857a178490d592e78ba14f351e9df5826a32918fb19622b110dd8b55aa1e

See more details on using hashes here.

Provenance

The following attestation bundles were made for stoa_env-0.1.1.tar.gz:

Publisher: release.yaml on EdanToledo/Stoa

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

File details

Details for the file stoa_env-0.1.1-py3-none-any.whl.

File metadata

  • Download URL: stoa_env-0.1.1-py3-none-any.whl
  • Upload date:
  • Size: 54.5 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: twine/6.1.0 CPython/3.13.7

File hashes

Hashes for stoa_env-0.1.1-py3-none-any.whl
Algorithm Hash digest
SHA256 f1ae996970e40d00a36f7eaf9b2f49c1536af8be247b67c8daf6bd0b38addd7e
MD5 379fa7c5a069df86b7a842507762688c
BLAKE2b-256 fd0be863b183a48835e160ce01ecb1b99932c7de4d8585098a4ebab194cc5d41

See more details on using hashes here.

Provenance

The following attestation bundles were made for stoa_env-0.1.1-py3-none-any.whl:

Publisher: release.yaml on EdanToledo/Stoa

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page