Skip to main content

Multi-Agent Reinforcement Learning with JAX

Project description

JaxMARL

Installation | Quick Start | Environments | Algorithms | Citation

Overcooked mabrax STORM hanabi
coin_game MPE jaxnav SMAX

Multi-Agent Reinforcement Learning in JAX

JaxMARL combines ease-of-use with GPU-enabled efficiency, and supports a wide range of commonly used MARL environments as well as popular baseline algorithms. Our aim is for one library that enables thorough evaluation of MARL methods across a wide range of tasks and against relevant baselines. We also introduce SMAX, a vectorised, simplified version of the popular StarCraft Multi-Agent Challenge, which removes the need to run the StarCraft II game engine.

For more details, take a look at our blog post or our Colab notebook, which walks through the basic usage.

Environments 🌍

Environment Reference README Summary
🔴 MPE Paper Source Communication orientated tasks in a multi-agent particle world
🍲 Overcooked Paper Source Fully-cooperative human-AI coordination tasks based on the video game of the same name
🥘 OvercookedV2 Paper Source Partially observable and stochastic extention of Overcooked. Fully-cooperative.
🦾 Multi-Agent Brax Paper Source Continuous multi-agent robotic control based on Brax, analogous to Multi-Agent MuJoCo
🎆 Hanabi Paper Source Fully-cooperative partially-observable multiplayer card game
👾 SMAX Novel Source Simplified cooperative StarCraft micro-management environment
🧮 STORM: Spatial-Temporal Representations of Matrix Games Paper Source Matrix games represented as grid world scenarios
🧭 JaxNav Paper Source 2D geometric navigation for differential drive robots
🪙 Coin Game Paper Source Two-player grid world environment which emulates social dilemmas
💡 Switch Riddle Paper Source Simple cooperative communication game included for debugging

Baseline Algorithms 🦉

We follow CleanRL's philosophy of providing single file implementations which can be found within the baselines directory. We use Hydra to manage our config files, with specifics explained in each algorithm's README. Most files include wandb logging code, this is disabled by default but can be enabled within the file's config.

Algorithm Reference README
IPPO Paper Source
MAPPO Paper Source
IQL Paper Source
VDN Paper Source
QMIX Paper Source
TransfQMIX Paper Source
SHAQ Paper Source
PQN-VDN Paper Source

Installation 🧗

Environments - Before installing, ensure you have the correct JAX installation for your hardware accelerator. We have tested up to JAX version 0.4.36. The JaxMARL environments can be installed directly from PyPi:

pip install jaxmarl 

Algorithms - If you would like to also run the algorithms, install the source code as follows:

  1. Clone the repository:
    git clone https://github.com/FLAIROx/JaxMARL.git && cd JaxMARL
    
  2. Install requirements:
    pip install -e .[algs]
    export PYTHONPATH=./JaxMARL:$PYTHONPATH
    
  3. For the fastest start, we recommend using our Dockerfile, the usage of which is outlined below.

Development - If you would like to run our test suite, install the additonal dependencies with: pip install -e .[dev], after cloning the repository.

Quick Start 🚀

We take inspiration from the PettingZoo and Gymnax interfaces. You can try out training an agent in our Colab notebook. Further introduction scripts can be found here.

Basic JaxMARL API Usage 🖥️

Actions, observations, rewards and done values are passed as dictionaries keyed by agent name, allowing for differing action and observation spaces. The done dictionary contains an additional "__all__" key, specifying whether the episode has ended. We follow a parallel structure, with each agent passing an action at each timestep. For asynchronous games, such as Hanabi, a dummy action is passed for agents not acting at a given timestep.

import jax
from jaxmarl import make

key = jax.random.PRNGKey(0)
key, key_reset, key_act, key_step = jax.random.split(key, 4)

# Initialise environment.
env = make('MPE_simple_world_comm_v3')

# Reset the environment.
obs, state = env.reset(key_reset)

