Skip to main content

Deep Reinforcement Learning with JAX and Equinox.

Project description

Lerax

This is a work in progress implementation of a JAX based reinforcement learning library using Equinox. The main feature is Neural Differential Equation based models. This is meant as a cleaner and more complete continuation of earlier work in this repo NCDE-RL NDEs can be extraordinarily computationally intensive, this library is intended to provide an optimized implementation of NDEs and other RL algorithms using just in time compilation (JIT). Paired with environments that support JIT, high performance is possible using the Anakin architecture for fully GPU based RL.

I'm working on this in my free time, so it may take a while to get to a usable state. I'm also mainly developing this for personal research, so it may not be suitable for all use cases.

Credit

A ton of the code is a slight translation of the code found in the Stable Baselines 3 and Gymnasium libraries which are both under the MIT license. The developers of these excellent libraries have done a great job of creating a solid foundation for reinforcement learning in Python, and I have learned a lot from their code.

In addition, the NDE code is heavily inspired by the work of Patrick Kidger and the entire library is based on his excellent Equinox library along with some use of Diffrax and jaxtyping.

Usage

Installation

Install via pip:

pip install lerax@git+https://github.com/RunnersNum40/lerax.git

Or clone the repo and install in editable mode:

git clone https://github.com/RunnersNum40/lerax.git
cd lerax
pip install -e .

Running an example

python examples/ppo.py

Running TensorBoard

tensorboard --logdir runs

Then open your browser to http://localhost:6006.

Creating your own models and environments

Check out the MLP Actor Critic for a simple example of how to create your own actor critic model. Check out the PPO example for how to use your model in training. Check out the CartPole environment for how to create your own environment. Check out the Gymnasium wrapper for how to wrap Gymnasium environments (this will be slower to run than a fully Jax environment).

TODO

  • Optimise for performance under JIT compilation
    • Sharding support for distributed training
  • Expand policy support beyond Box and Discrete spaces
  • Documentation
    • Standardize docstring formats
    • Write documentation for all public APIs
    • Add API to docs when Zensical supports it
  • Testing
    • Unit testing
    • Integration testing
    • Full Jaxtyping
      • Ensure all functions and classes have proper type annotations
  • Use it
    • Personal research
  • Round out features
    • Expand RL variants to include more algorithms
    • Create a more comprehensive set of environments
      • Brax based environments
    • Save and load models

Code Style

This code is written to follow the Equinox's abstract/final pattern for code structure and Black formatting. This is intended to make the code more readable and maintainable, and to ensure that it is consistent with the Equinox library. If you want to contribute, please follow these conventions.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

lerax-0.0.1a1.tar.gz (52.0 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

lerax-0.0.1a1-py3-none-any.whl (82.3 kB view details)

Uploaded Python 3

File details

Details for the file lerax-0.0.1a1.tar.gz.

File metadata

  • Download URL: lerax-0.0.1a1.tar.gz
  • Upload date:
  • Size: 52.0 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: twine/6.1.0 CPython/3.13.7

File hashes

Hashes for lerax-0.0.1a1.tar.gz
Algorithm Hash digest
SHA256 8bf7f74a1bcf365267b22bd969fe4aa94ba82f1541ccf962fabfdc535a2b0150
MD5 88f34282ab5eba0b6f8a39eed4da9d51
BLAKE2b-256 c0bc17ba5c51130d4db9a986518b76508661e5e46d58da26358a1f83c40a8855

See more details on using hashes here.

Provenance

The following attestation bundles were made for lerax-0.0.1a1.tar.gz:

Publisher: release.yml on RunnersNum40/lerax

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

File details

Details for the file lerax-0.0.1a1-py3-none-any.whl.

File metadata

  • Download URL: lerax-0.0.1a1-py3-none-any.whl
  • Upload date:
  • Size: 82.3 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: twine/6.1.0 CPython/3.13.7

File hashes

Hashes for lerax-0.0.1a1-py3-none-any.whl
Algorithm Hash digest
SHA256 c83d3021ff65f7715e97480a026ed57d7bea6e20e4632a7f1c549103a78c6105
MD5 ec5484e5c7c5228af354ff5ae421f16a
BLAKE2b-256 41c98be85cf3187e4bf284ceda138160ad4b144b5487feef72abd58f0068f99b

See more details on using hashes here.

Provenance

The following attestation bundles were made for lerax-0.0.1a1-py3-none-any.whl:

Publisher: release.yml on RunnersNum40/lerax

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page