Skip to main content

Data-driven Deep Reinforcement Learning Library as an Out-of-the-box Tool

Project description

d3rlpy: A data-driven deep reinforcement learning library as an out-of-the-box tool

test build Documentation Status codecov Maintainability Gitter MIT

d3rlpy is a data-driven deep reinforcement learning library as an out-of-the-box tool.

from d3rlpy.dataset import MDPDataset
from d3rlpy.algos import CQL

# MDPDataset takes arrays of state transitions
dataset = MDPDataset(observations, actions, rewards, terminals)

# train data-driven deep RL
cql = CQL()
cql.fit(dataset.episodes)

# ready to control
actions = cql.predict(x)

Documentation: https://d3rlpy.readthedocs.io

key features

:zap: Designed for Data-Driven Deep Reinforcement Learning

d3rlpy is designed for data-driven deep reinforcement learning algorithms where the algorithm finds the good policy within the given dataset, which is suitable to tasks where online interaction is not feasible. d3rlpy also supports the conventional online training paradigm to fit in with any cases.

:beginner: Easy-To-Use API

d3rlpy provides state-of-the-art algorithms through scikit-learn style APIs without compromising flexibility that provides detailed configurations for professional users. Moreoever, d3rlpy is not just designed like scikit-learn, but also fully compatible with scikit-learn utilites.

:rocket: Beyond State-Of-The-Art

d3rlpy provides further tweeks to improve performance of state-of-the-art algorithms potentially beyond their original papers. Therefore, d3rlpy enables every user to achieve professional-level performance just in a few lines of codes.

installation

d3rlpy supports Linux, macOS and Windows.

PyPI

PyPI version PyPI - Downloads

$ pip install d3rlpy

Anaconda

Anaconda-Server Badge Anaconda-Server Badge Anaconda-Server Badge

$ conda install -c conda-forge d3rlpy

supported algorithms

algorithm discrete control continuous control data-driven RL?
Behavior Cloning (supervised learning) :white_check_mark: :white_check_mark:
Deep Q-Network (DQN) :white_check_mark: :no_entry:
Double DQN :white_check_mark: :no_entry:
Deep Deterministic Policy Gradients (DDPG) :no_entry: :white_check_mark:
Twin Delayed Deep Deterministic Policy Gradients (TD3) :no_entry: :white_check_mark:
Soft Actor-Critic (SAC) :white_check_mark: :white_check_mark:
Batch Constrained Q-learning (BCQ) :white_check_mark: :white_check_mark: :white_check_mark:
Bootstrapping Error Accumulation Reduction (BEAR) :no_entry: :white_check_mark: :white_check_mark:
Advantage-Weighted Regression (AWR) :white_check_mark: :white_check_mark: :white_check_mark:
Conservative Q-Learning (CQL) (recommended) :white_check_mark: :white_check_mark: :white_check_mark:
Advantage Weighted Actor-Critic (AWAC) :no_entry: :white_check_mark: :white_check_mark:
Policy in Latent Action Space (PLAS) :no_entry: :white_check_mark: :white_check_mark:

supported Q functions

other features

Basically, all features are available with every algorithm.

  • evaluation metrics in a scikit-learn scorer function style
  • export greedy-policy as TorchScript or ONNX
  • ensemble Q function
  • N-step TD calculation
  • parallel cross validation with multiple GPU
  • online training
  • data augmentation
  • model-based algorithm

examples

Atari 2600

from d3rlpy.datasets import get_atari
from d3rlpy.algos import DiscreteCQL
from d3rlpy.metrics.scorer import evaluate_on_environment
from d3rlpy.metrics.scorer import discounted_sum_of_advantage_scorer
from sklearn.model_selection import train_test_split

# get data-driven RL dataset
dataset, env = get_atari('breakout-expert-v0')

# split dataset
train_episodes, test_episodes = train_test_split(dataset, test_size=0.2)

# setup algorithm
cql = DiscreteCQL(n_frames=4, q_func_factory='qr', scaler='pixel', use_gpu=True)

# start training
cql.fit(train_episodes,
        eval_episodes=test_episodes,
        n_epochs=100,
        scorers={
            'environment': evaluate_on_environment(env),
            'advantage': discounted_sum_of_advantage_scorer
        })

See more Atari datasets at d4rl-atari.

PyBullet

