Skip to main content

Single-Agent Reinforcement Learning with JAX

Project description

Stoa logo

A JAX-Native Interface for Reinforcement Learning Environments

🚀 Welcome to Stoa

Stoa provides a lightweight, JAX-native interface for reinforcement learning environments. It defines a common abstraction layer that enables different environment libraries to work together seamlessly in JAX workflows.

⚠️ Early Development – Core abstractions are in place, but the library is still growing!

🎯 What Stoa Provides

  • Common Interface: A standardized Environment base class that defines the contract for RL environments in JAX.
  • JAX-Native Design: Pure-functional step and reset operations compatible with JAX transformations like jit and vmap.
  • Environment Wrappers: A flexible system for composing and extending environments with additional functionality.
  • Space Definitions: Structured representations for observation, action, and state spaces.
  • TimeStep Protocol: A standardized TimeStep structure to represent environment transitions with clear termination and truncation signals.

🛠️ Installation

You can install the core stoa library via pip:

pip install stoa-env

This minimal installation includes the core API and wrappers but no specific environment adapters.

Environment Adapters

Adapters for external environment libraries are available as optional extras. You can install them individually or all at once.

Install a specific adapter:

# Example for Gymnax
pip install "stoa-env[gymnax]"

# Example for Brax
pip install "stoa-env[brax]"

Install all available adapters:

pip install "stoa-env[all]"

🧩 Available Adapters

Stoa currently supports the following JAX-native environment libraries:

  • Brax
  • Gymnax
  • Jumanji
  • Kinetix
  • Navix
  • PGX (Game environments)
  • MuJoCo Playground
  • XMinigrid

✨ Available Wrappers

Stoa provides a rich set of wrappers to modify and extend environment behavior:

  • Core Wrappers: AutoResetWrapper, RecordEpisodeMetrics, AddRNGKey, VmapWrapper.
  • Observation Wrappers: FlattenObservationWrapper, FrameStackingWrapper, ObservationExtractWrapper, AddActionMaskWrapper, AddStartFlagAndPrevAction, AddStepCountWrapper, MakeChannelLast, ObservationTypeWrapper.
  • Action Space Wrappers: MultiDiscreteToDiscreteWrapper, MultiBoundedToBoundedWrapper.
  • Utility Wrappers: EpisodeStepLimitWrapper, ConsistentExtrasWrapper.

⚡ Usage Example

Here's how to adapt a gymnax environment and compose it with several wrappers:

import jax
import gymnax
from stoa import GymnaxToStoa, FlattenObservationWrapper, AutoResetWrapper, RecordEpisodeMetrics

# 1. Instantiate a base environment from a supported library
gymnax_env, env_params = gymnax.make("CartPole-v1")

# 2. Adapt the environment to the Stoa interface
env = GymnaxToStoa(gymnax_env, env_params)

# 3. Apply standard wrappers
# Note: The order of wrappers matters.
env = AutoResetWrapper(env, next_obs_in_extras=True)
env = RecordEpisodeMetrics(env)

# JIT compile the reset and step functions for performance
env.reset = jax.jit(env.reset)
env.step = jax.jit(env.step)

# 4. Interact with the environment
rng_key = jax.random.PRNGKey(0)
state, timestep = env.reset(rng_key)
total_reward = 0

for _ in range(100):
    action = env.action_space().sample(rng_key)
    state, timestep = env.step(state, action)
    total_reward += timestep.reward

    if timestep.last():
        # Access metrics recorded by the RecordEpisodeMetrics wrapper
        episode_return = timestep.extras['episode_metrics']['episode_return']
        print(f"Episode finished. Return: {episode_return}")

        # The state has been auto-reset, so we can continue the loop
        total_reward = 0

🛣️ Roadmap

  • Documentation: Expand documentation with detailed tutorials and API references.
  • More Wrappers: Add more common utility wrappers (e.g., observation normalization, reward clipping).
  • Integration Examples: Provide examples of how to integrate stoa with popular JAX-based RL libraries.

🤝 Contributing

We're building Stoa to provide a common foundation for JAX-based RL research. Contributions are welcome!

📚 Related Projects

  • Stoix – Distributed single-agent RL in JAX
  • Gymnax – Classic control environments in JAX
  • Brax – Physics-based environments in JAX
  • Jumanji – Board games and optimization problems in JAX
  • Navix – Grid-world environments in JAX
  • PGX - Classic board and card game environments in JAX
  • Kinetix - Robotics environments in JAX

Citation

If you use Stoa, please cite it!

@misc{toledo2025stoa,
  author = {Edan Toledo},
  title  = {Stoa: A JAX-Native Interface for Reinforcement Learning Environments},
  year   = {2025},
  url    = {https://github.com/EdanToledo/Stoa}
}

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

stoa_env-0.1.2.tar.gz (42.9 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

stoa_env-0.1.2-py3-none-any.whl (55.8 kB view details)

Uploaded Python 3

File details

Details for the file stoa_env-0.1.2.tar.gz.

File metadata

  • Download URL: stoa_env-0.1.2.tar.gz
  • Upload date:
  • Size: 42.9 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: twine/6.1.0 CPython/3.13.7

File hashes

Hashes for stoa_env-0.1.2.tar.gz
Algorithm Hash digest
SHA256 8677d446d313e73465870f60e652bbdff4cf17cbf2bde81d2f9f7c3a380e4c04
MD5 669fc4da8ee6a5b4b1c4e5a832e80091
BLAKE2b-256 630637cfafbc39cae3ead9d448e7736477df44d48bfaae243a89f71f3a5b56c5

See more details on using hashes here.

Provenance

The following attestation bundles were made for stoa_env-0.1.2.tar.gz:

Publisher: release.yaml on EdanToledo/Stoa

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

File details

Details for the file stoa_env-0.1.2-py3-none-any.whl.

File metadata

  • Download URL: stoa_env-0.1.2-py3-none-any.whl
  • Upload date:
  • Size: 55.8 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: twine/6.1.0 CPython/3.13.7

File hashes

Hashes for stoa_env-0.1.2-py3-none-any.whl
Algorithm Hash digest
SHA256 e3e90674d81ae66ff87c9fca06afc93f45252304600f05f121fcb5d34aef1e6c
MD5 20170ed2d8199c4abcb0dc3536702076
BLAKE2b-256 85ef21536ac34a24e7c6bf9e88db1b3361b40e6c3f4a140bbee895a4d9ca46f5

See more details on using hashes here.

Provenance

The following attestation bundles were made for stoa_env-0.1.2-py3-none-any.whl:

Publisher: release.yaml on EdanToledo/Stoa

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page