# Sample random actions.
key_act = jax.random.split(key_act, env.num_agents)
actions = {agent: env.action_space(agent).sample(key_act[i]) for i, agent in enumerate(env.agents)}

# Perform the step transition.
obs, state, reward, done, infos = env.step(key_step, state, actions)

Dockerfile 🐋

To help get experiments up and running we include a Dockerfile and its corresponding Makefile. With Docker and the Nvidia Container Toolkit installed, the container can be built with:

make build

The built container can then be run:

make run

Contributing 🔨

Please contribute! Please take a look at our contributing guide for how to add an environment/algorithm or submit a bug report. If you're looking for a project, we also have a few suggestions listed under the roadmap :)

Citing JaxMARL 📜

If you use JaxMARL in your work, please cite us as follows:
@inproceedings{
    flair2024jaxmarl,
    title={JaxMARL: Multi-Agent RL Environments and Algorithms in JAX},
    author={Alexander Rutherford and Benjamin Ellis and Matteo Gallici and Jonathan Cook and Andrei Lupu and Gar{\dh}ar Ingvarsson and Timon Willi and Ravi Hammond and Akbir Khan and Christian Schroeder de Witt and Alexandra Souly and Saptarashmi Bandyopadhyay and Mikayel Samvelyan and Minqi Jiang and Robert Tjarko Lange and Shimon Whiteson and Bruno Lacerda and Nick Hawes and Tim Rockt{\"a}schel and Chris Lu and Jakob Nicolaus Foerster},
    booktitle={The Thirty-eight Conference on Neural Information Processing Systems Datasets and Benchmarks Track},
    year={2024},
}

See Also 🙌

There are a number of other libraries which inspired this work, we encourage you to take a look!

JAX-native algorithms:

  • Mava: JAX implementations of popular MARL algorithms.
  • PureJaxRL: JAX implementation of PPO, and demonstration of end-to-end JAX-based RL training.

JAX-native environments:

  • Gymnax: Implementations of classic RL tasks including classic control, bsuite and MinAtar.
  • Jumanji: A diverse set of environments ranging from simple games to NP-hard combinatorial problems.
  • Pgx: JAX implementations of classic board games, such as Chess, Go and Shogi.
  • Brax: A fully differentiable physics engine written in JAX, features continuous control tasks.
  • XLand-MiniGrid: Meta-RL gridworld environments inspired by XLand and MiniGrid.
  • Craftax: (Crafter + NetHack) in JAX.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

jaxmarl-0.1.0.tar.gz (207.2 kB view details)

Uploaded Source

Built Distribution

jaxmarl-0.1.0-py3-none-any.whl (243.7 kB view details)

Uploaded Python 3

File details

Details for the file jaxmarl-0.1.0.tar.gz.

File metadata

  • Download URL: jaxmarl-0.1.0.tar.gz
  • Upload date:
  • Size: 207.2 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.1.0 CPython/3.9.22

File hashes

Hashes for jaxmarl-0.1.0.tar.gz
Algorithm Hash digest
SHA256 0b6ec67ddf2fc3f91451546ca26c223bd91efb12d44355eafcecbb4143943028
MD5 73a4f713f7e435d5881dc03aa73a093a
BLAKE2b-256 9fc3cd9b494bb18a05fee68559b3b5dd9c03734e22e5c839264df19d3bce9664

See more details on using hashes here.

File details

Details for the file jaxmarl-0.1.0-py3-none-any.whl.

File metadata

  • Download URL: jaxmarl-0.1.0-py3-none-any.whl
  • Upload date:
  • Size: 243.7 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.1.0 CPython/3.9.22

File hashes

Hashes for jaxmarl-0.1.0-py3-none-any.whl
Algorithm Hash digest
SHA256 6f5ee922f7647af039a593292467b2335799b0321dbe51c55146e8c00a33d1d8
MD5 7b5c35014ce4650273ae46ef87d8d7d3
BLAKE2b-256 8e5852f32160ab60f730d4bcb05ed8e98a0d386d0679a8d40b939acee62aac63

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page