Skip to main content

A PyTorch library for model-based reinforcement learning research

Project description

PyPi Version Main License: MIT Python 3.7+ Code style: black

MBRL-Lib

mbrl is a toolbox for facilitating development of Model-Based Reinforcement Learning algorithms. It provides easily interchangeable modeling and planning components, and a set of utility functions that allow writing model-based RL algorithms with only a few lines of code.

See also our companion paper.

Getting Started

Installation

Standard Installation

mbrl requires Python 3.7+ library and PyTorch (>= 1.7). To install the latest stable version, run

pip install mbrl

Developer installation

If you are interested in modifying the library, clone the repository and set up a development environment as follows

git clone https://github.com/facebookresearch/mbrl-lib.git
pip install -e ".[dev]"

And test it by running the following from the root folder of the repository

python -m pytest tests/core
python -m pytest tests/algorithms

Basic example

As a starting point, check out our tutorial notebook on how to write the PETS algorithm (Chua et al., NeurIPS 2018) using our toolbox, and running it on a continuous version of the cartpole environment.

Provided algorithm implementations

MBRL-Lib provides implementations of popular MBRL algorithms as examples of how to use this library. You can find them in the mbrl/algorithms folder. Currently, we have implemented PETS, MBPO, PlaNet, we plan to keep increasing this list in the future.

The implementations rely on Hydra to handle configuration. You can see the configuration files in this folder. The overrides subfolder contains environment specific configurations for each environment, overriding the default configurations with the best hyperparameter values we have found so far for each combination of algorithm and environment. You can run training by passing the desired override option via command line. For example, to run MBPO on the gym version of HalfCheetah, you should call

python -m mbrl.examples.main algorithm=mbpo overrides=mbpo_halfcheetah 

By default, all algorithms will save results in a csv file called results.csv, inside a folder whose path looks like ./exp/mbpo/default/gym___HalfCheetah-v2/yyyy.mm.dd/hhmmss; you can change the root directory (./exp) by passing root_dir=path-to-your-dir, and the experiment sub-folder (default) by passing experiment=your-name. The logger will also save a file called model_train.csv with training information for the dynamics model.

Beyond the override defaults, You can also change other configuration options, such as the type of dynamics model (e.g., dynamics_model=basic_ensemble), or the number of models in the ensemble (e.g., dynamics_model.model.ensemble_size=some-number). To learn more about all the available options, take a look at the provided configuration files.

Note

