Skip to main content

Data-driven Deep Reinforcement Learning Library as an Out-of-the-box Tool

Project description

d3rlpy: A data-driven deep reinforcement learning library as an out-of-the-box tool

test build Documentation Status codecov Language grade: Python Gitter MIT

d3rlpy is a data-driven deep reinforcement learning library as an out-of-the-box tool.

from d3rlpy.dataset import MDPDataset
from d3rlpy.algos import CQL

# MDPDataset takes arrays of state transitions
dataset = MDPDataset(observations, actions, rewards, terminals)

# train data-driven deep RL
cql = CQL()
cql.fit(dataset.episodes)

# ready to control
actions = cql.predict(x)

Documentation: https://d3rlpy.readthedocs.io

key features

:zap: Designed for Data-Driven Deep Reinforcement Learning

d3rlpy is designed for data-driven deep reinforcement learning algorithms where the algorithm finds the good policy within the given dataset, which is suitable to tasks where online interaction is not feasible. d3rlpy also supports the conventional online training paradigm to fit in with any cases.

:beginner: Easy-To-Use API

d3rlpy provides state-of-the-art algorithms through scikit-learn style APIs without compromising flexibility that provides detailed configurations for professional users. Moreoever, d3rlpy is not just designed like scikit-learn, but also fully compatible with scikit-learn utilites.

:rocket: Beyond State-Of-The-Art

d3rlpy provides further tweeks to improve performance of state-of-the-art algorithms potentially beyond their original papers. Therefore, d3rlpy enables every user to achieve professional-level performance just in a few lines of codes.

installation

d3rlpy supports Linux, macOS and Windows.

PyPI

PyPI version PyPI - Downloads

$ pip install d3rlpy

Anaconda

Anaconda-Server Badge Anaconda-Server Badge Anaconda-Server Badge

$ conda install -c conda-forge d3rlpy

supported algorithms

algorithm discrete control continuous control data-driven RL?
Behavior Cloning (supervised learning) :white_check_mark: :white_check_mark:
Deep Q-Network (DQN) :white_check_mark: :no_entry:
Double DQN :white_check_mark: :no_entry:
Deep Deterministic Policy Gradients (DDPG) :no_entry: :white_check_mark:
Twin Delayed Deep Deterministic Policy Gradients (TD3) :no_entry: :white_check_mark:
Soft Actor-Critic (SAC) :white_check_mark: :white_check_mark:
Batch Constrained Q-learning (BCQ) :white_check_mark: :white_check_mark: :white_check_mark:
Bootstrapping Error Accumulation Reduction (BEAR) :no_entry: :white_check_mark: :white_check_mark:
Advantage-Weighted Regression (AWR) :white_check_mark: :white_check_mark: :white_check_mark:
Conservative Q-Learning (CQL) (recommended) :white_check_mark: :white_check_mark: :white_check_mark:
Advantage Weighted Actor-Critic (AWAC) :no_entry: :white_check_mark: :white_check_mark:
Policy in Latent Action Space (PLAS) :no_entry: :white_check_mark: :white_check_mark:

supported Q functions

other features

Basically, all features are available with every algorithm.

  • evaluation metrics in a scikit-learn scorer function style
  • export greedy-policy as TorchScript or ONNX
  • ensemble Q function
  • N-step TD calculation
  • parallel cross validation with multiple GPU
  • online training
  • data augmentation
  • model-based algorithm

examples

Atari 2600

from d3rlpy.datasets import get_atari
from d3rlpy.algos import DiscreteCQL
from d3rlpy.metrics.scorer import evaluate_on_environment
from d3rlpy.metrics.scorer import discounted_sum_of_advantage_scorer
from sklearn.model_selection import train_test_split

# get data-driven RL dataset
dataset, env = get_atari('breakout-expert-v0')

# split dataset
train_episodes, test_episodes = train_test_split(dataset, test_size=0.2)

# setup algorithm
cql = DiscreteCQL(n_frames=4, q_func_factory='qr', scaler='pixel', use_gpu=True)

# start training
cql.fit(train_episodes,
        eval_episodes=test_episodes,
        n_epochs=100,
        scorers={
            'environment': evaluate_on_environment(env),
            'advantage': discounted_sum_of_advantage_scorer
        })

See more Atari datasets at d4rl-atari.

PyBullet

