A TensorFlow 2.0 implemented RL baselines
Project description
Unstable Baselines (Early Access)
A Deep Reinforcement Learning codebase in TensorFlow 2.0 with an unified, flexible and highly customizable structure for fast prototyping.
Features | Unstable Baselines | Stable-Baselines3 | OpenAI Baselines |
---|---|---|---|
State of the art RL methods | :heavy_minus_sign: (1) | :heavy_check_mark: | :heavy_check_mark: |
Documentation | :x: | :heavy_check_mark: | :x: |
Custom callback (2) | :x: | :vomiting_face: | :heavy_minus_sign: |
TensorFlow 2.0 support | :heavy_check_mark: | :x: | :x: |
Clean, elegant code | :heavy_check_mark: | :x: | :x: |
Easy to trace, customize | :heavy_check_mark: | :x: (3) | :x: (3) |
Standalone implementations | :heavy_check_mark: | :heavy_minus_sign: | :x: (4) |
(1) Currently only support DQN, C51, PPO, TD3, ...etc. We are still working on other algorithms.
(2) For example, in Stable-Baselines, you need to write this disgusting custom callback to save the best-performed model :vomiting_face:, while in Unstable Baselines, they are automatically saved.
(3) If you have traced Stable-baselines or OpenAI/baselines once, you'll never do that again.
(4) Many cross-dependencies across all algos make the code very hard to trace, for example baselines/common/policies.py, baselines/a2c/a2c.py.... Great job! OpenAI!:cat:
Documentation
We don't have any documentation yet.
Installation
Basic requirements:
- Python >= 3.6
- TensorFlow (CPU/GPU) >= 2.3.0
You can install from PyPI
$ pip install unstable_baselines
Or you can also install the latest version from this repository
$ pip install git+https://github.com/Ending2015a/unstable_baselines.git@master
Done! Now, you can
- Go through the Quick Start section
- Or run the example codes in example folder.
Algorithms
Model-free RL
Algorithm | Box |
Discrete |
MultiDiscrete |
MultiBinary |
---|---|---|---|---|
DQN | :x: | :heavy_check_mark: | :x: | :x: |
PPO | :heavy_check_mark: | :heavy_check_mark: | :x: | :x: |
TD3 | :heavy_check_mark: | :x: | :x: | :x: |
SD3 | :heavy_check_mark: | :x: | :x: | :x: |
- 2021.09.17: DQN supports
- Multi-step learning
- Prioritized experience replay: arXiv:1511.05952
- Dueling network: arXiv:1511.06581
- 2021.04.19: Implemented DQN
- From paper: arXiv:1509.06461
- 2021.03.27: PPO support continuous (Box) action space
- 2021.03.23: Implemented SD3
- From paper: arXiv:2010.09177
- 2021.03.20: Implemented TD3
- From paper: arXiv:1802.09477
- 2021.03.10: Implemented PPO
- From paper: arXiv:1707.06347
Distributional RL
Algorithm | Box |
Discrete |
MultiDiscrete |
MultiBinary |
---|---|---|---|---|
C51 | :x: | :heavy_check_mark: | :x: | :x: |
QRDQN | :x: | :heavy_check_mark: | :x: | :x: |
IQN | :x: | :heavy_check_mark: | :x: | :x: |
- 2021.04.28: Implemented IQN
- From paper: arXiv:1806.06923
- 2021.04.21: Implemented QRDQN
- From paper: arXiv:1710.10044
- 2021.04.20: Implemented C51
- From paper: arXiv:1707.06887
Quick Start
This example shows how to train a PPO agent to play CartPole-v0
. You can find the full scripts in example/cartpole/train_ppo.py.
First, import dependencies
import gym
import unstable_baselines as ub
from unstable_baselines.algo.ppo import PPO
Create environments for training and evaluation
# create environments
env = ub.envs.VecEnv([gym.make('CartPole-v0') for _ in range(10)])
eval_env = gym.make('CartPole-v0')
Create a PPO model and train it
model = PPO(
env,
learning_rate=1e-3,
gamma=0.8,
batch_size=128,
n_steps=500
).learn( # train for 20000 steps
20000,
verbose=1
)
Save and load the trained model
model.save('./my_ppo_model')
model = PPO.load('./my_ppo_model')
Evaluate the training results
model.eval(eval_env, 20, 200, render=True)
# don't forget to close the environments!
env.close()
eval_env.close()
More examples:
Update Logs
- 2021.05.22: Add benchmarks
- 2021.04.27: Update to framework v2: supports saving/loading the best performed checkpoints.
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
File details
Details for the file unstable_baselines-0.1.0a0.tar.gz
.
File metadata
- Download URL: unstable_baselines-0.1.0a0.tar.gz
- Upload date:
- Size: 98.6 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/3.4.2 importlib_metadata/4.8.1 pkginfo/1.7.1 requests/2.26.0 requests-toolbelt/0.9.1 tqdm/4.62.2 CPython/3.7.10
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | 842e7090a298627fa3d41f377b6b99129504d676b62acf102ce418b0c7633714 |
|
MD5 | 8301eaf0358f59cc0c201a9c93c28588 |
|
BLAKE2b-256 | 233809fda390f9ae84dab606e674bfe7dd6b8bf2e399ae607d0ba196f14ca634 |