from d3rlpy.datasets import get_pybullet
from d3rlpy.algos import CQL
from d3rlpy.metrics.scorer import evaluate_on_environment
from d3rlpy.metrics.scorer import discounted_sum_of_advantage_scorer
from sklearn.model_selection import train_test_split

# get data-driven RL dataset
dataset, env = get_pybullet('hopper-bullet-mixed-v0')

# split dataset
train_episodes, test_episodes = train_test_split(dataset, test_size=0.2)

# setup algorithm
cql = CQL(q_func_factory='qr', use_gpu=True)

# start training
cql.fit(train_episodes,
        eval_episodes=test_episodes,
        n_epochs=300,
        scorers={
            'environment': evaluate_on_environment(env),
            'advantage': discounted_sum_of_advantage_scorer
        })

See more PyBullet datasets at d4rl-pybullet.

Online Training

import gym

from d3rlpy.algos import SAC
from d3rlpy.online.buffers import ReplayBuffer

# setup environment
env = gym.make('HopperBulletEnv-v0')
eval_env = gym.make('HopperBulletEnv-v0')

# setup algorithm
sac = SAC(use_gpu=True)

# setup replay buffer
buffer = ReplayBuffer(maxlen=1000000, env=env)

# start training
sac.fit_online(env, buffer, n_steps=1000000, eval_env=eval_env)

tutorials

Try a cartpole example on Google Colaboratory!

Open In Colab

scikit-learn compatibility

This library is designed as if born from scikit-learn. You can fully utilize scikit-learn's utilities to increase your productivity.

from sklearn.model_selection import train_test_split
from d3rlpy.metrics.scorer import td_error_scorer

train_episodes, test_episodes = train_test_split(dataset)

cql.fit(train_episodes,
        eval_episodes=test_episodes,
        scorers={'td_error': td_error_scorer})

You can naturally perform cross-validation.

from sklearn.model_selection import cross_validate

scores = cross_validate(cql, dataset, scoring={'td_error': td_error_scorer})

And more.

from sklearn.model_selection import GridSearchCV

gscv = GridSearchCV(estimator=cql,
                    param_grid={'actor_learning_rate': [3e-3, 3e-4, 3e-5]},
                    scoring={'td_error': td_error_scorer},
                    refit=False)
gscv.fit(train_episodes)

MDPDataset

d3rlpy introduces MDPDataset, a convenient data structure for reinforcement learning. MDPDataset splits sequential data into transitions that includes a tuple of data observed at t and t+1, which is usually used for training.

from d3rlpy.dataset import MDPDataset

# offline data
observations = np.random.random((100000, 100)) # 100-dim feature observations
actions = np.random.random((100000, 4)) # 4-dim continuous actions
rewards = np.random.random(100000)
terminals = np.random.randint(2, size=100000)

# builds MDPDataset from offline data
dataset = MDPDataset(observations, actions, rewards, terminals)

# splits offline data into episodes
dataset.episodes[0].observations

# splits episodes into transitions
dataset.episodes[0].transitions[0].observation
dataset.episodes[0].transitions[0].action
dataset.episodes[0].transitions[0].next_reward
dataset.episodes[0].transitions[0].next_observation
dataset.episodes[0].transitions[0].terminal

TransitionMiniBatch is also a convenient class to make a mini-batch of sampled transitions. And, memory copies done in TransitionMiniBatch are implemented with Cython, which provides extremely fast computation.

from random import sample
from d3rlpy.dataset import TransitionMiniBatch

transitions = sample(dataset.episodes[0].transitions, 100)

# fast batching up with efficient-memory copy
batch = TransitionMiniBatch(transitions)

batch.observations.shape == (100, 100)

One more interesting feature in the dataset structure is that each transition has pointers to its next and previous transition. This feature enables JIT frame stacking just as serveral works do with Atari tasks, which is also implemented with Cython for reducing bottlenecks.

observations = np.random.randint(256, size=(100000, 1, 84, 84), dtype=np.uint8) # 1x84x84 pixel images
actions = np.random.randint(4, size=100000) # discrete actions with 4 options
rewards = np.random.random(100000)
terminals = np.random.randint(2, size=100000)

# builds MDPDataset from offline data
dataset = MDPDataset(observations, actions, rewards, terminals, discrete_action=True)

