GPU-accelerated vectorized mahjong simulators for reinforcement learning
Project description
MahJax
A GPU-Accelerated Mahjong Simulator for Reinforcement Learning in JAX
[!NOTE] Japanese Riichi Mahjong is a challenging multi-agent RL environment with imperfect information, stochastic dynamics, more than two players, and high-dimensional observations. Mahjax aims to make Mahjong research more accessible to a broader RL community. For newcommers, please see our basic introduction and the bilingual visualization.
Overview
- 🚀 Vectorized Environment: Extremely fast (approx. 1.6M steps/sec on 8x A100 GPUs).
- 🎨 Rich Visualization: SVG-based visualization with bilingual support for those unfamiliar with Kanji.
- 🎮 Playable Interface: A web-based UI allows you to play directly against the agents you train.
- 📚 RL Examples: Simple examples for Behavior Cloning + PPO in the
examples/.
For more details, please refer to the Documentation.
Quick Start
Install
Mahjax is available on PyPI. Please make sure that your Python environment has jax and jaxlib installed, depending on your hardware setup.
pip install mahjax
📣 Mahjax is currently under active development. If you prefer to use the latest codebase with the newest features, please clone the repository and install it in editable mode:
git clone https://github.com/nissymori/mahjax.git
cd mahjax
pip install -e .
[!NOTE] The current API is still provisional and under active development, so it may change in future releases.
Basic Usage
We basically follow the Pgx API design.
import jax
import jax.numpy as jnp
import mahjax
batch_size = 10
rng = jax.random.PRNGKey(0)
# Initialize environment
env = mahjax.make(
"red_mahjong",
round_mode="single", # "single", "east" (tonpuusen), or "half" (hanchan)
observe_type="dict", # "dict" for Transformer, "2D" for CNN
order_points=[30, 10, -10, -30] # Final score bonuses (uma)
)
init_fn = jax.jit(jax.vmap(env.init))
step_fn = jax.jit(jax.vmap(env.step))
obs_fn = jax.jit(jax.vmap(env.observe))
# Initialize state
rng, subrng = jax.random.split(rng)
rngs = jax.random.split(subrng, batch_size)
state = init_fn(rngs)
# Step
rng, subrng = jax.random.split(rng)
rngs = jax.random.split(subrng, batch_size)
action = jnp.zeros((batch_size,), dtype=jnp.int8)
state = step_fn(state, action, rngs)
# Get observation
obs = obs_fn(state)
User interface
MahJax includes a web-based UI (FastAPI + JS) that allows you to play against built-in or custom agents directly in your browser.
Running the UI
Install dependencies and start the server:
pip install mahjax
uvicorn mahjax.ui.app:create_app --host 0.0.0.0 --port 8000
Open http://localhost:8000 to start playing. The default agents are the random and rule_based ones.
Playing Against Your Agent
You can register your trained agent to appear in the UI's agent selector.
Create a Python script (e.g., my_app.py) and register your agent's act function:
### my_app.py
from pathlib import Path
from mahjax.ui.app import create_app
app = create_app()
# Load your custom agent
app.state.manager.registry.load_callable_from_path(
file_path=Path("path/to/my_agent.py"),
attribute="act", # The function name to call: act(state, rng) -> action_id
description="My Custom Agent",
)
Run uvicorn my_ui:app --port 8000.
Supported Rules
Currently, MahJax supports the following rules:
| Rule | id | Status | Code | Speed (steps/sec) |
|---|---|---|---|---|
| No-Red Mahjong | no_red_mahjong |
✅ | no_red_mahjong | ~1.6M |
| Red Mahjong | red_mahjong |
✅ | red_mahjong | ~9M |
| Selective Rules | - | 🚧 | - | - |
| 3-player Mahjong | - | 🚧 | - | - |
red_mahjong implements standard 4-player riichi mahjong with red fives.
Its rules are designed to follow Tenhou, one of the most widely used online mahjong platforms in Japan, and we validate the implementation against downloaded Tenhou game logs.
For the detailed rule specification, see the official Tenhou rules.
no_red_mahjong implements 4-player riichi mahjong without red fives.
This variant is intentionally simplified for speed, and excludes some rules such as abortive draws (特殊流局), pao, and double ron.
If throughput is your priority, no_red_mahjong is the recommended option.
You can configure the environment with:
id: the rule set, such asred_mahjongorno_red_mahjonground_mode:singlefor a single round,eastfor tonpuusen (East-only), orhalffor hanchan (East-South)observe_type:dictfor transformer-style inputs or2Dfor CNN-style inputsorder_points: final placement bonuses (uma), for example[30, 10, -10, -30]
env = mahjax.make(
id="red_mahjong",
round_mode="single",
observe_type="dict",
order_points=[30, 10, -10, -30],
)
[!NOTE] The observation features are not yet finalized (though the current version suffices for RL with BC; see examples/).
See also
JAX-based environments
- Pgx: Board game environments such as Go, Chess, and Shogi.
- Brax: Robotics control.
- Gymnax: Popular small-scale RL environments such as CartPole or bsuite.
- Jumanji: A diverse suite of RL environments (packing, routing, etc.).
- Craftax: A JAX version of Crafter + Nethack.
- JaxMARL: Multi-agent environments such as Hanabi.
- Navix: A JAX version of MiniGrid.
Cite us
Paper coming soon.
Acknowledgement
- sotetsuk: For general advice on the development of mahjax based on his experience developing pgx.
- habara-k: For developing core JAX components such as shanten and Yaku calculation.
- OkanoShinri: For the initial implementation of MahJax and its SVG visualization.
- easonyu0203: For advice on PPO implementation in a multi-player imperfect-information game.
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file mahjax-0.1.0.tar.gz.
File metadata
- Download URL: mahjax-0.1.0.tar.gz
- Upload date:
- Size: 1.1 MB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.2.0 CPython/3.10.0
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
b4ab1671ba8e5f4161df9083e3e466cc4217817b7c8b27dc1181a030cab3a5fa
|
|
| MD5 |
bf655e80a14cf346395d69c2100764af
|
|
| BLAKE2b-256 |
a715801314c16e293e4a233c60071279947cdc166055012417439f9e4f71c4f2
|
File details
Details for the file mahjax-0.1.0-py3-none-any.whl.
File metadata
- Download URL: mahjax-0.1.0-py3-none-any.whl
- Upload date:
- Size: 1.1 MB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.2.0 CPython/3.10.0
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
1258b34204f3c5b79c243c0e9b084ffc94577d7d5a04faafe4eead8bcaa4b89e
|
|
| MD5 |
f408d4d35208a498b540c3784d4efac3
|
|
| BLAKE2b-256 |
999a7d7785b0cd87571b302040cd02fa8fd5e022e501c9528b5dd66f4357febd
|