Skip to main content

Data-driven Deep Reinforcement Learning Library as an Out-of-the-box Tool

Project description

format check test codecov MIT

d3rlpy

Data-driven Deep Reinforcement Learning Library as an Out-of-the-box Tool.

from d3rlpy.dataset import MDPDataset
from d3rlpy.algos import BEAR

# MDPDataset takes arrays of state transitions
dataset = MDPDataset(observations, actions, rewards, terminals)

# train data-driven deep RL
bear = BEAR()
bear.fit(dataset.episodes)

# ready to control
actions = bear.predict(x)

d3rlpy provides state-of-the-art data-driven deep reinforcement learning algorithms through out-of-the-box scikit-learn-style APIs. Unlike other RL libraries, the provided algorithms can achieve extremely powerful performance beyond the paper via several tweaks.

These are the design principles of d3rlpy:

  • d3rlpy is designed for practical projects unlike the many other RL libraries.
  • d3rlpy is not focusing on reproducing RL papers.
  • d3rlpy is adding more techniques than the original implementations.

installation

$ pip install d3rlpy

scikit-learn compatibility

This library is designed as if born from scikit-learn. You can fully utilize scikit-learn's utilities to increase your productivity.

from sklearn.model_selection import train_test_split
from d3rlpy.metrics.scorer import td_error_scorer

train_episodes, test_episodes = train_test_split(dataset)

bear.fit(train_episodes,
         eval_episodes=test_episodes,
         scorers={'td_error': td_error_scorer})

You can naturally perform cross-validation.

from sklearn.model_selection import cross_validate

scores = cross_validate(bear, dataset, scoring={'td_error': td_error_scorer})

And more.

from sklearn.model_selection import GridSearchCV

gscv = GridSearchCV(estimator=bear,
                    param_grid={'actor_learning_rate': np.arange(1, 10) * 1e-3},
                    scoring={'td_error': td_error_scorer},
                    refit=False)
gscv.fit(train_episodes)

supported algorithms

algorithm discrete control continuous control data-driven RL?
Behavior Cloning (supervised learning) :white_check_mark: :white_check_mark:
Deep Q-Network (DQN) :white_check_mark: :no_entry:
Double DQN :white_check_mark: :no_entry:
Deep Deterministic Policy Gradients (DDPG) :no_entry: :white_check_mark:
Twin Delayed Deep Deterministic Policy Gradients (TD3) :no_entry: :white_check_mark:
Soft Actor-Critic (SAC) :no_entry: :white_check_mark:
Random Ensemble Mixture (REM) :construction: :no_entry: :white_check_mark:
Batch Constrained Q-learning (BCQ) :white_check_mark: :white_check_mark: :white_check_mark:
Bootstrapping Error Accumulation Reduction (BEAR) :no_entry: :white_check_mark: :white_check_mark:
Advantage-Weighted Regression (AWR) :construction: :construction: :white_check_mark:
Advantage-weighted Behavior Model (ABM) :construction: :construction: :white_check_mark:
Conservative Q-Learning (CQL) (recommended) :white_check_mark: :white_check_mark: :white_check_mark:

supported Q functions

other features

Basically, all features are available with every algorithm.

  • evaluation metrics in a scikit-learn scorer function style
  • embedded preprocessors
  • ensemble Q function with bootstrapping
  • delayed policy updates
  • parallel cross validation with multiple GPU
  • online training
  • Model-based Offline Policy Optimization (experimental)
  • user-defined custom network
  • automatic image augmentation

examples

Atari 2600

from d3rlpy.datasets import get_atari
from d3rlpy.algos import DiscreteCQL
from d3rlpy.metrics.scorer import evaluate_on_environment
from d3rlpy.metrics.scorer import discounted_sum_of_advantage_scorer
from sklearn.model_selection import train_test_split

# get data-driven RL dataset
dataset, env = get_atari('breakout-expert-v0')

# split dataset
train_episodes, test_episodes = train_test_split(dataset, test_size=0.2)

# setup algorithm
cql = DiscreteCQL(n_epochs=100,
                  n_critics=3,
                  bootstrap=True,
                  q_func_type='qr',
                  scaler='pixel',
                  use_gpu=True)

# start training
cql.fit(train_episodes,
        eval_episodes=test_episodes,
        scorers={
            'environment': evaluate_on_environment(env),
            'advantage': discounted_sum_of_advantage_scorer
        })
performance demo
breakout breakout

See more Atari datasets at d4rl-atari.

PyBullet

from d3rlpy.datasets import get_pybullet
from d3rlpy.algos import CQL
from d3rlpy.metrics.scorer import evaluate_on_environment
from d3rlpy.metrics.scorer import discounted_sum_of_advantage_scorer
from sklearn.model_selection import train_test_split

# get data-driven RL dataset
dataset, env = get_pybullet('hopper-bullet-mixed-v0')

# split dataset
train_episodes, test_episodes = train_test_split(dataset, test_size=0.2)

# setup algorithm
cql = CQL(n_epochs=300,
          n_critics=10,
          bootstrap=True,
          update_actor_interval=2,
          q_func_type='qr',
          use_gpu=True)

# start training
cql.fit(train_episodes,
        eval_episodes=test_episodes,
        scorers={
            'environment': evaluate_on_environment(env),
            'advantage': discounted_sum_of_advantage_scorer
        })
performance demo
hopper hopper

See more PyBullet datasets at d4rl-pybullet.

contributions

coding style

This library is fully formatted with yapf. You can format the entire scripts as follows:

$ ./scripts/format

test

The unit tests are provided as much as possible. This repository is using pytest-cov instead of pytest. You can run the entire tests as follows:

$ ./scripts/test

If you give -p option, the performance tests with toy tasks are also run (this will take minutes).

$ ./scripts/test -p

acknowledgement

This work is supported by Information-technology Promotion Agency, Japan (IPA), Exploratory IT Human Resources Project (MITOU Program) in the fiscal year 2020.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

d3rlpy-0.1.tar.gz (51.2 kB view hashes)

Uploaded Source

Built Distribution

d3rlpy-0.1-py3-none-any.whl (78.3 kB view hashes)

Uploaded Python 3

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page