Skip to main content

GPU-accelerated vectorized mahjong simulators for reinforcement learning

Project description


MahJax

A GPU-Accelerated Japanese Riichi Mahjong Simulator for RL in JAX

Japanese Riichi Mahjong is a complex board game that presents a unique combination of imperfect information, multi-player (>2) competition, stochastic dynamics, and high-dimensional inputs. MahJax is highly inspired by Pgx, which offers vectorized simulators for a diverse set of board games. While Pgx includes imperfect information games (such as miniature poker and mahjong), its primary emphasis is on deterministic perfect-information games like Go, Chess, and Shogi. We aim to complement this by offering a full-scale Japanese Riichi Mahjong environment written entirely in JAX.

Overview

  • Vectorized Environment: Fully JIT-compilable and extremely fast (approx. 1.6M steps/sec on 8x A100 GPUs).
  • Beautiful Visualization: Like Pgx, we offer SVG-based game visualization. We also provide an English tile version for those unfamiliar with Chinese characters (Kanji).
  • Playable Interface: A web-based UI allows you to play directly against the agents you train.
  • RL Examples: We provide simple examples for Behavior Cloning and Reinforcement Learning in the examples/ directory.

For more details, please refer to the Documentation (TODO links).

Quick Start

Install

Mahjax is available on PyPI. Please make sure that your Python environment has jax and jaxlib installed, depending on your hardware specification.

pip install mahjax

Basic Usage

We basically follow the Pgx API design.

import jax
import jax.numpy as jnp
import mahjax

batch_size = 10
rng = jax.random.PRNGKey(0)

# Initialize environment
env = mahjax.make(
    "no_red_mahjong",
    one_round=True,      # True: Single round, False: Hanchan (East-South game)
    observe_type="dict", # "dict" for Transformer, "2D" for CNN
    order_points=[30, 10, -10, -30] # Final score bonuses (uma)
)

init_fn = jax.jit(jax.vmap(env.init))
step_fn = jax.jit(jax.vmap(env.step))
obs_fn = jax.jit(jax.vmap(env.observe))

# Initialize state
rng, subrng = jax.random.split(rng)
rngs = jax.random.split(subrng, batch_size)
state = init_fn(rngs)

# Step
rng, subrng = jax.random.split(rng)
rngs = jax.random.split(subrng, batch_size)
action = jnp.zeros((batch_size,), dtype=jnp.int8)
state = step_fn(state, action, rngs)

# Get observation
obs = obs_fn(state)

On rules of JAPANESE RIICHI Mahjong

There are several variants of Japanese Riichi Mahjong. The most significant distinction is the inclusion of "Red 5" tiles (aka-dora).

  • Current Support: Standard 4-player rules without red tiles.
  • Future Plans: We plan to incorporate popular variants, including Red 5 tiles and 3-player Mahjong (Sanma).

User interface

MahJax includes a web-based UI (FastAPI + JS) that allows you to play against built-in or custom agents directly in your browser.

Running the UI

Install dependencies and start the server:

pip install mahjax
uvicorn mahjax.ui.app:create_app --host 0.0.0.0 --port 8000

Open http://localhost:8000 to start playing. The default agents are random and rule_based one.

Playing Against Your Agent

You can register your trained agent to appear in the UI's agent selector. Create a python script (e.g., my_app.py) and register your agent's act function:

### my_app.py
from pathlib import Path
from mahjax.ui.app import create_app

app = create_app()

# Load your custom agent
app.state.manager.registry.load_callable_from_path(
    file_path=Path("path/to/my_agent.py"),
    attribute="act", # The function name to call: act(state, rng) -> action_id
    description="My Custom Agent",
)

Run uvicorn my_ui:app --port 8000.

See also

Jax based environments

  • Pgx: Boad game environments such as Go, chess, and Shogi.
  • Brax: Robotics control.
  • Gymnax: Popular small scale RL environments such as cartpole or bsuite.
  • Jumanji: A diverse suite of RL environments (paking, routing, etc).
  • Craftax: JAX-version of (Crafter + Nethack).
  • JaxMARL: Multi-agent environments such as Hanabi.
  • Navix: JAX-version of MiniGrid.

Cite us

Paper comming soon.

Acknowledgements

  • sotetsuk: For general advice on the development of mahjax based on his experience of developping pgx
  • habara-k: For developing core JAX components such as shanten and Yaku calculation.
  • OkanoShinri: For the initial implementation of MahJax and its SVG visualization.
  • easonyu0203: For advise on PPO implementation in multi-player imperfect information game.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

mahjax-0.0.1.tar.gz (22.4 MB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

mahjax-0.0.1-py3-none-any.whl (22.6 MB view details)

Uploaded Python 3

File details

Details for the file mahjax-0.0.1.tar.gz.

File metadata

  • Download URL: mahjax-0.0.1.tar.gz
  • Upload date:
  • Size: 22.4 MB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.10.0

File hashes

Hashes for mahjax-0.0.1.tar.gz
Algorithm Hash digest
SHA256 ab2f873911c864bbacb3659ae30971a824e68484bedc8869e043b7ab35f8c0e0
MD5 0ef5deba541fad4ee888f9c9404945ff
BLAKE2b-256 1d234882fda8d24e2aae89d7b2074b8ffee750f1cc74a2c82bd072e1e0dbc542

See more details on using hashes here.

File details

Details for the file mahjax-0.0.1-py3-none-any.whl.

File metadata

  • Download URL: mahjax-0.0.1-py3-none-any.whl
  • Upload date:
  • Size: 22.6 MB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.10.0

File hashes

Hashes for mahjax-0.0.1-py3-none-any.whl
Algorithm Hash digest
SHA256 5fb1cd7721ddcdaaebd694a323505a453bedd2712b9a5cd77f9d4205ac9f7674
MD5 174d178bc595feb04f8e70808b7948de
BLAKE2b-256 b2a4e704ef467d0cac056e97ab2db8ef27636d0a312e35511f4ffbe543b0f056

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page