Skip to main content

Data-driven Deep Reinforcement Learning Library as an Out-of-the-box Tool

Project description

PyPI version format check test Documentation Status codecov Language grade: Python MIT

d3rlpy

Data-driven Deep Reinforcement Learning Library as an Out-of-the-box Tool.

from d3rlpy.dataset import MDPDataset
from d3rlpy.algos import BEAR

# MDPDataset takes arrays of state transitions
dataset = MDPDataset(observations, actions, rewards, terminals)

# train data-driven deep RL
bear = BEAR()
bear.fit(dataset.episodes)

# ready to control
actions = bear.predict(x)

d3rlpy provides state-of-the-art data-driven deep reinforcement learning algorithms through out-of-the-box scikit-learn-style APIs. Unlike other RL libraries, the provided algorithms can achieve extremely powerful performance beyond the paper via several tweaks.

These are the design principles of d3rlpy:

  • d3rlpy is designed for practical projects unlike the many other RL libraries.
  • d3rlpy is not focusing on reproducing RL papers.
  • d3rlpy is adding more techniques than the original implementations.

Documentation: https://d3rlpy.readthedocs.io

installation

$ pip install d3rlpy

scikit-learn compatibility

This library is designed as if born from scikit-learn. You can fully utilize scikit-learn's utilities to increase your productivity.

from sklearn.model_selection import train_test_split
from d3rlpy.metrics.scorer import td_error_scorer

train_episodes, test_episodes = train_test_split(dataset)

bear.fit(train_episodes,
         eval_episodes=test_episodes,
         scorers={'td_error': td_error_scorer})

You can naturally perform cross-validation.

from sklearn.model_selection import cross_validate

scores = cross_validate(bear, dataset, scoring={'td_error': td_error_scorer})

And more.

from sklearn.model_selection import GridSearchCV

gscv = GridSearchCV(estimator=bear,
                    param_grid={'actor_learning_rate': np.arange(1, 10) * 1e-3},
                    scoring={'td_error': td_error_scorer},
                    refit=False)
gscv.fit(train_episodes)

supported algorithms

algorithm discrete control continuous control data-driven RL?
Behavior Cloning (supervised learning) :white_check_mark: :white_check_mark:
Deep Q-Network (DQN) :white_check_mark: :no_entry:
Double DQN :white_check_mark: :no_entry:
Deep Deterministic Policy Gradients (DDPG) :no_entry: :white_check_mark:
Twin Delayed Deep Deterministic Policy Gradients (TD3) :no_entry: :white_check_mark:
Soft Actor-Critic (SAC) :no_entry: :white_check_mark:
Random Ensemble Mixture (REM) :construction: :no_entry: :white_check_mark:
Batch Constrained Q-learning (BCQ) :white_check_mark: :white_check_mark: :white_check_mark:
Bootstrapping Error Accumulation Reduction (BEAR) :no_entry: :white_check_mark: :white_check_mark:
Advantage-Weighted Regression (AWR) :construction: :construction: :white_check_mark:
Advantage-weighted Behavior Model (ABM) :construction: :construction: :white_check_mark:
Conservative Q-Learning (CQL) (recommended) :white_check_mark: :white_check_mark: :white_check_mark:

supported Q functions

other features

Basically, all features are available with every algorithm.

examples

Atari 2600

from d3rlpy.datasets import get_atari
from d3rlpy.algos import DiscreteCQL
from d3rlpy.metrics.scorer import evaluate_on_environment
from d3rlpy.metrics.scorer import discounted_sum_of_advantage_scorer
from sklearn.model_selection import train_test_split

# get data-driven RL dataset
dataset, env = get_atari('breakout-expert-v0')

# split dataset
train_episodes, test_episodes = train_test_split(dataset, test_size=0.2)

# setup algorithm
cql = DiscreteCQL(n_epochs=100,
                  n_critics=3,
                  bootstrap=True,
                  q_func_type='qr',
                  scaler='pixel',
                  use_gpu=True)

