Skip to main content

Counterfactual Regret Minimization in Jax

Project description

cfrx: Counterfactual Regret Minimization in Jax.

cfrx is an open-source library designed for efficient implementation of counterfactual regret minimization (CFR) algorithms using JAX. It focuses on computational speed and easy parallelization on hardware accelerators like GPUs and TPUs.

Key Features:

  • JIT Compilation for Speed: cfrx makes the most out of JAX's just-in-time (JIT) compilation to minimize runtime overhead and maximize computational speed.

  • Hardware Accelerator Support: It supports parallelization on GPUs and TPUs, enabling efficient scaling of computations for large-scale problems.

  • Python/JAX Ease of Use: cfrx provides a Pythonic interface built on JAX, offering simplicity and accessibility compared to traditional C++ implementations or prohibitively slow pure-Python code.

Installation

pip install cfrx

Getting started

An example notebook is available here.

Snippet for training a MCCFR-outcome sampling on the Kuhn Poker game.

import jax

from cfrx.envs.kuhn_poker.env import KuhnPoker
from cfrx.policy import TabularPolicy
from cfrx.trainers.mccfr import MCCFRTrainer

env = KuhnPoker()

policy = TabularPolicy(
    n_actions=env.n_actions,
    exploration_factor=0.6,
    info_state_idx_fn=env.info_state_idx,
)

random_key = jax.random.PRNGKey(0)

trainer = MCCFRTrainer(env=env, policy=policy)

training_state = trainer.train(
    random_key=random_key, n_iterations=100_000, metrics_period=5_000
)

Implemented features and upcoming features

Algorithms
MCCFR (outcome-sampling) :white_check_mark:
MCCFR (other variants) :x:
Vanilla CFR :x:
Deep CFR :x:
Metrics
Exploitability :white_check_mark:
Local Best Response :x:
Environments
Kuhn Poker :white_check_mark:
Leduc Poker :white_check_mark:
Larger games :x:

Performance

Below is a small benchmark against open_spiel for MCCFR-outcome-sampling on Kuhn Poker and Leduc Poker. Compared to the Python API of open_spiel, cfrx has faster runtime and demonstrates similar convergence.

benchmarck_against_open_spiel_img

See also

cfrx is heavily inspired by the amazing google-deepmind/open_spiel library as well as by many projects from the Jax ecosystem and especially sotetsuk/pgx and google-deepmind/mctx.

Contributing

Contributions are welcome, refer to the contributions guidelines.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

cfrx-0.0.2.tar.gz (29.9 kB view details)

Uploaded Source

Built Distribution

cfrx-0.0.2-py3-none-any.whl (33.7 kB view details)

Uploaded Python 3

File details

Details for the file cfrx-0.0.2.tar.gz.

File metadata

  • Download URL: cfrx-0.0.2.tar.gz
  • Upload date:
  • Size: 29.9 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/5.0.0 CPython/3.10.13

File hashes

Hashes for cfrx-0.0.2.tar.gz
Algorithm Hash digest
SHA256 a8aca00772662981286ce08869cc1351697beaa290b1cfa5b285af9ba9c89ab2
MD5 651cd0f379e695214565c32a1909bba3
BLAKE2b-256 cf5ab8733324fca69ce6bc44b2d783051e8f67bd805eea1c46c58f7d1ee3c343

See more details on using hashes here.

File details

Details for the file cfrx-0.0.2-py3-none-any.whl.

File metadata

  • Download URL: cfrx-0.0.2-py3-none-any.whl
  • Upload date:
  • Size: 33.7 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/5.0.0 CPython/3.10.13

File hashes

Hashes for cfrx-0.0.2-py3-none-any.whl
Algorithm Hash digest
SHA256 0ec287eb6798bd25bea297bc08cfc5f41bd047c200fdc1ce4eef6b0805f5a3a2
MD5 6411557177e3d6c20b1dd448f8dfe768
BLAKE2b-256 34077e981fbf482984f26c58e5d3a2136219016087a1c7943493119146a9ae89

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page