Skip to main content

A TensorFlow 2.0 implemented RL baselines

Project description

Unstable Baselines (Early Access)

A Deep Reinforcement Learning codebase in TensorFlow 2.0 with an unified, flexible and highly customizable structure for fast prototyping.

Features Unstable Baselines Stable-Baselines3 OpenAI Baselines
State of the art RL methods :heavy_minus_sign: (1) :heavy_check_mark: :heavy_check_mark:
Documentation :x: :heavy_check_mark: :x:
Custom callback (2) :x: :vomiting_face: :heavy_minus_sign:
TensorFlow 2.0 support :heavy_check_mark: :x: :x:
Clean, elegant code :heavy_check_mark: :x: :x:
Easy to trace, customize :heavy_check_mark: :x: (3) :x: (3)
Standalone implementations :heavy_check_mark: :heavy_minus_sign: :x: (4)

(1) Currently only support DQN, C51, PPO, TD3, ...etc. We are still working on other algorithms.
(2) For example, in Stable-Baselines, you need to write this disgusting custom callback to save the best-performed model :vomiting_face:, while in Unstable Baselines, they are automatically saved.
(3) If you have traced Stable-baselines or OpenAI/baselines once, you'll never do that again.
(4) Many cross-dependencies across all algos make the code very hard to trace, for example baselines/common/policies.py, baselines/a2c/a2c.py.... Great job! OpenAI!:cat:

Documentation

We don't have any documentation yet.

Installation

Basic requirements:

  • Python >= 3.6
  • TensorFlow (CPU/GPU) >= 2.3.0

You can install from PyPI

$ pip install unstable_baselines

Or you can also install the latest version from this repository

$ pip install git+https://github.com/Ending2015a/unstable_baselines.git@master

Done! Now, you can

Algorithms

Model-free RL

Algorithm Box Discrete MultiDiscrete MultiBinary
DQN :x: :heavy_check_mark: :x: :x:
PPO :heavy_check_mark: :heavy_check_mark: :x: :x:
TD3 :heavy_check_mark: :x: :x: :x:
SD3 :heavy_check_mark: :x: :x: :x:

Distributional RL

Algorithm Box Discrete MultiDiscrete MultiBinary
C51 :x: :heavy_check_mark: :x: :x:
QRDQN :x: :heavy_check_mark: :x: :x:
IQN :x: :heavy_check_mark: :x: :x:

Quick Start

This example shows how to train a PPO agent to play CartPole-v0. You can find the full scripts in example/cartpole/train_ppo.py.

First, import dependencies

import gym
import unstable_baselines as ub
from unstable_baselines.algo.ppo import PPO

Create environments for training and evaluation

# create environments
env = ub.envs.VecEnv([gym.make('CartPole-v0') for _ in range(10)])
eval_env = gym.make('CartPole-v0')

Create a PPO model and train it

model = PPO(
    env,
    learning_rate=1e-3,
    gamma=0.8,
    batch_size=128,
    n_steps=500
).learn(  # train for 20000 steps
    20000,
    verbose=1
)

Save and load the trained model

model.save('./my_ppo_model')
model = PPO.load('./my_ppo_model')

Evaluate the training results

model.eval(eval_env, 20, 200, render=True)
# don't forget to close the environments!
env.close()
eval_env.close()

More examples:

Update Logs

  • 2021.05.22: Add benchmarks
  • 2021.04.27: Update to framework v2: supports saving/loading the best performed checkpoints.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

unstable_baselines-0.1.0a0.tar.gz (98.6 kB view details)

Uploaded Source

File details

Details for the file unstable_baselines-0.1.0a0.tar.gz.

File metadata

  • Download URL: unstable_baselines-0.1.0a0.tar.gz
  • Upload date:
  • Size: 98.6 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/3.4.2 importlib_metadata/4.8.1 pkginfo/1.7.1 requests/2.26.0 requests-toolbelt/0.9.1 tqdm/4.62.2 CPython/3.7.10

File hashes

Hashes for unstable_baselines-0.1.0a0.tar.gz
Algorithm Hash digest
SHA256 842e7090a298627fa3d41f377b6b99129504d676b62acf102ce418b0c7633714
MD5 8301eaf0358f59cc0c201a9c93c28588
BLAKE2b-256 233809fda390f9ae84dab606e674bfe7dd6b8bf2e399ae607d0ba196f14ca634

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page