from d3rlpy.datasets import get_pybullet
from d3rlpy.algos import CQL
from d3rlpy.metrics.scorer import evaluate_on_environment
from d3rlpy.metrics.scorer import discounted_sum_of_advantage_scorer
from sklearn.model_selection import train_test_split

# get data-driven RL dataset
dataset, env = get_pybullet('hopper-bullet-mixed-v0')

# split dataset
train_episodes, test_episodes = train_test_split(dataset, test_size=0.2)

# setup algorithm
cql = CQL(q_func_factory='qr', use_gpu=True)

# start training
cql.fit(train_episodes,
        eval_episodes=test_episodes,
        n_epochs=300,
        scorers={
            'environment': evaluate_on_environment(env),
            'advantage': discounted_sum_of_advantage_scorer
        })

See more PyBullet datasets at d4rl-pybullet.

Online Training

import gym

from d3rlpy.algos import SAC
from d3rlpy.online.buffers import ReplayBuffer

# setup environment
env = gym.make('HopperBulletEnv-v0')
eval_env = gym.make('HopperBulletEnv-v0')

# setup algorithm
sac = SAC(use_gpu=True)

# setup replay buffer
buffer = ReplayBuffer(maxlen=1000000, env=env)

# start training
sac.fit_online(env, buffer, n_steps=1000000, eval_env=eval_env)

tutorials

Try a cartpole example on Google Colaboratory!

Open In Colab

scikit-learn compatibility

This library is designed as if born from scikit-learn. You can fully utilize scikit-learn's utilities to increase your productivity.

from sklearn.model_selection import train_test_split
from d3rlpy.metrics.scorer import td_error_scorer

train_episodes, test_episodes = train_test_split(dataset)

cql.fit(train_episodes,
        eval_episodes=test_episodes,
        scorers={'td_error': td_error_scorer})

You can naturally perform cross-validation.

from sklearn.model_selection import cross_validate

scores = cross_validate(cql, dataset, scoring={'td_error': td_error_scorer})

And more.

from sklearn.model_selection import GridSearchCV

gscv = GridSearchCV(estimator=cql,
                    param_grid={'actor_learning_rate': [3e-3, 3e-4, 3e-5]},
                    scoring={'td_error': td_error_scorer},
                    refit=False)
gscv.fit(train_episodes)

MDPDataset

d3rlpy introduces MDPDataset, a convenient data structure for reinforcement learning. MDPDataset splits sequential data into transitions that includes a tuple of data observed at t and t+1, which is usually used for training.

from d3rlpy.dataset import MDPDataset

# offline data
observations = np.random.random((100000, 100)) # 100-dim feature observations
actions = np.random.random((100000, 4)) # 4-dim continuous actions
rewards = np.random.random(100000)
terminals = np.random.randint(2, size=100000)

# builds MDPDataset from offline data
dataset = MDPDataset(observations, actions, rewards, terminals)

# splits offline data into episodes
dataset.episodes[0].observations

# splits episodes into transitions
dataset.episodes[0].transitions[0].observation
dataset.episodes[0].transitions[0].action
dataset.episodes[0].transitions[0].next_reward
dataset.episodes[0].transitions[0].next_observation
dataset.episodes[0].transitions[0].terminal

TransitionMiniBatch is also a convenient class to make a mini-batch of sampled transitions. And, memory copies done in TransitionMiniBatch are implemented with Cython, which provides extremely fast computation.

from random import sample
from d3rlpy.dataset import TransitionMiniBatch

transitions = sample(dataset.episodes[0].transitions, 100)

# fast batching up with efficient-memory copy
batch = TransitionMiniBatch(transitions)

batch.observations.shape == (100, 100)

One more interesting feature in the dataset structure is that each transition has pointers to its next and previous transition. This feature enables JIT frame stacking just as serveral works do with Atari tasks, which is also implemented with Cython for reducing bottlenecks.

observations = np.random.randint(256, size=(100000, 1, 84, 84), dtype=np.uint8) # 1x84x84 pixel images
actions = np.random.randint(4, size=100000) # discrete actions with 4 options
rewards = np.random.random(100000)
terminals = np.random.randint(2, size=100000)

# builds MDPDataset from offline data
dataset = MDPDataset(observations, actions, rewards, terminals, discrete_action=True)

# samples transitions
transitions = sample(dataset.episodes[0].transitions, 32)