# samples transitions
transitions = sample(dataset.episodes[0].transitions, 32)

# makes mini-batch with frame stacking
batch = TransitionMiniBatch(transitions, n_frames=4)

batch.observations.shape == (32, 4, 84, 84)

Finally, TransitionMiniBatch also supports N-step TD backup, which is also efficiently done with Cython.

batch = TransitionMiniBatch(transitions, n_steps=1, gamma=0.99)

# the number of steps before next observations at each batch index
batch.n_steps.shape == (32, 1)

# N step after batch.observations
batch.next_observations

# N step after batch.actions
batch.next_actions

# N-step return
batch.next_rewards

contributions

coding style

This library is fully formatted with black and yapf. You can format the entire scripts as follows:

$ ./scripts/format

linter

This library is analyzed by mypy and pylint. You can check the code structures as follows:

$ ./scripts/lint

test

The unit tests are provided as much as possible. This repository is using pytest-cov instead of pytest. You can run the entire tests as follows:

$ ./scripts/test

If you give -p option, the performance tests with toy tasks are also run (this will take minutes).

$ ./scripts/test -p

citation

@misc{seno2020d3rlpy,
  author = {Takuma Seno},
  title = {d3rlpy: A data-driven deep reinforcement library as an out-of-the-box tool},
  year = {2020},
  publisher = {GitHub},
  journal = {GitHub repository},
  howpublished = {\url{https://github.com/takuseno/d3rlpy}}
}

acknowledgement

This work is supported by Information-technology Promotion Agency, Japan (IPA), Exploratory IT Human Resources Project (MITOU Program) in the fiscal year 2020.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

d3rlpy-0.51.tar.gz (297.6 kB view details)

Uploaded Source

Built Distributions

If you're not sure about the file name format, learn more about wheel file names.

d3rlpy-0.51-cp38-cp38-win_amd64.whl (298.2 kB view details)

Uploaded CPython 3.8Windows x86-64

d3rlpy-0.51-cp38-cp38-manylinux1_x86_64.whl (972.6 kB view details)

Uploaded CPython 3.8

d3rlpy-0.51-cp38-cp38-macosx_10_14_x86_64.whl (336.9 kB view details)

Uploaded CPython 3.8macOS 10.14+ x86-64

d3rlpy-0.51-cp37-cp37m-win_amd64.whl (292.7 kB view details)

Uploaded CPython 3.7mWindows x86-64

d3rlpy-0.51-cp37-cp37m-manylinux1_x86_64.whl (892.2 kB view details)

Uploaded CPython 3.7m

d3rlpy-0.51-cp37-cp37m-macosx_10_14_x86_64.whl (335.5 kB view details)

Uploaded CPython 3.7mmacOS 10.14+ x86-64

d3rlpy-0.51-cp36-cp36m-win_amd64.whl (292.5 kB view details)

Uploaded CPython 3.6mWindows x86-64

d3rlpy-0.51-cp36-cp36m-manylinux1_x86_64.whl (893.6 kB view details)

Uploaded CPython 3.6m

d3rlpy-0.51-cp36-cp36m-macosx_10_14_x86_64.whl (342.6 kB view details)

Uploaded CPython 3.6mmacOS 10.14+ x86-64

File details

Details for the file d3rlpy-0.51.tar.gz.

File metadata

  • Download URL: d3rlpy-0.51.tar.gz
  • Upload date:
  • Size: 297.6 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/3.3.0 pkginfo/1.6.1 requests/2.25.1 setuptools/41.2.0 requests-toolbelt/0.9.1 tqdm/4.47.0 CPython/3.7.7

File hashes

Hashes for d3rlpy-0.51.tar.gz
Algorithm Hash digest
SHA256 8b4ce377ac02a1f66fc2cd589ef50f53122026e2752e7ccb250ac70b915fd0f5
MD5 78e32f846503879efaea770b0045aad4
BLAKE2b-256 cdba0bf07111a80047c59755f1bb2362196fd9aa28793d8ff595fcf3e877ee73

See more details on using hashes here.

File details

Details for the file d3rlpy-0.51-cp38-cp38-win_amd64.whl.

File metadata

  • Download URL: d3rlpy-0.51-cp38-cp38-win_amd64.whl
  • Upload date:
  • Size: 298.2 kB
  • Tags: CPython 3.8, Windows x86-64
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/3.3.0 pkginfo/1.6.1 requests/2.25.1 setuptools/41.2.0 requests-toolbelt/0.9.1 tqdm/4.47.0 CPython/3.7.7

File hashes

Hashes for d3rlpy-0.51-cp38-cp38-win_amd64.whl
Algorithm Hash digest
SHA256 5fd369761b1cbf80edbf930ee68f6101aa37e86b14c34e85b36711dd0d17d7d9
MD5 d7a5e01058276a9b7a0572d8b197634e
BLAKE2b-256 41b9df415860143f16bd2623d356732bc5fef608184dbc632754d9921f6ac66c

See more details on using hashes here.

File details

Details for the file d3rlpy-0.51-cp38-cp38-manylinux1_x86_64.whl.

File metadata

  • Download URL: d3rlpy-0.51-cp38-cp38-manylinux1_x86_64.whl
  • Upload date:
  • Size: 972.6 kB
  • Tags: CPython 3.8
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/3.3.0 pkginfo/1.6.1 requests/2.25.1 setuptools/41.2.0 requests-toolbelt/0.9.1 tqdm/4.47.0 CPython/3.7.7

File hashes

Hashes for d3rlpy-0.51-cp38-cp38-manylinux1_x86_64.whl
Algorithm Hash digest
SHA256 39cb6f1550398af15e15e227bdfd7f43d84b12bb8ba48a13572ea7e19d345dca
MD5 234e19a5093f0e8eb51f2e60e92e1eee
BLAKE2b-256 e04461d7ec454a753c8d24be08dc6ff20d6fe806fc74454f59ff20bef844d2b7

See more details on using hashes here.

File details

Details for the file d3rlpy-0.51-cp38-cp38-macosx_10_14_x86_64.whl.

File metadata

  • Download URL: d3rlpy-0.51-cp38-cp38-macosx_10_14_x86_64.whl
  • Upload date:
  • Size: 336.9 kB
  • Tags: CPython 3.8, macOS 10.14+ x86-64
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/3.3.0 pkginfo/1.6.1 requests/2.25.1 setuptools/41.2.0 requests-toolbelt/0.9.1 tqdm/4.47.0 CPython/3.7.7

File hashes

Hashes for d3rlpy-0.51-cp38-cp38-macosx_10_14_x86_64.whl
Algorithm Hash digest
SHA256 6773a91b13ec291e23afc6897ab3a45d8fd3cff59c7a0361fc8928b21a75aa40
MD5 a753cf9c2b9cda509f41ee256b8ab958
BLAKE2b-256 22f9176687e2d1eabb16582e9d9e0ffca97ef4fb381be97082393a989a1eebe7

See more details on using hashes here.

File details

Details for the file d3rlpy-0.51-cp37-cp37m-win_amd64.whl.

File metadata

  • Download URL: d3rlpy-0.51-cp37-cp37m-win_amd64.whl
  • Upload date:
  • Size: 292.7 kB
  • Tags: CPython 3.7m, Windows x86-64
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/3.3.0 pkginfo/1.6.1 requests/2.25.1 setuptools/41.2.0 requests-toolbelt/0.9.1 tqdm/4.47.0 CPython/3.7.7

File hashes

Hashes for d3rlpy-0.51-cp37-cp37m-win_amd64.whl
Algorithm Hash digest
SHA256 c5c7f9da065259c067d094363e997cffadf48da29ed32460e81cd6ab59460526
MD5 22b6d88adb7ce6a9cf3187760689d393
BLAKE2b-256 4bcb61f5881cb0c3fe0afaf44ac26b4e10f07b4dd5daceeeb8d3dfd5e8983506

See more details on using hashes here.

File details

Details for the file d3rlpy-0.51-cp37-cp37m-manylinux1_x86_64.whl.

File metadata

  • Download URL: d3rlpy-0.51-cp37-cp37m-manylinux1_x86_64.whl
  • Upload date:
  • Size: 892.2 kB
  • Tags: CPython 3.7m
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/3.3.0 pkginfo/1.6.1 requests/2.25.1 setuptools/41.2.0 requests-toolbelt/0.9.1 tqdm/4.47.0 CPython/3.7.7

File hashes

Hashes for d3rlpy-0.51-cp37-cp37m-manylinux1_x86_64.whl
Algorithm Hash digest
SHA256 ae299315900537dd7e7bb6b347c85b01da0724f550e59bd2e780e748db33a693
MD5 0dfc3519468a017c531a6533d68e61c4
BLAKE2b-256 ea85d8818e56ee22e84de26b84a46f5138a73c259020925fc540432ecf877420

See more details on using hashes here.

File details

Details for the file d3rlpy-0.51-cp37-cp37m-macosx_10_14_x86_64.whl.

File metadata

  • Download URL: d3rlpy-0.51-cp37-cp37m-macosx_10_14_x86_64.whl
  • Upload date:
  • Size: 335.5 kB
  • Tags: CPython 3.7m, macOS 10.14+ x86-64
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/3.3.0 pkginfo/1.6.1 requests/2.25.1 setuptools/41.2.0 requests-toolbelt/0.9.1 tqdm/4.47.0 CPython/3.7.7

File hashes

Hashes for d3rlpy-0.51-cp37-cp37m-macosx_10_14_x86_64.whl
Algorithm Hash digest
SHA256 22a7f1efeaab1381758e5bb92855baf99752c2ec65167606378c749a1490414d
MD5 fdcd329e3f3ae17a51aac4f398cf522f
BLAKE2b-256 faaa9d17f72afb2043a3af2b28360d722886e8a84502b73b84ca0e16c14576e3

See more details on using hashes here.

File details

Details for the file d3rlpy-0.51-cp36-cp36m-win_amd64.whl.

File metadata

  • Download URL: d3rlpy-0.51-cp36-cp36m-win_amd64.whl
  • Upload date:
  • Size: 292.5 kB
  • Tags: CPython 3.6m, Windows x86-64
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/3.3.0 pkginfo/1.6.1 requests/2.25.1 setuptools/41.2.0 requests-toolbelt/0.9.1 tqdm/4.47.0 CPython/3.7.7

File hashes

Hashes for d3rlpy-0.51-cp36-cp36m-win_amd64.whl
Algorithm Hash digest
SHA256 25da59409828d0875f23c708f6307b5ff8055efb9eb28f0b0859c0d27f8a10e3
MD5 61f55206c63a8c85afb71fe77bfb0502
BLAKE2b-256 4b96261acbc34bde58c85b14433bf7f915fe53557ad5513b6b1a6d426c694657

See more details on using hashes here.

File details

Details for the file d3rlpy-0.51-cp36-cp36m-manylinux1_x86_64.whl.

File metadata

  • Download URL: d3rlpy-0.51-cp36-cp36m-manylinux1_x86_64.whl
  • Upload date:
  • Size: 893.6 kB
  • Tags: CPython 3.6m
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/3.3.0 pkginfo/1.6.1 requests/2.25.1 setuptools/41.2.0 requests-toolbelt/0.9.1 tqdm/4.47.0 CPython/3.7.7

File hashes

Hashes for d3rlpy-0.51-cp36-cp36m-manylinux1_x86_64.whl
Algorithm Hash digest
SHA256 1d1a2c745296e2448e1cbcfd9b1f52790bbb8e1d16a8200a2e598f34b2b690fa
MD5 f06d625cfc4a1e802893dd27a0142ca5
BLAKE2b-256 adfa30ae0cfba77c1f8300573eb85745c6553b103124477f23591e6052b88844

See more details on using hashes here.

File details

Details for the file d3rlpy-0.51-cp36-cp36m-macosx_10_14_x86_64.whl.

File metadata

  • Download URL: d3rlpy-0.51-cp36-cp36m-macosx_10_14_x86_64.whl
  • Upload date:
  • Size: 342.6 kB
  • Tags: CPython 3.6m, macOS 10.14+ x86-64
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/3.3.0 pkginfo/1.6.1 requests/2.25.1 setuptools/41.2.0 requests-toolbelt/0.9.1 tqdm/4.47.0 CPython/3.7.7

File hashes

Hashes for d3rlpy-0.51-cp36-cp36m-macosx_10_14_x86_64.whl
Algorithm Hash digest
SHA256 6447569a2518308716955332c2d945a3200d3e089d70b05307c77ff083429ed8
MD5 d248d43c329fad5595e458d2267932b3
BLAKE2b-256 33a9aedb922ea0600b8807b6f6a72c84bcb16858da15fff369fe39439d3f657c

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page