An offline deep reinforcement learning library
Project description
Deep Reinforcement Learning
Shyamal H Anadkat | Fall '21
Background
Hello! This is a repository for AIPI530 DeepRL final project. The goal is to build a pipeline for offline RL. The starter code has been forked from d3rlpy (see citation at the bottom) Offline reinforcement learning (RL) defines the task of learning from a fixed batch of data.
Before diving in, I would recommend getting familiarized with basic Reinforcement Learning. Here is a link to my blog post on Reinforcement Learning to get you started: RL Primer
The blog post briefly covers the following:
- What is reinforcement learning ?
- What are the pros and cons of reinforcement learning ?
- When should we consider applying reinforcement learning (and when should not) ?
- What's the difference between supervised learning and reinforcement learning ?
- What is offline reinforcement learning ? What are the pros and cons of offline reinforcement learning ?
- When should we consider applying offline reinforcement learning (and when should not) ?
- Have an example of offline reinforcement learning in the real-world
Getting Started
(please read carefully)
This project is customized to training CQL on a custom dataset in d3rlpy, and training OPE (FQE) to evaluate the trained policy. Important scripts:
cql_train.py
: at the root of the project is the main script, used to train cql & get evaluation scoresplot_helper.py
: utility script to help produce the plots required
How do I install & run this project ?
1. Clone this repository
git clone https://github.com/shyamal-anadkat/offlinerl
2. Install pybullet from source:
pip install git+https://github.com/takuseno/d4rl-pybullet
3. Install requirements:
pip install Cython numpy
pip install -e .
- Execute
cql_train.py
found at the root of the project- Default dataset is
hopper-bullet-mixed-v0
- Default no. of
epochs
is10
. You can change this via custom args--epochs_cql
&--epochs_fqe
- For example if we want to run for 10 epochs:
- Default dataset is
python cql_train.py --epochs_cql 10 --epochs_fqe 10
(see colab example below for more clarity)
-
Important Logs:
- Estimated Q values vs training steps (CQL):
d3rlpy_logs/CQL_hopper-bullet-mixed-v0_1/init_value.csv
- Average reward vs training steps (CQL):
d3rlpy_logs/CQL_hopper-bullet-mixed-v0_1/environment.csv
- True Q values vs training steps (CQL):
d3rlpy_logs/CQL_hopper-bullet-mixed-v0_1/true_q_value.csv
- True Q & Estimated Q values vs training steps (FQE):
d3rlpy_logs/FQE_hopper-bullet-mixed-v0_1/..
- Note: I created my own scorer to calculate the true q values. See
scorer.py
(true_q_value_scorer
) for implementation details)
- Estimated Q values vs training steps (CQL):
-
For plotting, I wrote a utility script (at root of the project) which can be executed like so
python plot_helper.py
Note: you can provide arguments that correspond to the path to the logs or it will use the default.
- If you're curious here's the benchmark/reproduction
Other scripts:
- Format:
./scripts/format
- Linting:
./scripts/lint
Sample Plots (with 100 epochs):
Note: logs can be found in /d3rlpy_logs
Examples speak more:
Walkthrough:
Background on d3rlpy
d3rlpy is an offline deep reinforcement learning library for practitioners and researchers.
- Documentation: https://d3rlpy.readthedocs.io
- Paper: https://arxiv.org/abs/2111.03788
How do I install d3rlpy?
d3rlpy supports Linux, macOS and Windows. d3rlpy is not only easy, but also completely compatible with scikit-learn API, which means that you can maximize your productivity with the useful scikit-learn's utilities.
PyPI (recommended)
$ pip install d3rlpy
More examples around d3rlpy usage
import d3rlpy
dataset, env = d3rlpy.datasets.get_dataset("hopper-medium-v0")
# prepare algorithm
sac = d3rlpy.algos.SAC()
# train offline
sac.fit(dataset, n_steps=1000000)
# train online
sac.fit_online(env, n_steps=1000000)
# ready to control
actions = sac.predict(x)
MuJoCo
import d3rlpy
# prepare dataset
dataset, env = d3rlpy.datasets.get_d4rl('hopper-medium-v0')
# prepare algorithm
cql = d3rlpy.algos.CQL(use_gpu=True)
# train
cql.fit(dataset,
eval_episodes=dataset,
n_epochs=100,
scorers={
'environment': d3rlpy.metrics.evaluate_on_environment(env),
'td_error': d3rlpy.metrics.td_error_scorer
})
See more datasets at d4rl.
Atari 2600
import d3rlpy
from sklearn.model_selection import train_test_split
# prepare dataset
dataset, env = d3rlpy.datasets.get_atari('breakout-expert-v0')
# split dataset
train_episodes, test_episodes = train_test_split(dataset, test_size=0.1)
# prepare algorithm
cql = d3rlpy.algos.DiscreteCQL(n_frames=4, q_func_factory='qr', scaler='pixel', use_gpu=True)
# start training
cql.fit(train_episodes,
eval_episodes=test_episodes,
n_epochs=100,
scorers={
'environment': d3rlpy.metrics.evaluate_on_environment(env),
'td_error': d3rlpy.metrics.td_error_scorer
})
See more Atari datasets at d4rl-atari.
PyBullet
import d3rlpy
# prepare dataset
dataset, env = d3rlpy.datasets.get_pybullet('hopper-bullet-mixed-v0')
# prepare algorithm
cql = d3rlpy.algos.CQL(use_gpu=True)
# start training
cql.fit(dataset,
eval_episodes=dataset,
n_epochs=100,
scorers={
'environment': d3rlpy.metrics.evaluate_on_environment(env),
'td_error': d3rlpy.metrics.td_error_scorer
})
See more PyBullet datasets at d4rl-pybullet.
How about some tutorials?
Try a cartpole example on Google Colaboratory:
Citation
Thanks to Takuma Seno and his work on d3rlpy This wouldn't have been possible without it.
Seno, T., & Imai, M. (2021). d3rlpy: An Offline Deep Reinforcement Learning Library Conference paper. 35th Conference on Neural Information Processing Systems, Offline Reinforcement Learning Workshop, 2021
@InProceedings{seno2021d3rlpy,
author = {Takuma Seno, Michita Imai},
title = {d3rlpy: An Offline Deep Reinforcement Library},
booktitle = {NeurIPS 2021 Offline Reinforcement Learning Workshop},
month = {December},
year = {2021}
}
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
File details
Details for the file zjkdemo2-0.91.tar.gz
.
File metadata
- Download URL: zjkdemo2-0.91.tar.gz
- Upload date:
- Size: 339.8 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/3.8.0 pkginfo/1.8.2 readme-renderer/32.0 requests/2.27.1 requests-toolbelt/0.9.1 urllib3/1.26.8 tqdm/4.62.3 importlib-metadata/4.10.1 keyring/23.5.0 rfc3986/2.0.0 colorama/0.4.4 CPython/3.7.8
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | 2137a5fd191f523ee9e91803503d5e12f050f04cada50d4af1781f6adde3913d |
|
MD5 | 620d0db4edab13b14a862dd19ce4d36a |
|
BLAKE2b-256 | bcc9132dc0ee78f9b045c496b2ee97d4fe031b25359a68cd09b07e573604e8a2 |