A TensorFlow 2.0 implemented RL baselines
Project description
Unstable Baselines (Early Access)
A Deep Reinforcement Learning codebase in TensorFlow 2.0 with an unified, flexible and highly customizable structure for fast prototyping.
Features | Unstable Baselines | Stable-Baselines3 | OpenAI Baselines |
---|---|---|---|
State of the art RL methods | :heavy_minus_sign: (1) | :heavy_check_mark: | :heavy_check_mark: |
Documentation | :x: | :heavy_check_mark: | :x: |
Custom callback (2) | :x: | :vomiting_face: | :heavy_minus_sign: |
TensorFlow 2.0 support | :heavy_check_mark: | :x: | :x: |
Clean, elegant code | :heavy_check_mark: | :x: | :x: |
Easy to trace, customize | :heavy_check_mark: | :x: (3) | :x: (3) |
Standalone implementations | :heavy_check_mark: | :heavy_minus_sign: | :x: (4) |
(1) Currently only support DQN, C51, PPO, TD3, ...etc. We are still working on other algorithms.
(2) For example, in Stable-Baselines, you need to write this disgusting custom callback to save the best-performed model :vomiting_face:, while in Unstable Baselines, they are automatically saved.
(3) If you have traced Stable-baselines or OpenAI/baselines once, you'll never do that again.
(4) Many cross-dependencies across all algos make the code very hard to trace, for example baselines/common/policies.py, baselines/a2c/a2c.py.... Great job! OpenAI!:cat:
Documentation
We don't have any documentation yet.
Installation
Basic requirements:
- Python >= 3.6
- TensorFlow (CPU/GPU) >= 2.3.0
You can install from PyPI
$ pip install unstable_baselines
Or you can also install the latest version from this repository
$ pip install git+https://github.com/Ending2015a/unstable_baselines.git@master
Done! Now, you can
- Go through the Quick Start section
- Or run the example codes in example folder.
Algorithms
Model-free RL
Algorithm | Box |
Discrete |
MultiDiscrete |
MultiBinary |
---|---|---|---|---|
DQN | :x: | :heavy_check_mark: | :x: | :x: |
PPO | :heavy_check_mark: | :heavy_check_mark: | :x: | :x: |
TD3 | :heavy_check_mark: | :x: | :x: | :x: |
SD3 | :heavy_check_mark: | :x: | :x: | :x: |
- 2021.09.17: DQN supports
- Multi-step learning
- Prioritized experience replay: arXiv:1511.05952
- Dueling network: arXiv:1511.06581
- 2021.04.19: Implemented DQN
- From paper: arXiv:1509.06461
- 2021.03.27: PPO support continuous (Box) action space
- 2021.03.23: Implemented SD3
- From paper: arXiv:2010.09177
- 2021.03.20: Implemented TD3
- From paper: arXiv:1802.09477
- 2021.03.10: Implemented PPO
- From paper: arXiv:1707.06347
Distributional RL
Algorithm | Box |
Discrete |
MultiDiscrete |
MultiBinary |
---|---|---|---|---|
C51 | :x: | :heavy_check_mark: | :x: | :x: |
QRDQN | :x: | :heavy_check_mark: | :x: | :x: |
IQN | :x: | :heavy_check_mark: | :x: | :x: |
- 2021.04.28: Implemented IQN
- From paper: arXiv:1806.06923
- 2021.04.21: Implemented QRDQN
- From paper: arXiv:1710.10044
- 2021.04.20: Implemented C51
- From paper: arXiv:1707.06887
Quick Start
This example shows how to train a PPO agent to play CartPole-v0
. You can find the full scripts in example/cartpole/train_ppo.py.
First, import dependencies
import gym
import unstable_baselines as ub
from unstable_baselines.algo.ppo import PPO
Create environments for training and evaluation
# create environments
env = ub.envs.VecEnv([gym.make('CartPole-v0') for _ in range(10)])
eval_env = gym.make('CartPole-v0')
Create a PPO model and train it
model = PPO(
env,
learning_rate=1e-3,
gamma=0.8,
batch_size=128,
n_steps=500
).learn( # train for 20000 steps
20000,
verbose=1
)
Save and load the trained model
model.save('./my_ppo_model')
model = PPO.load('./my_ppo_model')
Evaluate the training results
model.eval(eval_env, 20, 200, render=True)
# don't forget to close the environments!
env.close()
eval_env.close()
More examples:
Update Logs
- 2021.05.22: Add benchmarks
- 2021.04.27: Update to framework v2: supports saving/loading the best performed checkpoints.
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Hashes for unstable_baselines-0.1.0a0.tar.gz
Algorithm | Hash digest | |
---|---|---|
SHA256 | 842e7090a298627fa3d41f377b6b99129504d676b62acf102ce418b0c7633714 |
|
MD5 | 8301eaf0358f59cc0c201a9c93c28588 |
|
BLAKE2b-256 | 233809fda390f9ae84dab606e674bfe7dd6b8bf2e399ae607d0ba196f14ca634 |