A Jax interface for Reinforcement Learning environments.
Project description
jit_env: A Jax Compatible RL Environment API
jit_env is a library that aims to adhere closely to the dm_env interface
while allowing for jax transformations inside Environment implementations
and defining clear type annotations.
Like dm_env our API consists of the main components:
jit_env.Environment: An abstract base class for RL environments.jit_env.TimeStep: A container class representing the outputs of the environment on each time step (transition).jit_env.specs: A module containing primitives that are used to describe the format of the actions consumed by an environment, as well as the observations, rewards, and discounts it returns.
This is extended with the components:
jit_env.Wrapper: An interface built on top of Environment that allows modular transformations of the base Environment.jit_env.Action, jit_env.Observation, jit_env.State: Explicit types that concern Agent-Environment IO.jit_env.compat: A Module containing API hooks to other Agent-Environment interfaces likedm_envorgymnasium.jit_env.wrappers: A Module containing a few generally useful implementations forWrapper(that simultaneously serves as a reference).
Note that this module is only an interface and does not implement any
Environments itself. The implementations in examples serve to illustrate the syntax.
For a more thorough review of the semantics, please refer to the dm-env
wiki and compare our implementation of jit_env.Environment with dm_env.Environment and the conversion as given in compat.py.
Installation
jit_env can be installed with (it is recommended to install jax first):
python -m pip install jit-env
You can also install it directly from our GitHub repository using pip:
python -m pip install git+git://github.com/joeryjoery/jit_env.git
or alternatively by checking out a local copy of our repository and running:
python -m pip install /path/to/local/jit_env/
The Big Difference with dm_env
The main difference between this API and the standard dm_env API is
that our definition of jit_env.Environment is functionally pure.
This allows the the logic to e.g., be batched over or accelerated
using jax.vmap or jax.jit.
On top of that, we extend the specs logic of what dm_env provides.
The specs module defines primitive for how the Agent interacts with
the Environment. We explicitly implement additional specs that are
compatible with jax based PyTree objects.
This allows for tree-based operations on the spec objects themselves,
which in turn gives some added flexibility in designing desired
state-action spaces.
Some other modified behaviours include:
restartproviding a reference value for reward and discount in place ofNoneStepTypeis no longer anenumtype asjax.jitwould type convertenumtypes to jax primitives anyway. It remains a namespace for defining episode boundaries.TimeStepis now a frozenchex.dataclassto allow usage ofreplacewithin the public API (which is private forNamedTuple).TimeStepcarries an additionalextrasfield to carry optional data (metrics) not shown to the agent.- all helper
restart,transition, etc., now take ashapevalue to generate the referencerewardordiscountfields.
Why jit_env
I developed this module to have a reliable Environment backend that is less subject
to refactoring changes as other libraries while providing free compatibility to both jax
transforms as well as any other popular type of Agent-Environment interface.
The hope is that this library will not require much maintenance/ alterations (aside from some type-hint updates) after an official 1.0.0 release.
Cite us
If you are a particularly nice person and this work was useful to you, you can cite this repository as:
@misc{jit_env_2023,
author={Joery A. de Vries},
title={jit\_env: A Jax interface for reinforcement learning environments},
year={2023},
url={http://github.com/joeryjoery/jit_env}
}
References
This library was heavily inspired by the following libraries:
- dm-env: https://github.com/deepmind/dm_env
- jumanji: https://github.com/instadeepai/jumanji
- gymnax: https://github.com/RobertTLange/gymnax
- gymnasium: https://github.com/Farama-Foundation/Gymnasium
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file jit_env-0.1.5.tar.gz.
File metadata
- Download URL: jit_env-0.1.5.tar.gz
- Upload date:
- Size: 33.6 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/4.0.2 CPython/3.9.7
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
322dba221f0c1e27b1545b94be920990f969dd0d7f404a012a04472665185a60
|
|
| MD5 |
a0cbe741d02e57c3a2298daed13f48b1
|
|
| BLAKE2b-256 |
3c215357205c81fd4635d1c76cde04c73ef6c9ccbeca6b0fd93fb25616912afb
|
File details
Details for the file jit_env-0.1.5-py3-none-any.whl.
File metadata
- Download URL: jit_env-0.1.5-py3-none-any.whl
- Upload date:
- Size: 35.6 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/4.0.2 CPython/3.9.7
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
0378956eae5ceff66ac8b54784d454810dd8a7cb4a9710215a5f217640a85e19
|
|
| MD5 |
901ef0c897d4c575ab8c1adc7b5bf3f9
|
|
| BLAKE2b-256 |
bb6b79f85fcacb05681b5915e55428f85429e8c0b932826e62e8b941ded925b7
|