Skip to main content

A minimal RL library for infinite horizon tasks.

Project description

rlstack: A Minimal RL Library

rlstack is a minimal RL library that can simulate highly parallelized, infinite horizon environments, and can train a PPO policy using those environments, achieving up to 500k environment transitions (and one policy update) per second using a single NVIDIA RTX 2080.

Quick Start

Installation

Install with pip for the latest (stable) version.

pip install rlstack

Install from GitHub for the latest (unstable) version.

git clone https://github.com/theOGognf/rlstack.git
pip install ./rlstack/

Basic Usage

Collect environment transitions and update a policy directly using the low-level algorithm interface (this updates the policy once).

from rlstack import Algorithm
from rlstack.env import DiscreteDummyEnv

algo = Algorithm(DiscreteDummyEnv)
algo.collect()
algo.step()

Train a policy with PPO and log training progress with MLFlow using the high-level trainer interface (this updates the policy indefinitely).

from rlstack import Trainer
from rlstack.env import DiscreteDummyEnv

trainer = Trainer(DiscreteDummyEnv)
trainer.run()

Concepts

rlstack is minimal in that it limits the number of interfaces required for training a policy with PPO, even for customized policies, without restrictions on observation and action specs, custom models, and custom action distributions.

rlstack is built around six key concepts:

  • The environment: The simulation that the policy learns to interact with. The environment is always user-defined.

  • The model: The policy parameterization that determines how the policy processes environment observations and how parameters for the action distribution are generated. The model is usually user-defined (default models are sometimes sufficient depending on the environment’s observation and action specs).

  • The action distribution: The mechanism for representing actions conditioned on environment observations and model outputs. Environment actions are ultimately sampled from the action distribution. The action distribution is sometimes user-defined (default action distributions are usually sufficient depending on the environment’s observation and action specs).

  • The policy: The union of the model and the action distribution that actually calls and samples from the model and action distribution, respectively. The policy handles some pre/post -processing on its I/O to make it more convenient to sample from the model and action distribution together. The policy is almost never user-defined.

  • The algorithm: The PPO implementation that uses the environment to train the policy (i.e., update the model’s parameters). All hyperparameters and customizations are set with the algorithm. The algorithm is almost never user-defined.

  • The trainer: The high-level interface for using the algorithm to train indefinitely or until some condition is met. The trainer directly integrates with MLFlow to track experiments and training progress. The trainer is never user-defined.

Quick Examples

Customizing Training Runs

Use a custom distribution and custom hyperparameters with the low-level algorithm interface. The algorithm uses default feedforward models depending on the environment’s action spec.

from rlstack import Algorithm, SquashedNormal
from rlstack.env import ContinuousDummyEnv

algo = Algorithm(
    ContinuousDummyEnv,
    distribution_cls=SquashedNormal,
    gae_lambda=0.99,
    gamma=0.99,
)
algo.collect()
algo.step()

Specify the same settings using the high-level trainer interface.

from rlstack import SquashedNormal, Trainer
from rlstack.env import ContinuousDummyEnv

trainer = Trainer(
    ContinuousDummyEnv,
    algorithm_config={
        "distribution_cls": SquashedNormal,
        "gae_lambda": 0.99,
        "gamma": 0.99,
    }
)
trainer.run()

Training a Recurrent Policy

Use the low-level algorithm interface to seamlessly switch between feedforward and recurrent algorithms. The recurrent algorithm uses default recurrent models depending on the environment’s action spec.

from rlstack import RecurrentAlgorithm
from rlstack.env import DiscreteDummyEnv

algo = RecurrentAlgorithm(DiscreteDummyEnv)
algo.collect()
algo.step()

Specify the algorithm type using the high-level trainer interface (which usually defaults to a feedforward version of the algorithm).

from rlstack import RecurrentAlgorithm, Trainer
from rlstack.env import DiscreteDummyEnv

trainer = Trainer(DiscreteDummyEnv, algorithm_cls=RecurrentAlgorithm)
trainer.run()

Training on a GPU

Use the low-level algorithm interface to specify training on a GPU.

from rlstack import Algorithm
from rlstack.env import DiscreteDummyEnv

algo = Algorithm(DiscreteDummyEnv, device="cuda")
algo.collect()
algo.step()

Specify training on a GPU using the high-level trainer interface.

from rlstack import Trainer
from rlstack.env import DiscreteDummyEnv

trainer = Trainer(DiscreteDummyEnv, algorithm_config={"device": "cuda"})
trainer.run()

Minimizing GPU Memory Usage

Use the low-level algorithm interface to enable policy updates with gradient accumulation and/or Automatic Mixed Precision (AMP) to minimize GPU memory usage so you can simulate more environments or use larger models.

import torch.optim as optim