Running the provided examples requires Mujoco, but you can try out the library components (and algorithms) on other environments by creating your own entry script and Hydra configuration (see [examples].

If you do have a working Mujoco installation (and license), you can check that it works correctly with our library by running (also requires dm_control).

python -m pytest tests/mujoco

Visualization tools

Our library also contains a set of visualization tools, meant to facilitate diagnostics and development of models and controllers. These currently require a Mujoco installation (see previous subsection), but we are planning to add support for other environments and extensions in the future. Currently, the following tools are provided:

  • Visualizer: Creates a video to qualitatively assess model predictions over a rolling horizon. Specifically, it runs a user specified policy in a given environment, and at each time step, computes the model's predicted observation/rewards over a lookahead horizon for the same policy. The predictions are plotted as line plots, one for each observation dimension (blue lines) and reward (red line), along with the result of applying the same policy to the real environment (black lines). The model's uncertainty is visualized by plotting lines the maximum and minimum predictions at each time step. The model and policy are specified by passing directories containing configuration files for each; they can be trained independently. The following gif shows an example of 200 steps of pre-trained MBPO policy on Inverted Pendulum environment.

    Example of Visualizer

  • DatasetEvaluator: Loads a pre-trained model and a dataset (can be loaded from separate directories), and computes predictions of the model for each output dimension. The evaluator then creates a scatter plot for each dimension comparing the ground truth output vs. the model's prediction. If the model is an ensemble, the plot shows the mean prediction as well as the individual predictions of each ensemble member.

    Example of DatasetEvaluator

  • FineTuner: Can be used to train a model on a dataset produced by a given agent/controller. The model and agent can be loaded from separate directories, and the fine tuner will roll the environment for some number of steps using actions obtained from the controller. The final model and dataset will then be saved under directory "model_dir/diagnostics/subdir", where subdir is provided by the user.

  • True Dynamics Multi-CPU Controller: This script can run a trajectory optimizer agent on the true environment using Python's multiprocessing. Each environment runs in its own CPU, which can significantly speed up costly sampling algorithm such as CEM. The controller will also save a video if the render argument is passed. Below is an example on HalfCheetah-v2 using CEM for trajectory optimization.

    Control Half-Cheetah True Dynamics

  • TrainingBrowser: This script launches a lightweight training browser for plotting rewards obtained after training runs (as long as the runs use our logger). The browser allows aggregating multiple runs and displaying mean/std, and also lets the user save the image to hard drive. The legend and axes labels can be edited in the pane at the bottom left. Requires installing PyQt5. Thanks to a3ahmad for the contribution.

    Training Browser Example

Note that, except for the training browser, all the tools above require Mujoco installation and are specific to models of type OneDimTransitionRewardModel. We are planning to extend this in the future; if you have useful suggestions don't hesitate to raise an issue or submit a pull request!

Documentation

Please check out our documentation and don't hesitate to raise issues or contribute if anything is unclear!

License

mbrl is released under the MIT license. See LICENSE for additional details about it. See also our Terms of Use and Privacy Policy.

Citing

If you use this project in your research, please cite:

@Article{Pineda2021MBRL,
  author  = {Luis Pineda and Brandon Amos and Amy Zhang and Nathan O. Lambert and Roberto Calandra},
  journal = {Arxiv},
  title   = {MBRL-Lib: A Modular Library for Model-based Reinforcement Learning},
  year    = {2021},
  url     = {https://arxiv.org/abs/2104.10159},
}

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

mbrl-0.2.0.dev1.tar.gz (112.5 kB view details)

Uploaded Source

Built Distribution

mbrl-0.2.0.dev1-py3-none-any.whl (154.3 kB view details)

Uploaded Python 3

File details

Details for the file mbrl-0.2.0.dev1.tar.gz.

File metadata

  • Download URL: mbrl-0.2.0.dev1.tar.gz
  • Upload date:
  • Size: 112.5 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/3.4.2 importlib_metadata/4.8.1 pkginfo/1.7.1 requests/2.24.0 requests-toolbelt/0.9.1 tqdm/4.48.2 CPython/3.7.8

File hashes

Hashes for mbrl-0.2.0.dev1.tar.gz
Algorithm Hash digest
SHA256 2854a501593c8800186b08ba870bce8dc3cfcdb32fa7dc70b3ba4c104170bd5b
MD5 01b008087760c46cda83fd1f39ec0535
BLAKE2b-256 3b82258175c8cdab58ff16969fe8156a3d93285e99de6e044211c96fa2a18257

See more details on using hashes here.

File details

Details for the file mbrl-0.2.0.dev1-py3-none-any.whl.

File metadata

  • Download URL: mbrl-0.2.0.dev1-py3-none-any.whl
  • Upload date:
  • Size: 154.3 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/3.4.2 importlib_metadata/4.8.1 pkginfo/1.7.1 requests/2.24.0 requests-toolbelt/0.9.1 tqdm/4.48.2 CPython/3.7.8

File hashes

Hashes for mbrl-0.2.0.dev1-py3-none-any.whl
Algorithm Hash digest
SHA256 fcf5616e69499ff23c66beee7b3440e0de476566136a10d8f2bf506b3327d448
MD5 410feef758fdf387358655ca2054b0d7
BLAKE2b-256 a090454a650d4e8f580152beb6ed1282051e0de385b02687367c5837d1a1eb3f

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page