# makes mini-batch with frame stacking
batch = TransitionMiniBatch(transitions, n_frames=4)

batch.observations.shape == (32, 4, 84, 84)

Finally, TransitionMiniBatch also supports N-step TD backup, which is also efficiently done with Cython.

batch = TransitionMiniBatch(transitions, n_steps=1, gamma=0.99)

# the number of steps before next observations at each batch index
batch.n_steps.shape == (32, 1)

# N step after batch.observations
batch.next_observations

# N step after batch.actions
batch.next_actions

# N-step return
batch.next_rewards

contributions

coding style

This library is fully formatted with black and yapf. You can format the entire scripts as follows:

$ ./scripts/format

linter

This library is analyzed by mypy and pylint. You can check the code structures as follows:

$ ./scripts/lint

test

The unit tests are provided as much as possible. This repository is using pytest-cov instead of pytest. You can run the entire tests as follows:

$ ./scripts/test

If you give -p option, the performance tests with toy tasks are also run (this will take minutes).

$ ./scripts/test -p

citation

@misc{seno2020d3rlpy,
  author = {Takuma Seno},
  title = {d3rlpy: A data-driven deep reinforcement library as an out-of-the-box tool},
  year = {2020},
  publisher = {GitHub},
  journal = {GitHub repository},
  howpublished = {\url{https://github.com/takuseno/d3rlpy}}
}

acknowledgement

This work is supported by Information-technology Promotion Agency, Japan (IPA), Exploratory IT Human Resources Project (MITOU Program) in the fiscal year 2020.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

d3rlpy-0.50.tar.gz (297.6 kB view details)

Uploaded Source

Built Distributions

If you're not sure about the file name format, learn more about wheel file names.

d3rlpy-0.50-cp38-cp38-win_amd64.whl (298.2 kB view details)

Uploaded CPython 3.8Windows x86-64

d3rlpy-0.50-cp38-cp38-manylinux1_x86_64.whl (972.6 kB view details)

Uploaded CPython 3.8

d3rlpy-0.50-cp38-cp38-macosx_10_14_x86_64.whl (336.9 kB view details)

Uploaded CPython 3.8macOS 10.14+ x86-64

d3rlpy-0.50-cp37-cp37m-win_amd64.whl (292.6 kB view details)

Uploaded CPython 3.7mWindows x86-64

d3rlpy-0.50-cp37-cp37m-manylinux1_x86_64.whl (892.2 kB view details)

Uploaded CPython 3.7m

d3rlpy-0.50-cp37-cp37m-macosx_10_14_x86_64.whl (335.5 kB view details)

Uploaded CPython 3.7mmacOS 10.14+ x86-64

d3rlpy-0.50-cp36-cp36m-win_amd64.whl (292.5 kB view details)

Uploaded CPython 3.6mWindows x86-64

d3rlpy-0.50-cp36-cp36m-manylinux1_x86_64.whl (893.6 kB view details)

Uploaded CPython 3.6m

d3rlpy-0.50-cp36-cp36m-macosx_10_14_x86_64.whl (342.5 kB view details)

Uploaded CPython 3.6mmacOS 10.14+ x86-64

File details

Details for the file d3rlpy-0.50.tar.gz.

File metadata

  • Download URL: d3rlpy-0.50.tar.gz
  • Upload date:
  • Size: 297.6 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/3.3.0 pkginfo/1.6.1 requests/2.25.1 setuptools/41.2.0 requests-toolbelt/0.9.1 tqdm/4.47.0 CPython/3.7.7

File hashes

Hashes for d3rlpy-0.50.tar.gz
Algorithm Hash digest
SHA256 8d5d9117a351f8b27945f0f6f2f82e3db435e32854b5b8b9c9e7d0565bbbe820
MD5 ca92409ba97a281121e5e0d02ae61ffa
BLAKE2b-256 07c0e8b8bb7c36ac4a100937490de704d118b56797067663e2d65934aa3a8ab3

See more details on using hashes here.

File details

Details for the file d3rlpy-0.50-cp38-cp38-win_amd64.whl.

File metadata

  • Download URL: d3rlpy-0.50-cp38-cp38-win_amd64.whl
  • Upload date:
  • Size: 298.2 kB
  • Tags: CPython 3.8, Windows x86-64
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/3.3.0 pkginfo/1.6.1 requests/2.25.1 setuptools/41.2.0 requests-toolbelt/0.9.1 tqdm/4.47.0 CPython/3.7.7

File hashes

Hashes for d3rlpy-0.50-cp38-cp38-win_amd64.whl
Algorithm Hash digest
SHA256 1cb9a7fbf6a68aa74ec4aca9fd84396d933a7a6b2b83c04caa3ed79f71e2b80b
MD5 0282326cfc5b48e62e2d58b2df8149be
BLAKE2b-256 87a31f6d3df5b732eb76291fbe3e589432c18f1953f3edbd6e8c83678b6df513

See more details on using hashes here.

File details

Details for the file d3rlpy-0.50-cp38-cp38-manylinux1_x86_64.whl.

File metadata

  • Download URL: d3rlpy-0.50-cp38-cp38-manylinux1_x86_64.whl
  • Upload date:
  • Size: 972.6 kB
  • Tags: CPython 3.8
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/3.3.0 pkginfo/1.6.1 requests/2.25.1 setuptools/41.2.0 requests-toolbelt/0.9.1 tqdm/4.47.0 CPython/3.7.7

File hashes

Hashes for d3rlpy-0.50-cp38-cp38-manylinux1_x86_64.whl
Algorithm Hash digest
SHA256 fdb02bf9789bb3460666260e5768c0095ced60e3b6b464b3f95060db8ebd8fb8
MD5 e6f7e8de9f9a8ad1a0078c25ebc36a11
BLAKE2b-256 79a586c09a0a6da3dfb7bd788f90d1a5e59ae66c6de9f4a67cbbee487fcd60bd

See more details on using hashes here.

File details

Details for the file d3rlpy-0.50-cp38-cp38-macosx_10_14_x86_64.whl.

File metadata

  • Download URL: d3rlpy-0.50-cp38-cp38-macosx_10_14_x86_64.whl
  • Upload date:
  • Size: 336.9 kB
  • Tags: CPython 3.8, macOS 10.14+ x86-64
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/3.3.0 pkginfo/1.6.1 requests/2.25.1 setuptools/41.2.0 requests-toolbelt/0.9.1 tqdm/4.47.0 CPython/3.7.7

File hashes

Hashes for d3rlpy-0.50-cp38-cp38-macosx_10_14_x86_64.whl
Algorithm Hash digest
SHA256 34017a740885bc594507401e63637bfc66c4d814dddaab27750c38978c17f7ba
MD5 8a5b0d59e0946272c30fba58a6886093
BLAKE2b-256 f8ca225702d6ec10dfbe860b2b10829facf67fac47b55e066822fbc84c6535d2

See more details on using hashes here.

File details

Details for the file d3rlpy-0.50-cp37-cp37m-win_amd64.whl.

File metadata

  • Download URL: d3rlpy-0.50-cp37-cp37m-win_amd64.whl
  • Upload date:
  • Size: 292.6 kB
  • Tags: CPython 3.7m, Windows x86-64
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/3.3.0 pkginfo/1.6.1 requests/2.25.1 setuptools/41.2.0 requests-toolbelt/0.9.1 tqdm/4.47.0 CPython/3.7.7

File hashes

Hashes for d3rlpy-0.50-cp37-cp37m-win_amd64.whl
Algorithm Hash digest
SHA256 b18ddfad644afba58628cc8506e0f12767445a8684ab8915e8ac0bcf623a8dac
MD5 b9ce0d93d4fb06845f3a6a77b3b7181c
BLAKE2b-256 67f6566850bb99fce1f076bfeed522f49ddb95312858cfe1350832e430c97cb6

See more details on using hashes here.

File details

Details for the file d3rlpy-0.50-cp37-cp37m-manylinux1_x86_64.whl.

File metadata

  • Download URL: d3rlpy-0.50-cp37-cp37m-manylinux1_x86_64.whl
  • Upload date:
  • Size: 892.2 kB
  • Tags: CPython 3.7m
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/3.3.0 pkginfo/1.6.1 requests/2.25.1 setuptools/41.2.0 requests-toolbelt/0.9.1 tqdm/4.47.0 CPython/3.7.7

File hashes

Hashes for d3rlpy-0.50-cp37-cp37m-manylinux1_x86_64.whl
Algorithm Hash digest
SHA256 d974ea9403afbc34b10bdc796c1021c0589083dd2062a76ddbb8a14b7e31e87f
MD5 da9848a56f101d1c2b9317322d64a25c
BLAKE2b-256 a5b9cef3c2693371c73dc040a51599ca4257b9b589221098755a72cccae77546

See more details on using hashes here.

File details

Details for the file d3rlpy-0.50-cp37-cp37m-macosx_10_14_x86_64.whl.

File metadata

  • Download URL: d3rlpy-0.50-cp37-cp37m-macosx_10_14_x86_64.whl
  • Upload date:
  • Size: 335.5 kB
  • Tags: CPython 3.7m, macOS 10.14+ x86-64
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/3.3.0 pkginfo/1.6.1 requests/2.25.1 setuptools/41.2.0 requests-toolbelt/0.9.1 tqdm/4.47.0 CPython/3.7.7

File hashes

Hashes for d3rlpy-0.50-cp37-cp37m-macosx_10_14_x86_64.whl
Algorithm Hash digest
SHA256 3d195883668274b458c1116712f8edf811b0c4e3d807e15b75ec7aa852e484d0
MD5 9dc5a87b2b72c6d1a599700fe35c8779
BLAKE2b-256 9286ba180d49302ddefe5ec26cc2dda0a6f0696fa1a4870104599bea168d836c

See more details on using hashes here.

File details

Details for the file d3rlpy-0.50-cp36-cp36m-win_amd64.whl.

File metadata

  • Download URL: d3rlpy-0.50-cp36-cp36m-win_amd64.whl
  • Upload date:
  • Size: 292.5 kB
  • Tags: CPython 3.6m, Windows x86-64
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/3.3.0 pkginfo/1.6.1 requests/2.25.1 setuptools/41.2.0 requests-toolbelt/0.9.1 tqdm/4.47.0 CPython/3.7.7

File hashes

Hashes for d3rlpy-0.50-cp36-cp36m-win_amd64.whl
Algorithm Hash digest
SHA256 fc773d7193747dce450ed7a258392feba95b85000f53bf3b5403ecec4ae62ab0
MD5 197b29598b6c68d95033afbe3be92c56
BLAKE2b-256 aa679513f6f0f21f9fc9a2ee4deccb816d2cc43caefe77b3dc77c8779e1169ae

See more details on using hashes here.

File details

Details for the file d3rlpy-0.50-cp36-cp36m-manylinux1_x86_64.whl.

File metadata

  • Download URL: d3rlpy-0.50-cp36-cp36m-manylinux1_x86_64.whl
  • Upload date:
  • Size: 893.6 kB
  • Tags: CPython 3.6m
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/3.3.0 pkginfo/1.6.1 requests/2.25.1 setuptools/41.2.0 requests-toolbelt/0.9.1 tqdm/4.47.0 CPython/3.7.7

File hashes

Hashes for d3rlpy-0.50-cp36-cp36m-manylinux1_x86_64.whl
Algorithm Hash digest
SHA256 b8e82478b3cdd4ce40ee4bb2328180a26d3b1dfa9e6fc8244bf459295d47da0c
MD5 7a060b88a54df0540ef127f47e226e37
BLAKE2b-256 0566bc2f8b6342395e930982948254c219e35a8648035e9673f8fcaabc044f71

See more details on using hashes here.

File details

Details for the file d3rlpy-0.50-cp36-cp36m-macosx_10_14_x86_64.whl.

File metadata

  • Download URL: d3rlpy-0.50-cp36-cp36m-macosx_10_14_x86_64.whl
  • Upload date:
  • Size: 342.5 kB
  • Tags: CPython 3.6m, macOS 10.14+ x86-64
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/3.3.0 pkginfo/1.6.1 requests/2.25.1 setuptools/41.2.0 requests-toolbelt/0.9.1 tqdm/4.47.0 CPython/3.7.7

File hashes

Hashes for d3rlpy-0.50-cp36-cp36m-macosx_10_14_x86_64.whl
Algorithm Hash digest
SHA256 851dd680f590ccefb165f4c579d22f302ab846e9828c6856995b8b7b2fb095bc
MD5 7d28b5299b438a42955ed836f51fff9d
BLAKE2b-256 0113d39a121ecd0d1b310c308fbc42f134ebd88266aa6e77199aa1c7451e953b

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page