Skip to main content

A lightweight, standalone, and modular Gumbel MCTS implementation

Project description

PyPI version Tests codecov docs PyPI Downloads License: MIT

gumbel-mcts

A lightweight and numba-accelerated Gumbel MCTS implementation.

Optimized for speed! Generates hundreds of thousands of sims / sec. :rocket:

Gumbel principle
Improving MuZero using the Gumbel top-k trick, by Xavier O'Keefe

Description

Gumbel sampling brought tremendous progress to MCTS, but efficient standalone implementation of Gumbel MCTS are missing.

We provide three MCTS implementations:

  • puct.py: an efficient implementation of PUCT MCTS. It produces the exact same output as a reference mcts_v2.py but but with a 2-20X speedup on both Mac and NVIDIA GPUs.

  • gumbel_dense.py: an implementation of Policy improvement by planning with Gumbel, offering massive learning efficiency when the simulation budget is low

  • gumbel_sparse.py: a sparse implementation of Gumbel MCTS, particularly useful for games with large action spaces (e.g. chess)

Our Gumbel implementation offers both simulation efficiency and speed.

See gumbel-mcts-benchmark for full benchmark and validation against a gold standard MCTS.

Usage

def play_game():
    logic = TicTacToeLogic()
    model = TinyModel()
    model.eval()

    board = np.zeros((3, 3), dtype=np.int8)
    player = 1
    symbols = {0: ".", 1: "X", 2: "O"}

    while True:
        tree = GumbelSparse(n_games=1, max_nodes=500, device="cpu", logic=logic)
        tree.initialize_roots([0], board.ravel()[None], np.array([player]))
        move = tree.run_simulation_batch(model, [0], num_simulations=50)
        action = move[0]

        _, winner, done, board = logic.fast_step(board, action, player)

Illustration

With a random model, Gumbel wins with low-budget but PUCT catches up. As soon as the model gets better than random, Gumbel wins.

PUCT vs Gumbel

PUCT vs Gumbel

PUCT vs Gumbel

Gumbel MCTS makes much better use of its simulation budget. With 8 sims on Gomoku, Gumbel finds the strategic moves while PUCT concentrates its visit counts at the wrong place.

Gomoku Heatmap 9x9 — PUCT vs Gumbel Dense

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

gumbel_mcts-0.1.1.tar.gz (32.4 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

gumbel_mcts-0.1.1-py3-none-any.whl (30.7 kB view details)

Uploaded Python 3

File details

Details for the file gumbel_mcts-0.1.1.tar.gz.

File metadata

  • Download URL: gumbel_mcts-0.1.1.tar.gz
  • Upload date:
  • Size: 32.4 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.9.6

File hashes

Hashes for gumbel_mcts-0.1.1.tar.gz
Algorithm Hash digest
SHA256 cb34ec32fd39c9247795b28a9bebe617320bb354efc837cf5493ba865041a9ec
MD5 2f83c090885dfb519a774d8ed6fd3abb
BLAKE2b-256 118d9946e7d79d5d7bc747e0b312155f2406706e97b946911f02c6fba8337962

See more details on using hashes here.

File details

Details for the file gumbel_mcts-0.1.1-py3-none-any.whl.

File metadata

  • Download URL: gumbel_mcts-0.1.1-py3-none-any.whl
  • Upload date:
  • Size: 30.7 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.9.6

File hashes

Hashes for gumbel_mcts-0.1.1-py3-none-any.whl
Algorithm Hash digest
SHA256 8059e0bc08640a4de4905667dbf2a62665bd9d6a95d6614fbf4363c857f4c49b
MD5 20d7fee77c90596cf41a23eab7faffb5
BLAKE2b-256 42a54f3aeebbebe934555e47671418dfc6f96aadb3e39bda9ed15ab0d9fbce78

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page