Data-driven Deep Reinforcement Learning Library as an Out-of-the-box Tool
Project description
d3rlpy
Data-driven Deep Reinforcement Learning Library as an Out-of-the-box Tool.
from d3rlpy.dataset import MDPDataset
from d3rlpy.algos import BEAR
# MDPDataset takes arrays of state transitions
dataset = MDPDataset(observations, actions, rewards, terminals)
# train data-driven deep RL
bear = BEAR()
bear.fit(dataset.episodes)
# ready to control
actions = bear.predict(x)
d3rlpy provides state-of-the-art data-driven deep reinforcement learning algorithms through out-of-the-box scikit-learn-style APIs. Unlike other RL libraries, the provided algorithms can achieve extremely powerful performance beyond the paper via several tweaks.
These are the design principles of d3rlpy:
- d3rlpy is designed for practical projects unlike the many other RL libraries.
- d3rlpy is not focusing on reproducing RL papers.
- d3rlpy is adding more techniques than the original implementations.
Documentation: https://d3rlpy.readthedocs.io
installation
$ pip install d3rlpy
scikit-learn compatibility
This library is designed as if born from scikit-learn. You can fully utilize scikit-learn's utilities to increase your productivity.
from sklearn.model_selection import train_test_split
from d3rlpy.metrics.scorer import td_error_scorer
train_episodes, test_episodes = train_test_split(dataset)
bear.fit(train_episodes,
eval_episodes=test_episodes,
scorers={'td_error': td_error_scorer})
You can naturally perform cross-validation.
from sklearn.model_selection import cross_validate
scores = cross_validate(bear, dataset, scoring={'td_error': td_error_scorer})
And more.
from sklearn.model_selection import GridSearchCV
gscv = GridSearchCV(estimator=bear,
param_grid={'actor_learning_rate': np.arange(1, 10) * 1e-3},
scoring={'td_error': td_error_scorer},
refit=False)
gscv.fit(train_episodes)
supported algorithms
| algorithm | discrete control | continuous control | data-driven RL? |
|---|---|---|---|
| Behavior Cloning (supervised learning) | :white_check_mark: | :white_check_mark: | |
| Deep Q-Network (DQN) | :white_check_mark: | :no_entry: | |
| Double DQN | :white_check_mark: | :no_entry: | |
| Deep Deterministic Policy Gradients (DDPG) | :no_entry: | :white_check_mark: | |
| Twin Delayed Deep Deterministic Policy Gradients (TD3) | :no_entry: | :white_check_mark: | |
| Soft Actor-Critic (SAC) | :no_entry: | :white_check_mark: | |
| Random Ensemble Mixture (REM) | :construction: | :no_entry: | :white_check_mark: |
| Batch Constrained Q-learning (BCQ) | :white_check_mark: | :white_check_mark: | :white_check_mark: |
| Bootstrapping Error Accumulation Reduction (BEAR) | :no_entry: | :white_check_mark: | :white_check_mark: |
| Advantage-Weighted Regression (AWR) | :construction: | :construction: | :white_check_mark: |
| Advantage-weighted Behavior Model (ABM) | :construction: | :construction: | :white_check_mark: |
| Conservative Q-Learning (CQL) (recommended) | :white_check_mark: | :white_check_mark: | :white_check_mark: |
supported Q functions
- standard Q function
- Quantile Regression
- Implicit Quantile Network
- Fully parametrized Quantile Function (experimental)
other features
Basically, all features are available with every algorithm.
- evaluation metrics in a scikit-learn scorer function style
- embedded preprocessors
- export greedy-policy as a TorchScript model
- ensemble Q function with bootstrapping
- delayed policy updates
- parallel cross validation with multiple GPU
- online training
- data augmentation
- Model-based Offline Policy Optimization
- user-defined custom network
examples
Atari 2600
from d3rlpy.datasets import get_atari
from d3rlpy.algos import DiscreteCQL
from d3rlpy.metrics.scorer import evaluate_on_environment
from d3rlpy.metrics.scorer import discounted_sum_of_advantage_scorer
from sklearn.model_selection import train_test_split
# get data-driven RL dataset
dataset, env = get_atari('breakout-expert-v0')
# split dataset
train_episodes, test_episodes = train_test_split(dataset, test_size=0.2)
# setup algorithm
cql = DiscreteCQL(n_epochs=100,
n_critics=3,
bootstrap=True,
q_func_type='qr',
scaler='pixel',
use_gpu=True)
# start training
cql.fit(train_episodes,
eval_episodes=test_episodes,
scorers={
'environment': evaluate_on_environment(env),
'advantage': discounted_sum_of_advantage_scorer
})
| performance | demo |
|---|---|
See more Atari datasets at d4rl-atari.
PyBullet
from d3rlpy.datasets import get_pybullet
from d3rlpy.algos import CQL
from d3rlpy.metrics.scorer import evaluate_on_environment
from d3rlpy.metrics.scorer import discounted_sum_of_advantage_scorer
from sklearn.model_selection import train_test_split
# get data-driven RL dataset
dataset, env = get_pybullet('hopper-bullet-mixed-v0')
# split dataset
train_episodes, test_episodes = train_test_split(dataset, test_size=0.2)
# setup algorithm
cql = CQL(n_epochs=300,
actor_learning_rate=1e-3,
critic_learning_rate=1e-3,
temp_learning_rate=1e-3,
alpha_learning_rate=1e-3,
n_critics=10,
bootstrap=True,
update_actor_interval=2,
q_func_type='qr',
use_gpu=True)
# start training
cql.fit(train_episodes,
eval_episodes=test_episodes,
scorers={
'environment': evaluate_on_environment(env),
'advantage': discounted_sum_of_advantage_scorer
})
| performance | demo |
|---|---|
See more PyBullet datasets at d4rl-pybullet.
tutorials
Try a cartpole example on Google Colaboratory!
contributions
coding style
This library is fully formatted with yapf. You can format the entire scripts as follows:
$ ./scripts/format
test
The unit tests are provided as much as possible.
This repository is using pytest-cov instead of pytest.
You can run the entire tests as follows:
$ ./scripts/test
If you give -p option, the performance tests with toy tasks are also run
(this will take minutes).
$ ./scripts/test -p
acknowledgement
This work is supported by Information-technology Promotion Agency, Japan (IPA), Exploratory IT Human Resources Project (MITOU Program) in the fiscal year 2020.
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file d3rlpy-0.2.tar.gz.
File metadata
- Download URL: d3rlpy-0.2.tar.gz
- Upload date:
- Size: 54.9 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/3.2.0 pkginfo/1.5.0.1 requests/2.24.0 setuptools/49.2.0 requests-toolbelt/0.9.1 tqdm/4.47.0 CPython/3.6.1
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
3258864d65af1b9640699d2fb97f5c4aa64fcf42dd2a5d1c4a91a85f18af2693
|
|
| MD5 |
69a361aa8694f2218f8c3e31dd7501f7
|
|
| BLAKE2b-256 |
83471ed95fb7790503ff2b756ab35bb92608bf9ad0e41cd4af15e309887dd711
|
File details
Details for the file d3rlpy-0.2-py3-none-any.whl.
File metadata
- Download URL: d3rlpy-0.2-py3-none-any.whl
- Upload date:
- Size: 84.6 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/3.2.0 pkginfo/1.5.0.1 requests/2.24.0 setuptools/49.2.0 requests-toolbelt/0.9.1 tqdm/4.47.0 CPython/3.6.1
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
b7d71c513058acaded2793532b9f07d34322e42c606075b19ac13a4c5793773c
|
|
| MD5 |
fe699aa19d2036f285157552fcdbc8be
|
|
| BLAKE2b-256 |
7ab948b797e35ab48a9b22d29e14cb502c3dcfcef948f85c167c4fdce7e0527d
|