Skip to main content

Deep Reinforcement Learning with JAX and Equinox.

Project description

Lerax

This is a work in progress implementation of a JAX based reinforcement learning library using Equinox. The main feature is Neural Differential Equation based models. NDEs can be extraordinarily computationally intensive, this library is intended to provide an optimised implementation of NDEs and other RL algorithms using just in time compilation (JIT). Paired with environments that support JIT, high performance is possible using the Anakin architecture for fully GPU based RL.

I'm working on this in my free time, so it may take a while to get to a usable state. I'm also mainly developing this for personal research, so it may not be suitable for all use cases.

Code Style

This code is written to follow the Equinox's abstract/final pattern for code structure and Black formatting. This is intended to make the code more readable and maintainable, and to ensure that it is consistent with the Equinox library. If you want to contribute, please follow these conventions.

Credit

A ton of the code is a slight translation of the code found in the Stable Baselines 3 and Gymnasium libraries which are both under the MIT license. The developers of these excellent libraries have done a great job of creating a solid foundation for reinforcement learning in Python, and I have learned a lot from their code.

In addition, the NDE code is heavily inspired by the work of Patrick Kidger and the entire library is based on his excellent Equinox library along with some use of Diffrax and jaxtyping.

TODO

  • Expand support beyond Box and Discrete spaces
  • Logging
    • Code flow logging
    • Training logging
    • Migrate from tensorboard to aim
  • Documentation
    • Standardize docstring formats
    • Write documentation for all public APIs
    • Publish docs
  • Testing
    • More thorough unit testing
    • Integration testing
    • Runtime jaxtyping
  • Use it
    • Personal research
  • Optimise for performance under JIT compilation
    • Good vectorization support
    • Sharding support for distributed training
  • Round out features
    • Rendering support
    • Expand RL variants to include more algorithms
    • Create a more comprehensive set of environments

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

lerax-0.0.1a0.tar.gz (45.2 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

lerax-0.0.1a0-py3-none-any.whl (67.6 kB view details)

Uploaded Python 3

File details

Details for the file lerax-0.0.1a0.tar.gz.

File metadata

  • Download URL: lerax-0.0.1a0.tar.gz
  • Upload date:
  • Size: 45.2 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.13.7

File hashes

Hashes for lerax-0.0.1a0.tar.gz
Algorithm Hash digest
SHA256 80ad6043e54d6d6400a2c9d391bdf02a8af690cb873730c170f806958539c235
MD5 e82842c3d6e5d929f445a9c51e590007
BLAKE2b-256 d27b70102fdc28f0e27d144229e4a64eb3d02abe6b0b6c846cb230b430e92852

See more details on using hashes here.

File details

Details for the file lerax-0.0.1a0-py3-none-any.whl.

File metadata

  • Download URL: lerax-0.0.1a0-py3-none-any.whl
  • Upload date:
  • Size: 67.6 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.13.7

File hashes

Hashes for lerax-0.0.1a0-py3-none-any.whl
Algorithm Hash digest
SHA256 970e8bcd9a49e91ea4956e05c994cfad0e81d31266c2594423016ec0ee4c1e25
MD5 4b734d2cbed7c95d8ec756b46c74ce43
BLAKE2b-256 dd66281d5892a62599ae9a412593beb2e45f8ae4aea7db33bf68cc0ac5bbdc73

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page