Skip to main content

A JAX-native environment interface with powerful wrappers and adapters for popular RL environment suites

Project description

💌 Envelope: a JAX-native environment interface

# Create environments from JAX-native suites you have installed, ...
env = envelope.create("gymnax::CartPole-v1")

# ... interact with the environments using a simple interface, ...
state, info = env.init(key)
states, infos = jax.lax.scan(env.step, state, actions)
plt.plot(infos.reward.cumsum())

# ... and enjoy a powerful ecosystem of wrappers.
env = envelope.wrappers.AutoResetWrapper(env)
env = envelope.wrappers.VmapWrapper(env)
env = envelope.wrappers.ObservationNormalizationWrapper(env)

🌍 Simple, expressive interaction!

  • Environments are pytrees. Squish them through JAX transformations and trace their parameters.
  • Idiomatic jax-y interface of init(key: Key) -> State, Info and step(state: State, action: PyTree) -> State, Info. You can directly jax.scan over a step(...)!
  • Spaces are super simple. No Tuple, Dict nonsense! There are two spaces: Continuous and Discrete, which you can compose into a PyTreeSpace.
  • Explicit episode truncation supports correctly handling bootstrapping for value-function targets.
  • No auto-reset by default. Resetting every step can be expensive!

💪 Powerful, composable wrappers!

  • Carry state across episodes to track running statistics, for example to normalize observations.
  • Composable wrappers can be stacked in any order. For example, ObservationNormalizationWrapper before vs. after VmapWrapper gives per-env vs. global normalization.

🔌 Adapters for existing suites

📦 # 🤖 # 🌍
brax 🕺 12
craftax 🕺 4
gymnax 🕺 24
jumanji 🕺 / 👯 25 / 1
kinetix 🕺 74
mujoco_playground 🕺 54
navix 🕺 41
Total 🕺 / 👯 234 / 1
envelope.create("📦::🌍")

let's you create environments from any of the above!

📝 Testing

  • Default (no optional adapters deps required): uv run pytest -m "not adapters"
  • Adapters suite (requires full adapters dependency group):
    • uv sync --group adapters
    • uv run pytest -m adapters
    • If any adapter dependency is missing/broken, the run will fail fast with an error telling you what to install.

🏗️ Installation

pip install jax-envelope

💞 Related projects

  • stoa is a very similar project that provides adapters and wrappers for the jumanji-like interface.
  • Check out all the great suites we have adapters for! gymnax, brax, jumanji, kinetix, craftax, navix, mujoco_playground.
  • We will be adding support for jaxmarl and pgx in the future, as soon as we figured out the best ever MARL interface for JAX!

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

jax_envelope-0.4.1.tar.gz (222.2 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

jax_envelope-0.4.1-py3-none-any.whl (36.7 kB view details)

Uploaded Python 3

File details

Details for the file jax_envelope-0.4.1.tar.gz.

File metadata

  • Download URL: jax_envelope-0.4.1.tar.gz
  • Upload date:
  • Size: 222.2 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: twine/6.1.0 CPython/3.13.7

File hashes

Hashes for jax_envelope-0.4.1.tar.gz
Algorithm Hash digest
SHA256 9d2a64af11f7e8417444d4ca90d213762aef45e96f39ff0b5c0e19190671121a
MD5 766bf5ac356faa1b03537251b024dc89
BLAKE2b-256 92b8ea4fecb7dc112d1e4622cb01562d9cb6eb78e831abe0f676f01d97dd3cd7

See more details on using hashes here.

Provenance

The following attestation bundles were made for jax_envelope-0.4.1.tar.gz:

Publisher: publish.yml on keraJLi/envelope

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

File details

Details for the file jax_envelope-0.4.1-py3-none-any.whl.

File metadata

  • Download URL: jax_envelope-0.4.1-py3-none-any.whl
  • Upload date:
  • Size: 36.7 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: twine/6.1.0 CPython/3.13.7

File hashes

Hashes for jax_envelope-0.4.1-py3-none-any.whl
Algorithm Hash digest
SHA256 7fb7474daf4ce7ecd9c8e1ca0db9820bc551b147e438216569262f103b6313c5
MD5 b4979bc5e4ec5aca8e21e6e1287e00ce
BLAKE2b-256 4a6517a29ae6e4c9b4e7bfbbd1bd353efe9b6020909af483552dbbd4fb9ee581

See more details on using hashes here.

Provenance

The following attestation bundles were made for jax_envelope-0.4.1-py3-none-any.whl:

Publisher: publish.yml on keraJLi/envelope

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page