Skip to main content

A TensorFlow 2.0 implemented RL baselines

Project description

Unstable Baselines (Early Access)

A Deep Reinforcement Learning codebase in TensorFlow 2.0 with an unified, flexible and highly customizable structure for fast prototyping.

Features Unstable Baselines Stable-Baselines3 OpenAI Baselines
State of the art RL methods :heavy_minus_sign: (1) :heavy_check_mark: :heavy_check_mark:
Documentation :x: :heavy_check_mark: :x:
Custom callback (2) :x: :vomiting_face: :heavy_minus_sign:
TensorFlow 2.0 support :heavy_check_mark: :x: :x:
Clean, elegant code :heavy_check_mark: :x: :x:
Easy to trace, customize :heavy_check_mark: :x: (3) :x: (3)
Standalone implementations :heavy_check_mark: :heavy_minus_sign: :x: (4)

(1) Currently only support DQN, C51, PPO, TD3, ...etc. We are still working on other algorithms.
(2) For example, in Stable-Baselines, you need to write this disgusting custom callback to save the best-performed model :vomiting_face:, while in Unstable Baselines, they are automatically saved.
(3) If you have traced Stable-baselines or OpenAI/baselines once, you'll never do that again.
(4) Many cross-dependencies across all algos make the code very hard to trace, for example baselines/common/policies.py, baselines/a2c/a2c.py.... Great job! OpenAI!:cat:

Documentation

We don't have any documentation yet.

Installation

Basic requirements:

  • Python >= 3.6
  • TensorFlow (CPU/GPU) >= 2.3.0

You can install from PyPI

$ pip install unstable_baselines

Or you can also install the latest version from this repository

$ pip install git+https://github.com/Ending2015a/unstable_baselines.git@master

Done! Now, you can

Algorithms

Model-free RL

Algorithm Box Discrete MultiDiscrete MultiBinary
DQN :x: :heavy_check_mark: :x: :x:
PPO :heavy_check_mark: :heavy_check_mark: :x: :x:
TD3 :heavy_check_mark: :x: :x: :x:
SD3 :heavy_check_mark: :x: :x: :x:

Distributional RL

Algorithm Box Discrete MultiDiscrete MultiBinary
C51 :x: :heavy_check_mark: :x: :x:
QRDQN :x: :heavy_check_mark: :x: :x:
IQN :x: :heavy_check_mark: :x: :x:

Quick Start

This example shows how to train a PPO agent to play CartPole-v0. You can find the full scripts in example/cartpole/train_ppo.py.

First, import dependencies

import gym
import unstable_baselines as ub
from unstable_baselines.algo.ppo import PPO

Create environments for training and evaluation

# create environments
env = ub.envs.VecEnv([gym.make('CartPole-v0') for _ in range(10)])
eval_env = gym.make('CartPole-v0')

Create a PPO model and train it

model = PPO(
    env,
    learning_rate=1e-3,
    gamma=0.8,
    batch_size=128,
    n_steps=500
).learn(  # train for 20000 steps
    20000,
    verbose=1
)

Save and load the trained model

model.save('./my_ppo_model')
model = PPO.load('./my_ppo_model')

Evaluate the training results

model.eval(eval_env, 20, 200, render=True)
# don't forget to close the environments!
env.close()
eval_env.close()

More examples:

Update Logs

  • 2021.05.22: Add benchmarks
  • 2021.04.27: Update to framework v2: supports saving/loading the best performed checkpoints.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

unstable_baselines-0.1.0a0.tar.gz (98.6 kB view hashes)

Uploaded Source

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page