Data-driven Deep Reinforcement Learning Library as an Out-of-the-box Tool
Project description
d3rlpy
Data-driven Deep Reinforcement Learning Library as an Out-of-the-box Tool.
from d3rlpy.dataset import MDPDataset
from d3rlpy.algos import BEAR
# MDPDataset takes arrays of state transitions
dataset = MDPDataset(observations, actions, rewards, terminals)
# train data-driven deep RL
bear = BEAR()
bear.fit(dataset.episodes)
# ready to control
actions = bear.predict(x)
d3rlpy provides state-of-the-art data-driven deep reinforcement learning algorithms through out-of-the-box scikit-learn-style APIs. Unlike other RL libraries, the provided algorithms can achieve extremely powerful performance beyond the paper via several tweaks.
These are the design principles of d3rlpy:
- d3rlpy is designed for practical projects unlike the many other RL libraries.
- d3rlpy is not focusing on reproducing RL papers.
- d3rlpy is adding more techniques than the original implementations.
installation
$ pip install d3rlpy
scikit-learn compatibility
This library is designed as if born from scikit-learn. You can fully utilize scikit-learn's utilities to increase your productivity.
from sklearn.model_selection import train_test_split
from d3rlpy.metrics.scorer import td_error_scorer
train_episodes, test_episodes = train_test_split(dataset)
bear.fit(train_episodes,
eval_episodes=test_episodes,
scorers={'td_error': td_error_scorer})
You can naturally perform cross-validation.
from sklearn.model_selection import cross_validate
scores = cross_validate(bear, dataset, scoring={'td_error': td_error_scorer})
And more.
from sklearn.model_selection import GridSearchCV
gscv = GridSearchCV(estimator=bear,
param_grid={'actor_learning_rate': np.arange(1, 10) * 1e-3},
scoring={'td_error': td_error_scorer},
refit=False)
gscv.fit(train_episodes)
supported algorithms
algorithm | discrete control | continuous control | data-driven RL? |
---|---|---|---|
Behavior Cloning (supervised learning) | :white_check_mark: | :white_check_mark: | |
Deep Q-Network (DQN) | :white_check_mark: | :no_entry: | |
Double DQN | :white_check_mark: | :no_entry: | |
Deep Deterministic Policy Gradients (DDPG) | :no_entry: | :white_check_mark: | |
Twin Delayed Deep Deterministic Policy Gradients (TD3) | :no_entry: | :white_check_mark: | |
Soft Actor-Critic (SAC) | :no_entry: | :white_check_mark: | |
Random Ensemble Mixture (REM) | :construction: | :no_entry: | :white_check_mark: |
Batch Constrained Q-learning (BCQ) | :white_check_mark: | :white_check_mark: | :white_check_mark: |
Bootstrapping Error Accumulation Reduction (BEAR) | :no_entry: | :white_check_mark: | :white_check_mark: |
Advantage-Weighted Regression (AWR) | :construction: | :construction: | :white_check_mark: |
Advantage-weighted Behavior Model (ABM) | :construction: | :construction: | :white_check_mark: |
Conservative Q-Learning (CQL) (recommended) | :white_check_mark: | :white_check_mark: | :white_check_mark: |
supported Q functions
- standard Q function
- Quantile Regression
- Implicit Quantile Network
- Fully parametrized Quantile Function (experimental)
other features
Basically, all features are available with every algorithm.
- evaluation metrics in a scikit-learn scorer function style
- embedded preprocessors
- ensemble Q function with bootstrapping
- delayed policy updates
- parallel cross validation with multiple GPU
- online training
- Model-based Offline Policy Optimization (experimental)
- user-defined custom network
- automatic image augmentation
examples
Atari 2600
from d3rlpy.datasets import get_atari
from d3rlpy.algos import DiscreteCQL
from d3rlpy.metrics.scorer import evaluate_on_environment
from d3rlpy.metrics.scorer import discounted_sum_of_advantage_scorer
from sklearn.model_selection import train_test_split
# get data-driven RL dataset
dataset, env = get_atari('breakout-expert-v0')
# split dataset
train_episodes, test_episodes = train_test_split(dataset, test_size=0.2)
# setup algorithm
cql = DiscreteCQL(n_epochs=100,
n_critics=3,
bootstrap=True,
q_func_type='qr',
scaler='pixel',
use_gpu=True)
# start training
cql.fit(train_episodes,
eval_episodes=test_episodes,
scorers={
'environment': evaluate_on_environment(env),
'advantage': discounted_sum_of_advantage_scorer
})
performance | demo |
---|---|
See more Atari datasets at d4rl-atari.
PyBullet
from d3rlpy.datasets import get_pybullet
from d3rlpy.algos import CQL
from d3rlpy.metrics.scorer import evaluate_on_environment
from d3rlpy.metrics.scorer import discounted_sum_of_advantage_scorer
from sklearn.model_selection import train_test_split
# get data-driven RL dataset
dataset, env = get_pybullet('hopper-bullet-mixed-v0')
# split dataset
train_episodes, test_episodes = train_test_split(dataset, test_size=0.2)
# setup algorithm
cql = CQL(n_epochs=300,
n_critics=10,
bootstrap=True,
update_actor_interval=2,
q_func_type='qr',
use_gpu=True)
# start training
cql.fit(train_episodes,
eval_episodes=test_episodes,
scorers={
'environment': evaluate_on_environment(env),
'advantage': discounted_sum_of_advantage_scorer
})
performance | demo |
---|---|
See more PyBullet datasets at d4rl-pybullet.
contributions
coding style
This library is fully formatted with yapf. You can format the entire scripts as follows:
$ ./scripts/format
test
The unit tests are provided as much as possible.
This repository is using pytest-cov
instead of pytest
.
You can run the entire tests as follows:
$ ./scripts/test
If you give -p
option, the performance tests with toy tasks are also run
(this will take minutes).
$ ./scripts/test -p
acknowledgement
This work is supported by Information-technology Promotion Agency, Japan (IPA), Exploratory IT Human Resources Project (MITOU Program) in the fiscal year 2020.
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.