Skip to main content

A high throughput, end-to-end RL library for infinite horizon tasks.

Project description

https://raw.githubusercontent.com/theOGognf/rl8/main/docs/_static/rl8-logo.png

rl8 is a minimal end-to-end RL library that can simulate highly parallelized, infinite horizon environments, and can train a PPO policy using those environments, achieving up to 1M environment transitions (and one policy update) per second using a single NVIDIA RTX 2080.

The figure below depicts rl8’s experiment tracking integration with MLflow and rl8’s ability to solve reinforcement learning problems within seconds.

Consistently solving CartPole within seconds.

Quick Start

Installation

Install with pip for the latest (stable) version.

pip install rl8

Install from GitHub for the latest (unstable) version.

git clone https://github.com/theOGognf/rl8.git
pip install ./rl8/

Basic Usage

Train a policy with PPO and log training progress with MLflow using the high-level trainer interface (this updates the policy indefinitely).

from rl8 import Trainer
from rl8.env import DiscreteDummyEnv

trainer = Trainer(DiscreteDummyEnv)
trainer.run()

Collect environment transitions and update a policy directly using the low-level algorithm interface (this updates the policy once).

from rl8 import Algorithm
from rl8.env import DiscreteDummyEnv

algo = Algorithm(DiscreteDummyEnv)
algo.collect()
algo.step()

The trainer interface is the most popular interface for policy training workflows, whereas the algorithm interface is useful for lower-level customization of policy training workflows.

Concepts

rl8 is minimal in that it limits the number of interfaces required for training a policy with PPO without restrictions on observation and action specs, custom models, and custom action distributions.

rl8 is built around six key concepts:

  • The environment: The simulation that the policy learns to interact with.

  • The model: The policy parameterization that determines how the policy processes environment observations and how parameters for the action distribution are generated.

  • The action distribution: The mechanism for representing actions conditioned on environment observations and model outputs. Environment actions are ultimately sampled from the action distribution.

  • The policy: The union of the model and the action distribution that actually calls and samples from the model and action distribution, respectively.

  • The algorithm: The PPO implementation that uses the environment to train the policy (i.e., update the model’s parameters).

  • The trainer: The high-level interface for using the algorithm to train indefinitely or until some condition is met. The trainer directly integrates with MLflow to track experiments and training progress.

Quick Examples

These short snippets showcase rl8’s main features. See the examples for complete implementations of rl8-compatible environments and models.

Customizing Training Runs

Use a custom distribution and custom hyperparameters by passing options to the trainer (or algorithm) interface.

from rl8 import SquashedNormal, Trainer
from rl8.env import ContinuousDummyEnv

trainer = Trainer(
    ContinuousDummyEnv,
    distribution_cls=SquashedNormal,
    gae_lambda=0.99,
    gamma=0.99,
)
trainer.run()

Training a Recurrent Policy

Swap to the recurrent flavor of the trainer (or algorithm) interface to train a recurrent model and policy.

from rl8 import RecurrentTrainer
from rl8.env import DiscreteDummyEnv

trainer = RecurrentTrainer(DiscreteDummyEnv)
trainer.run()

Training on a GPU

Specify the device used across the environment, model, and algorithm.

from rl8 import Trainer
from rl8.env import DiscreteDummyEnv

trainer = Trainer(DiscreteDummyEnv, device="cuda")
trainer.run()

Minimizing GPU Memory Usage

Enable policy updates with gradient accumulation and/or Automatic Mixed Precision (AMP) to minimize GPU memory usage so you can simulate more environments or use larger models.

import torch.optim as optim

from rl8 import Trainer
from rl8.env import DiscreteDummyEnv

trainer = Trainer(
    DiscreteDummyEnv,
    optimizer_cls=optim.SGD,
    accumulate_grads=True,
    enable_amp=True,
    sgd_minibatch_size=8192,
    device="cuda",
)
trainer.run()

Specifying Training Stop Conditions

Specify conditions based on training statistics to stop training early.

from rl8 import Trainer
from rl8.conditions import Plateaus
from rl8.env import DiscreteDummyEnv

trainer = Trainer(DiscreteDummyEnv)
trainer.run(stop_conditions=[Plateaus("returns/mean", rtol=0.05)])

Why rl8?

TL;DR: rl8 focuses on a niche subset of RL that simplifies the overall library while allowing fast and fully customizable environments, models, and action distributions.

There are many high quality, open-sourced RL libraries. Most of them take on the daunting task of being a monolithic, one-stop-shop for everything RL, attempting to support as many algorithms, environments, models, and compute capabilities as possible. Naturely, this monolothic goal has some drawbacks:

  • The software becomes more dense with each supported feature, making the library all-the-more difficult to customize for a specific use case.

  • The software becomes less performant for a specific use case. RL practitioners typically end up accepting the cost of transitioning to expensive and difficult-to-manage compute clusters to get results faster.

Rather than focusing on being a monolithic RL library, rl8 fills the niche of maximizing training performance for a few key assumptions:

  • Environments are highly parallelized and their parallelization is entirely managed within the environment. This allows rl8 to ignore distributed computing design considerations.

  • Environments are infinite horizon (i.e., they have no terminal conditions). This allows rl8 to reset environments at the same, fixed horizon intervals, greatly simplifying environment and algorithm implementations.

  • The only supported ML framework is PyTorch and the only supported algorithm is PPO. This allows rl8 to ignore layers upon layers of abstraction, greatly simplifying the overall library implementation.

The end result is a minimal and high throughput library that can train policies to solve complex tasks within minutes on consumer grade compute devices.

Unfortunately, this means rl8 doesn’t support as many use cases as a monolithic RL library might. In fact, rl8 is probably a bad fit for your use case if:

  • Your environment isn’t parallelizable.

  • Your environment must contain terminal conditions and can’t be reformulated as an infinite horizon task.

  • You want to use an ML framework that isn’t PyTorch or you want to use an algorithm that isn’t a variant of PPO.

However, if rl8 does fit your use case, it can do wonders for your RL workflow.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

rl8-0.2.0.tar.gz (151.9 kB view details)

Uploaded Source

Built Distribution

rl8-0.2.0-py3-none-any.whl (87.8 kB view details)

Uploaded Python 3

File details

Details for the file rl8-0.2.0.tar.gz.

File metadata

  • Download URL: rl8-0.2.0.tar.gz
  • Upload date:
  • Size: 151.9 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/4.0.2 CPython/3.10.9

File hashes

Hashes for rl8-0.2.0.tar.gz
Algorithm Hash digest
SHA256 9b746ac39cd74f971bfe5dea8cacb8639134e8a6b05691614f51d9ab4d01bc64
MD5 f1806c1f68ec4e9aa7b55e24c17ff7b9
BLAKE2b-256 ba5896706a8f42943d52c229e24b7b242c1f88cae241cb358ec5a832dba70634

See more details on using hashes here.

File details

Details for the file rl8-0.2.0-py3-none-any.whl.

File metadata

  • Download URL: rl8-0.2.0-py3-none-any.whl
  • Upload date:
  • Size: 87.8 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/4.0.2 CPython/3.10.9

File hashes

Hashes for rl8-0.2.0-py3-none-any.whl
Algorithm Hash digest
SHA256 1a7e4eb48e0700531a334c97608505ac59694f03fecbe2f6023233c543293015
MD5 2345af7ae306b49b16c766cd459b5bc0
BLAKE2b-256 eedde6b59d19b0d99de10cb4b43758c5c784ba66dd5b19a82870f51b0b5d5fc0

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page