Skip to main content

JAX-accelerated meta-reinforcement learning environments inspired by XLand and MiniGrid

Project description

XLand-MiniGrid

Open In Colab

img

Meta-Reinforcement Learning in JAX

🥳 XLand-MiniGrid was accepted to Intrinsically Motivated Open-ended Learning workshop at NeurIPS 2023. We look forward to seeing everyone at the poster session!

XLand-MiniGrid is a suite of tools, grid-world environments and benchmarks for meta-reinforcement learning research inspired by the diversity and depth of XLand and the simplicity and minimalism of MiniGrid. Despite the similarities, XLand-MiniGrid is written in JAX from scratch and designed to be highly scalable, democratizing large-scale experimentation with limited resources. Ever wanted to reproduce a DeepMind AdA agent? Now you can and not in months, but days!

Features

  • 🔮 System of rules and goals that can be combined in arbitrary ways to produce diverse task distributions
  • 🔧 Simple to extend and modify, comes with example environments ported from the original MiniGrid
  • 🪄 Fully compatible with all JAX transformations, can run on CPU, GPU and TPU
  • 📈 Easily scales to $2^{16}$ parallel environments and millions of steps per second on a single GPU
  • 🔥 Multi-GPU PPO baselines in the PureJaxRL style, which can achieve 1 trillion environment steps under two days

How cool is that? For more details, take a look at the technical paper (soon) or examples, which will walk you through the basics and training your own adaptive agents in minutes!

img TODO: update this with the latest version of the codebase...

Installation 🎁

⚠️ XLand-MiniGrid is currently in alpha stage, so expect breaking changes! ⚠️

The latest release of XLand-MiniGrid can be installed directly from PyPI:

pip install xminigrid
# or, from github directly
pip install "xminigrid @ git+https://github.com/corl-team/xland-minigrid.git"

Alternatively, if you want to install the latest development version from the GitHub and run provided algorithms or scripts, install the source as follows:

git clone git@github.com:corl-team/xland-minigrid.git
cd xland-minigrid
# additional dependencies for baselines
pip install -e ".[dev,benchmark]"

Note that the installation of JAX may differ depending on your hardware accelerator! We advise users to explicitly install the correct JAX version (see the official installation guide).

Basic Usage 🕹️

Most users who are familiar with other popular JAX-based environments (such as gymnax or jumnaji), will find that the interface is very similar. On the high level, current API combines dm_env and gymnax interfaces.

import jax
import xminigrid

key = jax.random.PRNGKey(0)
reset_key, ruleset_key = jax.random.split(key)

# to list available benchmarks: xminigrid.registered_benchmarks()
benchmark = xminigrid.load_benchmark(name="trivial-1m")
# choosing ruleset, see section on rules and goals
ruleset = benchmark.sample_ruleset(ruleset_key)

# to list available environments: xminigrid.registered_environments()
env, env_params = xminigrid.make("XLand-MiniGrid-R9-25x25")
env_params = env_params.replace(ruleset=ruleset)

# fully jit-compatible step and reset methods
timestep = jax.jit(env.reset)(env_params, reset_key)
timestep = jax.jit(env.step)(env_params, timestep, action=0)

# optionally render the state
env.render(env_params, timestep)

Similar to gymnasium or jumanji, users can register new environment variations with register for convenient further usage with make. timestep is a dataclass containing step_type, reward, discount, observation, as well as the internal environment state.

For a bit more advanced introduction see provided walkthrough notebook.

On environment interface

Currently, there are a lot of new JAX-based environments appearing, each offering its own variant of API. Initially, we tried to reuse Jumaji, but it turned out that its design is not suitable for meta learning. The Gymnax design appeared to be more appropriate, but unfortunately it is not actively supported and often departs from the idea that parameters should only be contained in env_params. Furthermore, splitting timestep into multiple entities seems suboptimal to us, as it complicates many things, such as envpool or dm_env style auto reset, where the reset occurs on the next step (we need access to done of previous step).

Therefore, we decided that we would make a minimal interface that would cover just our needs without the goal of making it generic. The core of our library is interface independent, and we plan to switch to the new one when/if a better design becomes available (e.g. when stable Gymnasium FuncEnv is released).

Rules and Goals 🔮

