Skip to main content

Data-driven Deep Reinforcement Learning Library as an Out-of-the-box Tool

Project description

d3rlpy: A data-driven deep reinforcement learning library as an out-of-the-box tool

PyPI version format check test Documentation Status codecov Language grade: Python MIT Gitter

d3rlpy is a data-driven deep reinforcement learning library as an out-of-the-box tool.

from d3rlpy.dataset import MDPDataset
from d3rlpy.algos import CQL

# MDPDataset takes arrays of state transitions
dataset = MDPDataset(observations, actions, rewards, terminals)

# train data-driven deep RL
cql = CQL()
cql.fit(dataset.episodes)

# ready to control
actions = cql.predict(x)

Documentation: https://d3rlpy.readthedocs.io

key features

:zap: Designed for Data-Driven Deep Reinforcement Learning

d3rlpy is designed for data-driven deep reinforcement learning algorithms where the algorithm finds the good policy within the given dataset, which is suitable to tasks where online interaction is not feasible. d3rlpy also supports the conventional online training paradigm to fit in with any cases.

:beginner: Easy-To-Use API

d3rlpy provides state-of-the-art algorithms through scikit-learn style APIs without compromising flexibility that provides detailed configurations for professional users. Moreoever, d3rlpy is not just designed like scikit-learn, but also fully compatible with scikit-learn utilites.

:rocket: Beyond State-Of-The-Art

d3rlpy provides further tweeks to improve performance of state-of-the-art algorithms potentially beyond their original papers. Therefore, d3rlpy enables every user to achieve professional-level performance just in a few lines of codes.

installation

$ pip install d3rlpy

supported algorithms

algorithm discrete control continuous control data-driven RL?
Behavior Cloning (supervised learning) :white_check_mark: :white_check_mark:
Deep Q-Network (DQN) :white_check_mark: :no_entry:
Double DQN :white_check_mark: :no_entry:
Deep Deterministic Policy Gradients (DDPG) :no_entry: :white_check_mark:
Twin Delayed Deep Deterministic Policy Gradients (TD3) :no_entry: :white_check_mark:
Soft Actor-Critic (SAC) :no_entry: :white_check_mark:
Random Ensemble Mixture (REM) :construction: :no_entry: :white_check_mark:
Batch Constrained Q-learning (BCQ) :white_check_mark: :white_check_mark: :white_check_mark:
Bootstrapping Error Accumulation Reduction (BEAR) :no_entry: :white_check_mark: :white_check_mark:
Advantage-Weighted Regression (AWR) :white_check_mark: :white_check_mark: :white_check_mark:
Advantage-weighted Behavior Model (ABM) :construction: :construction: :white_check_mark:
Conservative Q-Learning (CQL) (recommended) :white_check_mark: :white_check_mark: :white_check_mark:

supported Q functions

other features

Basically, all features are available with every algorithm.

  • evaluation metrics in a scikit-learn scorer function style
  • embedded preprocessors
  • export greedy-policy as TorchScript or ONNX
  • ensemble Q function with bootstrapping
  • delayed policy updates
  • parallel cross validation with multiple GPU
  • online training
  • data augmentation
  • model-based algorithm
  • user-defined custom network

scikit-learn compatibility

This library is designed as if born from scikit-learn. You can fully utilize scikit-learn's utilities to increase your productivity.

from sklearn.model_selection import train_test_split
from d3rlpy.metrics.scorer import td_error_scorer

train_episodes, test_episodes = train_test_split(dataset)

cql.fit(train_episodes,
        eval_episodes=test_episodes,
        scorers={'td_error': td_error_scorer})

You can naturally perform cross-validation.

from sklearn.model_selection import cross_validate

scores = cross_validate(cql, dataset, scoring={'td_error': td_error_scorer})

And more.

from sklearn.model_selection import GridSearchCV

gscv = GridSearchCV(estimator=cql,
                    param_grid={'actor_learning_rate': [3e-3, 3e-4, 3e-5]},
                    scoring={'td_error': td_error_scorer},
                    refit=False)
gscv.fit(train_episodes)

examples

Atari 2600

from d3rlpy.datasets import get_atari
from d3rlpy.algos import DiscreteCQL
from d3rlpy.metrics.scorer import evaluate_on_environment
from d3rlpy.metrics.scorer import discounted_sum_of_advantage_scorer
from sklearn.model_selection import train_test_split

# get data-driven RL dataset
dataset, env = get_atari('breakout-expert-v0')

# split dataset
train_episodes, test_episodes = train_test_split(dataset, test_size=0.2)

