Skip to main content

Deep Reinforcement Learning with JAX and Equinox.

Project description

Lerax: Fully JITable reinforcement learning with Jax.

Lerax is a reinforcement learning library built on top of Jax, designed to facilitate the creation, training, and evaluation of RL agents in a fully JITable manner. It provides modular components for building custom environments, policies, and training algorithms.

Built on top of Jax, Equinox, and Diffrax.

Installation

pip install lerax

Documentation

Check out: lerax.tedpinkerton.ca

Training Example

from jax import random as jr

from lerax.algorithm import PPO
from lerax.env import CartPole
from lerax.policy import MLPActorCriticPolicy

env = CartPole()
policy = MLPActorCriticPolicy(env=env, key=jr.key(0))
algo = PPO()

policy = algo.learn(env, policy, total_timesteps=2**16, key=jr.key(1))

TODO

  • Optimise for performance under JIT compilation
    • Sharding support for distributed training
  • Documentation
    • Standardize docstring formats
    • Write documentation for all public APIs
    • Add API to docs when Zensical supports it
  • Testing
    • Unit testing
    • Integration testing
    • Full Jaxtyping
      • Ensure all functions and classes have proper type annotations
  • Round out features
    • Expand RL variants to include more algorithms
      • Any off-policy algorithms
    • Create a more comprehensive set of environments
      • Brax based environments

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

lerax-0.0.3.tar.gz (68.4 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

lerax-0.0.3-py3-none-any.whl (111.3 kB view details)

Uploaded Python 3

File details

Details for the file lerax-0.0.3.tar.gz.

File metadata

  • Download URL: lerax-0.0.3.tar.gz
  • Upload date:
  • Size: 68.4 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: twine/6.1.0 CPython/3.13.7

File hashes

Hashes for lerax-0.0.3.tar.gz
Algorithm Hash digest
SHA256 25e30f2f833aec3f546826d83387a40c4678737a2aebbe8121331cc094492754
MD5 670b1f45bbda2c0815531631e67c1f47
BLAKE2b-256 4b609a38521e196275e438c3327a0ae096f228182e4475e6aee3481f6791d27b

See more details on using hashes here.

Provenance

The following attestation bundles were made for lerax-0.0.3.tar.gz:

Publisher: release.yml on RunnersNum40/lerax

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

File details

Details for the file lerax-0.0.3-py3-none-any.whl.

File metadata

  • Download URL: lerax-0.0.3-py3-none-any.whl
  • Upload date:
  • Size: 111.3 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: twine/6.1.0 CPython/3.13.7

File hashes

Hashes for lerax-0.0.3-py3-none-any.whl
Algorithm Hash digest
SHA256 5762788f1f6bf7c97e121f8b94606c13b769dea835a4f783352dff9b34b6b6c2
MD5 16f85b2269e0bac42012fc54376667b3
BLAKE2b-256 4f1d8d5c0438359ea1e65861091967bf0bf22b0186ad9f3dc60028bcf556e1af

See more details on using hashes here.

Provenance

The following attestation bundles were made for lerax-0.0.3-py3-none-any.whl:

Publisher: release.yml on RunnersNum40/lerax

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page