In XLand-MiniGrid, the system of rules and goals is the cornerstone of the emergent complexity and diversity. In the original MiniGrid some environments have dynamic goals, but the dynamics are never changed. To train and evaluate highly adaptive agents, we need to be able to change the dynamics in non-trivial ways.

Rules are the functions that can change the environment state in some deterministic way according to the given conditions. Goals are similar to rules, except they do not change the state, they only test conditions. Every task should be described with a goal, rules and initial objects. We call these rulesets. Currently, we support only one goal per task.

To illustrate, we provide visualization for specific ruleset. To solve this task, agent should take blue pyramid and put it near the purple square to transform both objects into red circle. To complete the goal, red circle should be placed near green circle. However, placing purple square near yellow circle will make it unsolvable in this trial. Initial objects positions will be randomized on each reset.

For more advanced introduction, see corresponding section in the provided walkthrough notebook.

Benchmarks 🎲

While composing rules and goals by hand is flexible, it can quickly become cumbersome. Besides, it's hard to express efficiently in a JAX-compatible way due to the high number of heterogeneous computations

To avoid significant overhead during training and facilitate reliable comparisons between agents, we pre-sampled several benchmarks with up to five million unique tasks, following the procedure used to train DeepMind AdA agent from the original XLand. These benchmarks differ in the generation configs, producing distributions with varying levels of diversity and average difficulty of the tasks. They can be used for different purposes, for example the trivial-1m benchmark can be used to debug your agents, allowing very quick iterations. However, we would caution against treating benchmarks as a progression from simple to complex. They are just different 🤷.

Pre-sampled benchmarks are hosted on HuggingFace and will be downloaded and cached on the first use:

import jax.random
import xminigrid
from xminigrid.benchmarks import Benchmark

# downloading to path specified by XLAND_MINIGRID_DATA,
# ~/.xland_minigrid by default
benchmark: Benchmark = xminigrid.load_benchmark(name="trivial-1m")
# reusing cached on the second use
benchmark: Benchmark = xminigrid.load_benchmark(name="trivial-1m")

# users can sample or get specific rulesets
benchmark.sample_ruleset(jax.random.PRNGKey(0))
benchmark.get_ruleset(ruleset_id=benchmark.num_rulesets() - 1)

# or split them for train & test
train, test = benchmark.shuffle(key=jax.random.PRNGKey(0)).split(prop=0.8)

We also provide the script used to generate these benchmarks. Users can use it for their own purposes:

python scripts/ruleset_generator.py --help

In depth description of all available benchmarks is provided here (soon).

P.S. Be aware, that benchmarks can change, as we are currently testing and balancing them!

Environments 🌍

We provide environments from two domains. XLand is our main focus for meta-learning. For this domain we provide single environment and numerous registered variants with different grid layouts and sizes. All of them can be combined with arbitrary rulesets.

To demonstrate the generality of our library we also port majority of non-language based tasks from original MiniGrid. Similarly, some environments come with multiple registered variants. However, we have no current plans to actively develop and support them (but that may change).

Name Domain Visualization Goal
XLand-MiniGrid XLand specified by the provided ruleset
MiniGrid-Empty MiniGrid go to the green goal
MiniGrid-EmptyRandom MiniGrid go the green goal from different starting positions
MiniGrid-FourRooms MiniGrid go the green goal, but goal and starting positions are randomized
MiniGrid-LockedRoom MiniGrid find the key to unlock the door, go to the green goal
MiniGrid-Memory MiniGrid remember the initial object and choose it at the end of the corridor
MiniGrid-Playground MiniGrid goal is not specified
MiniGrid-Unlock MiniGrid unlock the door with the key
MiniGrid-UnlockPickUp MiniGrid unlock the door and pick up the object in another room
MiniGrid-BlockedUnlockPickUp MiniGrid unlock the door blocked by the object and pick up the object in another room
MiniGrid-DoorKey MiniGrid unlock the door and go to the green goal

Users can get all registered environments with xminigrid.registered_environments(). We also provide manual control to easily explore the environments:

python -m xminigrid.manual_control --env-id="MiniGrid-Empty-8x8"

Baselines 🚀