# setup algorithm
cql = DiscreteCQL(n_epochs=100,
                  n_frames=4,
                  n_critics=3,
                  bootstrap=True,
                  q_func_type='qr',
                  scaler='pixel',
                  use_gpu=True)

# start training
cql.fit(train_episodes,
        eval_episodes=test_episodes,
        scorers={
            'environment': evaluate_on_environment(env),
            'advantage': discounted_sum_of_advantage_scorer
        })
performance demo
breakout breakout

See more Atari datasets at d4rl-atari.

PyBullet

from d3rlpy.datasets import get_pybullet
from d3rlpy.algos import CQL
from d3rlpy.metrics.scorer import evaluate_on_environment
from d3rlpy.metrics.scorer import discounted_sum_of_advantage_scorer
from sklearn.model_selection import train_test_split

# get data-driven RL dataset
dataset, env = get_pybullet('hopper-bullet-mixed-v0')

# split dataset
train_episodes, test_episodes = train_test_split(dataset, test_size=0.2)

# setup algorithm
cql = CQL(n_epochs=300,
          actor_learning_rate=1e-3,
          critic_learning_rate=1e-3,
          temp_learning_rate=1e-3,
          alpha_learning_rate=1e-3,
          n_critics=10,
          bootstrap=True,
          update_actor_interval=2,
          q_func_type='qr',
          use_gpu=True)

# start training
cql.fit(train_episodes,
        eval_episodes=test_episodes,
        scorers={
            'environment': evaluate_on_environment(env),
            'advantage': discounted_sum_of_advantage_scorer
        })
performance demo
hopper hopper

See more PyBullet datasets at d4rl-pybullet.

Online Training

import gym

from d3rlpy.algos import SAC
from d3rlpy.online.buffers import ReplayBuffer
from d3rlpy.online.iterators import train

# setup environment
env = gym.make('HopperBulletEnv-v0')
eval_env = gym.make('HopperBulletEnv-v0')

# setup algorithm
sac = SAC(n_epochs=100, use_gpu=True)

# setup replay buffer
buffer = ReplayBuffer(maxlen=1000000, env=env)

# start training
train(env, sac, buffer, eval_env=eval_env)

tutorials

Try a cartpole example on Google Colaboratory!

Open In Colab

contributions

coding style

This library is fully formatted with yapf. You can format the entire scripts as follows:

$ ./scripts/format

test

The unit tests are provided as much as possible. This repository is using pytest-cov instead of pytest. You can run the entire tests as follows:

$ ./scripts/test

If you give -p option, the performance tests with toy tasks are also run (this will take minutes).

$ ./scripts/test -p

acknowledgement

This work is supported by Information-technology Promotion Agency, Japan (IPA), Exploratory IT Human Resources Project (MITOU Program) in the fiscal year 2020.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

d3rlpy-0.23.tar.gz (62.3 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

d3rlpy-0.23-py3-none-any.whl (96.1 kB view details)

Uploaded Python 3

File details

Details for the file d3rlpy-0.23.tar.gz.

File metadata

  • Download URL: d3rlpy-0.23.tar.gz
  • Upload date:
  • Size: 62.3 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/3.2.0 pkginfo/1.5.0.1 requests/2.24.0 setuptools/49.2.0 requests-toolbelt/0.9.1 tqdm/4.47.0 CPython/3.6.1

File hashes

Hashes for d3rlpy-0.23.tar.gz
Algorithm Hash digest
SHA256 f7989ad0c00cabc6808c73f43368ecc83017a2999042bc8574d076f3fe7f62d9
MD5 807876c0d541220fe3e54e6bf5a02c05
BLAKE2b-256 1f99473a6654d2e56fff2766900323a66ed146ebcf702893a58aa7612f7870b7

See more details on using hashes here.

File details

Details for the file d3rlpy-0.23-py3-none-any.whl.

File metadata

  • Download URL: d3rlpy-0.23-py3-none-any.whl
  • Upload date:
  • Size: 96.1 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/3.2.0 pkginfo/1.5.0.1 requests/2.24.0 setuptools/49.2.0 requests-toolbelt/0.9.1 tqdm/4.47.0 CPython/3.6.1

File hashes

Hashes for d3rlpy-0.23-py3-none-any.whl
Algorithm Hash digest
SHA256 e79954caf3f4b4ffcf86892f5113624ce3716df49246a13f8c3c992b06207c11
MD5 9552e8aa600d61ec0800d7597f681111
BLAKE2b-256 b81c1c4fd2422c1ec5b09ea562a66ab902c1782a7eb9ba3911b802bb9ac905e2

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page