# start training
cql.fit(train_episodes,
        eval_episodes=test_episodes,
        scorers={
            'environment': evaluate_on_environment(env),
            'advantage': discounted_sum_of_advantage_scorer
        })
performance demo
breakout breakout

See more Atari datasets at d4rl-atari.

PyBullet

from d3rlpy.datasets import get_pybullet
from d3rlpy.algos import CQL
from d3rlpy.metrics.scorer import evaluate_on_environment
from d3rlpy.metrics.scorer import discounted_sum_of_advantage_scorer
from sklearn.model_selection import train_test_split

# get data-driven RL dataset
dataset, env = get_pybullet('hopper-bullet-mixed-v0')

# split dataset
train_episodes, test_episodes = train_test_split(dataset, test_size=0.2)

# setup algorithm
cql = CQL(n_epochs=300,
          actor_learning_rate=1e-3,
          critic_learning_rate=1e-3,
          temp_learning_rate=1e-3,
          alpha_learning_rate=1e-3,
          n_critics=10,
          bootstrap=True,
          update_actor_interval=2,
          q_func_type='qr',
          use_gpu=True)

# start training
cql.fit(train_episodes,
        eval_episodes=test_episodes,
        scorers={
            'environment': evaluate_on_environment(env),
            'advantage': discounted_sum_of_advantage_scorer
        })
performance demo
hopper hopper

See more PyBullet datasets at d4rl-pybullet.

tutorials

Try a cartpole example on Google Colaboratory!

Open In Colab

contributions

coding style

This library is fully formatted with yapf. You can format the entire scripts as follows:

$ ./scripts/format

test

The unit tests are provided as much as possible. This repository is using pytest-cov instead of pytest. You can run the entire tests as follows:

$ ./scripts/test

If you give -p option, the performance tests with toy tasks are also run (this will take minutes).

$ ./scripts/test -p

acknowledgement

This work is supported by Information-technology Promotion Agency, Japan (IPA), Exploratory IT Human Resources Project (MITOU Program) in the fiscal year 2020.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

d3rlpy-0.2.tar.gz (54.9 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

d3rlpy-0.2-py3-none-any.whl (84.6 kB view details)

Uploaded Python 3

File details

Details for the file d3rlpy-0.2.tar.gz.

File metadata

  • Download URL: d3rlpy-0.2.tar.gz
  • Upload date:
  • Size: 54.9 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/3.2.0 pkginfo/1.5.0.1 requests/2.24.0 setuptools/49.2.0 requests-toolbelt/0.9.1 tqdm/4.47.0 CPython/3.6.1

File hashes

Hashes for d3rlpy-0.2.tar.gz
Algorithm Hash digest
SHA256 3258864d65af1b9640699d2fb97f5c4aa64fcf42dd2a5d1c4a91a85f18af2693
MD5 69a361aa8694f2218f8c3e31dd7501f7
BLAKE2b-256 83471ed95fb7790503ff2b756ab35bb92608bf9ad0e41cd4af15e309887dd711

See more details on using hashes here.

File details

Details for the file d3rlpy-0.2-py3-none-any.whl.

File metadata

  • Download URL: d3rlpy-0.2-py3-none-any.whl
  • Upload date:
  • Size: 84.6 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/3.2.0 pkginfo/1.5.0.1 requests/2.24.0 setuptools/49.2.0 requests-toolbelt/0.9.1 tqdm/4.47.0 CPython/3.6.1

File hashes

Hashes for d3rlpy-0.2-py3-none-any.whl
Algorithm Hash digest
SHA256 b7d71c513058acaded2793532b9f07d34322e42c606075b19ac13a4c5793773c
MD5 fe699aa19d2036f285157552fcdbc8be
BLAKE2b-256 7ab948b797e35ab48a9b22d29e14cb502c3dcfcef948f85c167c4fdce7e0527d

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page