In addition to the environments, we provide high-quality almost single-file implementations of recurrent PPO baselines in the style of PureJaxRL. With the help of magical jax.pmap transformation they can scale to multiple accelerators, achieving impressive FPS of millions during training.

Agents can be trained from the terminal and default arguments can be overwritten from the command line or from the yaml config:

# for meta learning
python training/train_meta_task.py \
    --config-path='some-path/config.yaml' \
    --env_id='XLand-MiniGrid-R1-9x9'

# for minigrid envs
python training/train_singe_task.py \
    --config-path='some-path/config.yaml' \ 
    --env_id='XLand-MiniGrid-R1-9x9'

For the source code and hyperparameters available see /training or run python training/train_meta_task.py --help. Furthermore, we provide standalone implementations that can be trained in Colab: xland, minigrid.

P.S. Do not expect that provided baselines will solve the hardest environments or benchmarks available. How much fun would that be 🤔? However, we hope that they will help to get started quickly!

Roadmap 🗓️

With the initial release of XLand-MiniGrid, things are just getting started. There is a long way to go in terms of polishing the code, adding new features, and improving the overall user experience. What we currently plan to improve in forthcoming releases:

  1. Tweaks to the benchmark generation, time-limits
  2. Documentation (in code and as standalone site)
  3. Full type hints coverage, type checking
  4. Tests
  5. More examples and tutorials

After that we will start thinking on new major features, environments and bechmarks. However, we should perfect the core before that.

Contributing 🔨

We welcome anyone interested in helping out! Please take a look at our contribution guide for further instructions and open an issue if something is not clear.

See Also 🔎

A lot of other work is going in a similar direction, transforming RL through JAX. Many of them have inspired us, and we encourage users to check them out as well.

  • Brax - fully differentiable physics engine used for research and development of robotics.
  • Gymnax - implements classic environments including classic control, bsuite, MinAtar and simplistic meta learning tasks.
  • Jumanji - a diverse set of environments ranging from simple games to NP-hard combinatorial problems.
  • Pgx - JAX implementations of classic board games, such as Chess, Go and Shogi.
  • JaxMARL - multi-agent RL in JAX with wide range of commonly used environments.

Let's build together!

Citation 🙏

@inproceedings{
    nikulin2023xlandminigrid,
    title={{XL}and-MiniGrid: Scalable Meta-Reinforcement Learning Environments in {JAX}},
    author={Alexander Nikulin and Vladislav Kurenkov and Ilya Zisman and Viacheslav Sinii and Artem Agarkov and Sergey Kolesnikov},
    booktitle={Intrinsically-Motivated and Open-Ended Learning Workshop, NeurIPS2023},
    year={2023},
    url={https://openreview.net/forum?id=xALDC4aHGz}
}

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

xminigrid-0.4.0.tar.gz (47.9 kB view details)

Uploaded Source

Built Distribution

xminigrid-0.4.0-py3-none-any.whl (51.8 kB view details)

Uploaded Python 3

File details

Details for the file xminigrid-0.4.0.tar.gz.

File metadata

  • Download URL: xminigrid-0.4.0.tar.gz
  • Upload date:
  • Size: 47.9 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/4.0.2 CPython/3.10.12

File hashes

Hashes for xminigrid-0.4.0.tar.gz
Algorithm Hash digest
SHA256 bc3971577cdabba3282ac585607a50e92a67ec5fe444bfb8e705cf6ee8dfc9c5
MD5 abef89c2c4d19ec6d68a3b007bc3ad16
BLAKE2b-256 7019175938728534cc1bd21e8305f2e722e0f285322918235af0f9020135fef6

See more details on using hashes here.

File details

Details for the file xminigrid-0.4.0-py3-none-any.whl.

File metadata

  • Download URL: xminigrid-0.4.0-py3-none-any.whl
  • Upload date:
  • Size: 51.8 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/4.0.2 CPython/3.10.12

File hashes

Hashes for xminigrid-0.4.0-py3-none-any.whl
Algorithm Hash digest
SHA256 87d0b59f694ff4089d093f61af06db123da5fc4a791405a490f26b38e10a5992
MD5 125205ee523061395099eeaa3ef5ad90
BLAKE2b-256 e86d295a166f87d4b6aec741ba37d5f9985989cdb2af84e9a0a540ffb58e6eb8

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page