A lightweight, standalone, and modular Gumbel MCTS implementation
Project description
gumbel-mcts
A lightweight and numba-accelerated Gumbel MCTS implementation.
Optimized for speed! Generates hundreds of thousands of sims / sec. :rocket:
Improving MuZero using the Gumbel top-k trick, by Xavier O'Keefe
Description
Gumbel sampling brought tremendous progress to MCTS, but efficient standalone implementation of Gumbel MCTS are missing.
We provide three MCTS implementations:
-
puct.py: an efficient implementation of PUCT MCTS. It produces the exact same output as a reference mcts_v2.py but but with a 2-20X speedup on both Mac and NVIDIA GPUs. -
gumbel_dense.py: an implementation of Policy improvement by planning with Gumbel, offering massive learning efficiency when the simulation budget is low -
gumbel_sparse.py: a sparse implementation of Gumbel MCTS, particularly useful for games with large action spaces (e.g. chess)
Our Gumbel implementation offers both simulation efficiency and speed.
See gumbel-mcts-benchmark for full benchmark and validation against a gold standard MCTS.
Usage
def play_game():
logic = TicTacToeLogic()
model = TinyModel()
model.eval()
board = np.zeros((3, 3), dtype=np.int8)
player = 1
symbols = {0: ".", 1: "X", 2: "O"}
while True:
tree = GumbelSparse(n_games=1, max_nodes=500, device="cpu", logic=logic)
tree.initialize_roots([0], board.ravel()[None], np.array([player]))
move = tree.run_simulation_batch(model, [0], num_simulations=50)
action = move[0]
_, winner, done, board = logic.fast_step(board, action, player)
Illustration
With a random model, Gumbel wins with low-budget but PUCT catches up. As soon as the model gets better than random, Gumbel wins.
Gumbel MCTS makes much better use of its simulation budget. With 8 sims on Gomoku, Gumbel finds the strategic moves while PUCT concentrates its visit counts at the wrong place.
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file gumbel_mcts-0.1.1.tar.gz.
File metadata
- Download URL: gumbel_mcts-0.1.1.tar.gz
- Upload date:
- Size: 32.4 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.2.0 CPython/3.9.6
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
cb34ec32fd39c9247795b28a9bebe617320bb354efc837cf5493ba865041a9ec
|
|
| MD5 |
2f83c090885dfb519a774d8ed6fd3abb
|
|
| BLAKE2b-256 |
118d9946e7d79d5d7bc747e0b312155f2406706e97b946911f02c6fba8337962
|
File details
Details for the file gumbel_mcts-0.1.1-py3-none-any.whl.
File metadata
- Download URL: gumbel_mcts-0.1.1-py3-none-any.whl
- Upload date:
- Size: 30.7 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.2.0 CPython/3.9.6
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
8059e0bc08640a4de4905667dbf2a62665bd9d6a95d6614fbf4363c857f4c49b
|
|
| MD5 |
20d7fee77c90596cf41a23eab7faffb5
|
|
| BLAKE2b-256 |
42a54f3aeebbebe934555e47671418dfc6f96aadb3e39bda9ed15ab0d9fbce78
|