Skip to main content

Baseline implementation of MuZero agent

Project description

supported platforms supported python versions dependencies status style black license MIT discord badge

MuZero General

A commented and documented implementation of MuZero based on the Google DeepMind paper (Nov 2019) and the associated pseudocode. It is designed to be easily adaptable for every games or reinforcement learning environments (like gym). You only need to add a game file with the hyperparameters and the game class. Please refer to the documentation and the example.

MuZero is a state of the art RL algorithm for board games (Chess, Go, ...) and Atari games. It is the successor to AlphaZero but without any knowledge of the environment underlying dynamics. MuZero learns a model of the environment and uses an internal representation that contains only the useful information for predicting the reward, value, policy and transitions. MuZero is also close to Value prediction networks. See How it works.

Disclaimer

This repository is fork of base MuZero implementation. Main target of fork allow higher customiztion and simple usage as library, more simular to OpenAI stable-baseelines.

Getting started

Installation

pip install muzero-baseline

Preapare game and configuration

from muzero_baseline.games.abstract_game import AbstractGame

# Create config for agent and network

class MuZeroConfig:
  def __init__(self): 
    self.seed = 0  # Seed for numpy, torch and the game
    self.max_num_gpus = None  # Fix the maximum number of GPUs to use. It's usually faster to use a single GPU (set it to 1) if it has enough memory. None will use every GPUs available

    ### Game
    self.observation_shape = (1, 1, 4)  # Dimensions of the game observation, must be 3D (channel, height, width). For a 1D array, please reshape it to (1, 1, length of array)
    self.action_space = list(range(2))  # Fixed list of all possible actions. You should only edit the length
    self.players = list(range(1))  # List of players. You should only edit the length
    self.stacked_observations = 0  # Number of previous observations and previous actions to add to the current observation

    # ...

class Game(AbstractGame):
    """
    Game wrapper.
    """

    def __init__(self, seed = None):

        self.env = gym.make("CartPole-v1")

        if seed is not None:
            self.env.seed(seed)

    # ...

More examples of configs and games can be found in games folder, you can adapt them for you needs.

More information is also available in wiki.

Initialize MuZero instance

from muzero_baseline.muzero import MuZero

# Initialize config
config = MuZeroConfig()
# Game object will be initialized in each thread separetly
mz = MuZero(TraidingGame, config)

Train agent

mz.train()

During training agent will save metrics and chekpoints of netowork and replay buffer in results folder.

Metrics can accessed though tensorboard

%load_ext tensorboard
%tensorboard --logdir ./results 

Test agent

mz.test()

For test in same thread

mz.test_direct()

Load existing model

mz.load_model(
    checkpoint_path = 'results/2021-07-15--16-06-15/model.checkpoint', 
    replay_buffer_path = 'results/2021-07-15--16-06-15/replay_buffer.pkl'
)

Features

  • Residual Network and Fully connected network in PyTorch
  • Multi-Threaded/Asynchronous/Cluster with Ray
  • Multi GPU support for the training and the selfplay
  • TensorBoard real-time monitoring
  • Model weights automatically saved at checkpoints
  • Single and two player mode
  • Commented and documented
  • Easily adaptable for new games
  • Examples of board games, Gym and Atari games (See list of implemented games)
  • Pretrained weights available
  • Windows support (Experimental / Workaround: Use the notebook in Google Colab)

Further improvements

These improvements are active research, they are personal ideas and go beyond MuZero paper. We are open to contributions and other ideas.

Demo

All performances are tracked and displayed in real time in TensorBoard :

cartpole training summary

Testing Lunar Lander :

lunarlander training preview

Games already implemented

  • Cartpole (Tested with the fully connected network)
  • Lunar Lander (Tested in deterministic mode with the fully connected network)
  • Gridworld (Tested with the fully connected network)
  • Tic-tac-toe (Tested with the fully connected network and the residual network)
  • Connect4 (Slightly tested with the residual network)
  • Gomoku
  • Twenty-One / Blackjack (Tested with the residual network)
  • Atari Breakout

Tests are done on Ubuntu with 16 GB RAM / Intel i7 / GTX 1050Ti Max-Q. We make sure to obtain a progression and a level which ensures that it has learned. But we do not systematically reach a human level. For certain environments, we notice a regression after a certain time. The proposed configurations are certainly not optimal and we do not focus for now on the optimization of hyperparameters. Any help is welcome.

Code structure

code structure

Network summary:

Authors

Please use this bibtex if you want to cite this repository (master branch) in your publications:

@misc{muzero-general,
  author       = {Werner Duvaud, Aurèle Hainaut},
  title        = {MuZero General: Open Reimplementation of MuZero},
  year         = {2019},
  publisher    = {GitHub},
  journal      = {GitHub repository},
  howpublished = {\url{https://github.com/werner-duvaud/muzero-general}},
}

Getting involved

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

muzero-baseline-0.4.0.tar.gz (45.9 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

muzero_baseline-0.4.0-py3-none-any.whl (80.3 kB view details)

Uploaded Python 3

File details

Details for the file muzero-baseline-0.4.0.tar.gz.

File metadata

  • Download URL: muzero-baseline-0.4.0.tar.gz
  • Upload date:
  • Size: 45.9 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/3.4.1 importlib_metadata/3.10.1 pkginfo/1.7.0 requests/2.24.0 requests-toolbelt/0.9.1 tqdm/4.51.0 CPython/3.7.10

File hashes

Hashes for muzero-baseline-0.4.0.tar.gz
Algorithm Hash digest
SHA256 4fe1bbb50665ad71d6ca41b38c9f4f553d4c0d5d2bb041a0cbd5a3ee3fc3f4f4
MD5 5e3efe52a0a5c89fbab9a34c7dcb39c8
BLAKE2b-256 25595913a774cb60ae85e45a5333a0cdf6edec272642e357d82291d13efeded1

See more details on using hashes here.

File details

Details for the file muzero_baseline-0.4.0-py3-none-any.whl.

File metadata

  • Download URL: muzero_baseline-0.4.0-py3-none-any.whl
  • Upload date:
  • Size: 80.3 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/3.4.1 importlib_metadata/3.10.1 pkginfo/1.7.0 requests/2.24.0 requests-toolbelt/0.9.1 tqdm/4.51.0 CPython/3.7.10

File hashes

Hashes for muzero_baseline-0.4.0-py3-none-any.whl
Algorithm Hash digest
SHA256 d5c938d9f388729c561bd61182dc04ae53f18729dbc6f9a0658a06ba1b7c7b31
MD5 b86a00902a5b10fe1e886cf7d8cbbd5e
BLAKE2b-256 b0e431c498b33492fe8b471ccd73a6a97f2c64a3823e281fcc4196ced0890e02

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page