Skip to main content

A library of reinforcement learning building blocks in JAX.

Project description

RLax

CI status docs pypi

RLax (pronounced "relax") is a library built on top of JAX that exposes useful building blocks for implementing reinforcement learning agents. Full documentation can be found at rlax.readthedocs.io.

Installation

You can install the latest released version of RLax from PyPI via:

pip install rlax

or you can install the latest development version from GitHub:

pip install git+https://github.com/deepmind/rlax.git

All RLax code may then be just in time compiled for different hardware (e.g. CPU, GPU, TPU) using jax.jit.

In order to run the examples/ you will also need to clone the repo and install the additional requirements: optax, haiku, and bsuite.

Content

The operations and functions provided are not complete algorithms, but implementations of reinforcement learning specific mathematical operations that are needed when building fully-functional agents capable of learning:

  • Values, including both state and action-values;
  • Values for Non-linear generalizations of the Bellman equations.
  • Return Distributions, aka distributional value functions;
  • General Value Functions, for cumulants other than the main reward;
  • Policies, via policy-gradients in both continuous and discrete action spaces.

The library supports both on-policy and off-policy learning (i.e. learning from data sampled from a policy different from the agent's policy).

See file-level and function-level doc-strings for the documentation of these functions and for references to the papers that introduced and/or used them.

Usage

See examples/ for examples of using some of the functions in RLax to implement a few simple reinforcement learning agents, and demonstrate learning on BSuite's version of the Catch environment (a common unit-test for agent development in the reinforcement learning literature):

Other examples of JAX reinforcement learning agents using rlax can be found in bsuite.

Background

Reinforcement learning studies the problem of a learning system (the agent), which must learn to interact with the universe it is embedded in (the environment).

Agent and environment interact on discrete steps. On each step the agent selects an action, and is provided in return a (partial) snapshot of the state of the environment (the observation), and a scalar feedback signal (the reward).

The behaviour of the agent is characterized by a probability distribution over actions, conditioned on past observations of the environment (the policy). The agents seeks a policy that, from any given step, maximises the discounted cumulative reward that will be collected from that point onwards (the return).

Often the agent policy or the environment dynamics itself are stochastic. In this case the return is a random variable, and the optimal agent's policy is typically more precisely specified as a policy that maximises the expectation of the return (the value), under the agent's and environment's stochasticity.

Reinforcement Learning Algorithms

There are three prototypical families of reinforcement learning algorithms:

  1. those that estimate the value of states and actions, and infer a policy by inspection (e.g. by selecting the action with highest estimated value)
  2. those that learn a model of the environment (capable of predicting the observations and rewards) and infer a policy via planning.
  3. those that parameterize a policy that can be directly executed,

In any case, policies, values or models are just functions. In deep reinforcement learning such functions are represented by a neural network. In this setting, it is common to formulate reinforcement learning updates as differentiable pseudo-loss functions (analogously to (un-)supervised learning). Under automatic differentiation, the original update rule is recovered.

Note however, that in particular, the updates are only valid if the input data is sampled in the correct manner. For example, a policy gradient loss is only valid if the input trajectory is an unbiased sample from the current policy; i.e. the data are on-policy. The library cannot check or enforce such constraints. Links to papers describing how each operation is used are however provided in the functions' doc-strings.

Naming Conventions and Developer Guidelines

We define functions and operations for agents interacting with a single stream of experience. The JAX construct vmap can be used to apply these same functions to batches (e.g. to support replay and parallel data generation).

Many functions consider policies, actions, rewards, values, in consecutive timesteps in order to compute their outputs. In this case the suffix _t and tm1 is often to clarify on which step each input was generated, e.g:

  • q_tm1: the action value in the source state of a transition.
  • a_tm1: the action that was selected in the source state.
  • r_t: the resulting rewards collected in the destination state.
  • discount_t: the discount associated with a transition.
  • q_t: the action values in the destination state.

Extensive testing is provided for each function. All tests should also verify the output of rlax functions when compiled to XLA using jax.jit and when performing batch operations using jax.vmap.

Citing RLax

This repository is part of the DeepMind JAX Ecosystem, to cite Rlax please use the citation:

@software{deepmind2020jax,
  title = {The {D}eep{M}ind {JAX} {E}cosystem},
  author = {DeepMind and Babuschkin, Igor and Baumli, Kate and Bell, Alison and Bhupatiraju, Surya and Bruce, Jake and Buchlovsky, Peter and Budden, David and Cai, Trevor and Clark, Aidan and Danihelka, Ivo and Dedieu, Antoine and Fantacci, Claudio and Godwin, Jonathan and Jones, Chris and Hemsley, Ross and Hennigan, Tom and Hessel, Matteo and Hou, Shaobo and Kapturowski, Steven and Keck, Thomas and Kemaev, Iurii and King, Michael and Kunesch, Markus and Martens, Lena and Merzic, Hamza and Mikulik, Vladimir and Norman, Tamara and Papamakarios, George and Quan, John and Ring, Roman and Ruiz, Francisco and Sanchez, Alvaro and Sartran, Laurent and Schneider, Rosalia and Sezener, Eren and Spencer, Stephen and Srinivasan, Srivatsan and Stanojevi\'{c}, Milo\v{s} and Stokowiec, Wojciech and Wang, Luyu and Zhou, Guangyao and Viola, Fabio},
  url = {http://github.com/deepmind},
  year = {2020},
}

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

rlax-0.1.8.tar.gz (84.7 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

rlax-0.1.8-py3-none-any.whl (116.2 kB view details)

Uploaded Python 3

File details

Details for the file rlax-0.1.8.tar.gz.

File metadata

  • Download URL: rlax-0.1.8.tar.gz
  • Upload date:
  • Size: 84.7 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: twine/6.1.0 CPython/3.12.9

File hashes

Hashes for rlax-0.1.8.tar.gz
Algorithm Hash digest
SHA256 4ec5dcd3969099e8699a979b15b967187ce2536bba467ad7c1d980f1eeade541
MD5 c2a37ea0ef46de6088252b1b6264c8c1
BLAKE2b-256 94965e68f3dac7a4cc730ffa5885f9c057023daa86384dc29a8c6da0fb65c7ae

See more details on using hashes here.

Provenance

The following attestation bundles were made for rlax-0.1.8.tar.gz:

Publisher: pypi-publish.yml on google-deepmind/rlax

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

File details

Details for the file rlax-0.1.8-py3-none-any.whl.

File metadata

  • Download URL: rlax-0.1.8-py3-none-any.whl
  • Upload date:
  • Size: 116.2 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: twine/6.1.0 CPython/3.12.9

File hashes

Hashes for rlax-0.1.8-py3-none-any.whl
Algorithm Hash digest
SHA256 c6d5c45cc6727dd7ad58baacf12b92ca53568d7d68063b8a97f3c5c36552cb8f
MD5 2325314c5c3ac52486d8de500e9bdbc4
BLAKE2b-256 f20e04727f672ba609c4d4ed9f9be82cc1e7afb0d7355051f9f406b55defac10

See more details on using hashes here.

Provenance

The following attestation bundles were made for rlax-0.1.8-py3-none-any.whl:

Publisher: pypi-publish.yml on google-deepmind/rlax

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page