from rlstack import Algorithm
from rlstack.env import DiscreteDummyEnv

algo = Algorithm(
    DiscreteDummyEnv,
    optimizer_cls=optim.SGD,
    accumulate_grads=True,
    enable_amp=True,
    sgd_minibatch_size=8192,
    device="cuda",
)
algo.collect()
algo.step()

Enable memory-minimization settings using the high-level trainer interface.

import torch.optim as optim

from rlstack import Trainer
from rlstack.env import DiscreteDummyEnv

trainer = Trainer(DiscreteDummyEnv,
    algorithm_config={
        "optimizer_cls": optim.SGD,
        "accumulate_grads": True,
        "enable_amp": True,
        "sgd_minibatch_size": 8192,
        "device": "cuda",
    }
)
trainer.run()

Specifying Training Stop Conditions

Specify training stop conditions based on training statistics using the high-level trainer interface.

from rlstack import Trainer
from rlstack.conditions import Plateaus
from rlstack.env import DiscreteDummyEnv

trainer = Trainer(
    DiscreteDummyEnv,
    stop_conditions=[Plateaus("returns/mean", rtol=0.05)],
)
trainer.run()

Why rlstack?

TL;DR: rlstack focuses on a niche subset of RL that simplifies the overall library while allowing fast and fully customizable environments, models, and action distributions.

There are many high quality, open-sourced RL libraries. Most of them take on the daunting task of being a monolithic, one-stop-shop for everything RL, attempting to support as many algorithms, environments, models, and compute capabilities as possible. Naturely, this monolothic goal has some drawbacks:

  • The software becomes more dense with each supported feature, making the library all-the-more difficult to customize for a specific use case.

  • The software becomes less performant for a specific use case. RL practitioners typically end up accepting the cost of transitioning to expensive and difficult-to-manage compute clusters to get results faster.

There’s a handful of high quality, open-sourced RL libraries that tradeoff feature richness to reduce these drawbacks. However, each library still doesn’t provide enough speed benefit to warrant the switch from a monolithic repo, or is still too complex to adapt to a specific use case.

rlstack is a niche RL library that finds a goldilocks zone between the feature support and speed/complexity tradeoff by making some key assumptions:

  • Environments are highly parallelized and their parallelization is entirely managed within the environment. This allows rlstack to ignore distributed computing design considerations.

  • Environments are infinite horizon (i.e., they have no terminal conditions). This allows rlstack to reset environments at the same, fixed horizon intervals, greatly simplifying environment and algorithm implementations.

  • The only supported ML framework is PyTorch and the only supported algorithm is PPO. This allows rlstack to ignore layers upon layers of abstraction, greatly simplifying the overall library implementation.

The end result is a minimal and high throughput library that can train policies to solve complex tasks on a single NVIDIA RTX 2080 within minutes.

Unfortunately, this means rlstack doesn’t support as many use cases as a monolithic RL library might. In fact, rlstack is probably a bad fit for your use case if:

  • Your environment isn’t parallelizable.

  • Your environment must contain terminal conditions and can’t be reformulated as an infinite horizon task.

  • You want to use an ML framework that isn’t PyTorch or you want to use an algorithm that isn’t a variant of PPO.

However, if rlstack does fit your use case, it can do wonders for your RL workflow.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

rlstack-0.1.1.tar.gz (87.5 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

rlstack-0.1.1-py3-none-any.whl (87.6 kB view details)

Uploaded Python 3

File details

Details for the file rlstack-0.1.1.tar.gz.

File metadata

  • Download URL: rlstack-0.1.1.tar.gz
  • Upload date:
  • Size: 87.5 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/4.0.2 CPython/3.10.9

File hashes

Hashes for rlstack-0.1.1.tar.gz
Algorithm Hash digest
SHA256 f6094b5a680fb23a56fc191dedfcc27aa3297a27a7d2a534d2d611fed4de886e
MD5 7ee225b816301cdffcfc09758afdf6ee
BLAKE2b-256 2b4b070eae3b44a1b84da5f03e73ecec35c8eadc74bca9c70f6952ca05bc32cd

See more details on using hashes here.

File details

Details for the file rlstack-0.1.1-py3-none-any.whl.

File metadata

  • Download URL: rlstack-0.1.1-py3-none-any.whl
  • Upload date:
  • Size: 87.6 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/4.0.2 CPython/3.10.9

File hashes

Hashes for rlstack-0.1.1-py3-none-any.whl
Algorithm Hash digest
SHA256 485c4c8a28c5bcf0ff3e3a5fd9fc4d648c81cc52eb34bf98c56dc7ff2e0ba95d
MD5 ded17f660cce920525f6a22d9f05f70b
BLAKE2b-256 40cbf5e2cc08ddbf2905a8c504a4713b61db694af8e3686df0c9e17715387f50

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page