Skip to main content

JAX-native platform for massively parallel control, system identification, and active learning of uncertain, stochastic systems.

Project description

Myriad

Python 3.11+ PyPI version License: MIT Docs Build codecov pre-commit Code style: black Checked with mypy Ruff

JAX-native platform for massively parallel control, system identification, and active learning of uncertain, stochastic systems.

[!WARNING] Myriad is in early active development. APIs will change, documentation has gaps, and some features are still taking shape. Contributions, feedback, and ideas are very welcome! Open a discussion or reach out to Robin (robin.henry@eng.ox.ac.uk).

Documentation: docs

At a Glance 🌟

Myriad is a playground to explore RL, traditional control, system identification, and active learning — with a focus on problems where uncertainty, stochasticity, and rare discrete dynamics play a big role and force us to study very large numbers of variants in parallel (think: biology → system = cell, chemistry → system = reactor). 🛝

It's a ready-to-go experimental platform. You can use one of the already-implemented tasks, algorithms, or implement your own and simply plug them in. Myriad will handle the intricacies of JAX/GPU optimization, training/evaluation loops, hyperparameter tracking, metrics logging, and many more not-so-fun things — freeing time for the more fun science and engineering bits. 👩🏾‍🔬👨🏻‍🔬

Last but not least, it yields results that are reproducible on the same GPU hardware. 🌟

Interested in the story behind Myriad? Read our Motivation & Philosophy.

Key Features

  • ⚡ Massive GPU Parallelism: run algorithms on 1M+ of environments simultaneously.

  • 🏎️ JAX JIT Optimization: Myriad is fast, even on CPU.

  • ✅ Reproducible: Same seed + config yields bit-identical results on the same GPU hardware. Minor numerical differences may occur across different GPU models or driver versions due to non-deterministic GPU kernel optimizations → great for science.

  • 🎲 Exact Stochastic Simulations: native JAX implementation of the Gillespie Algorithm (aka SSA) for discrete, asynchronous molecular events.

  • ∇ Differentiable "White-Box" Physics: exposes underlying physics, ODEs, and jump processes for gradient-based system ID and active learning.

  • 🛠 Research-Ready: pre-configured with Hydra, Pydantic, and W&B support.

Installation 🌱

Requirements: Python 3.11+, JAX 0.7+.

[!IMPORTANT] GPU Support: JAX installation can be hardware-specific. We strongly recommend installing JAX according to your CUDA/cuDNN version before installing Myriad if you encounter issues.

# Standard installation
pip install myriad-jax

# With generic GPU dependencies (checks for nvidia-related packages)
pip install "myriad-jax[gpu]"

Quickstart 🏁

Myriad is designed to be used programmatically (for research loops) or via CLI (for massive sweeps).

Python API

from myriad import create_config, train_and_evaluate

# Configure a gene expression control experiment across 10k cells
config = create_config(
    env="gene-circuit-v1",      # A stochastic gene circuit (Gillespie)
    agent="dqn",                # The algorithm to use (eg, 'pid', 'pqn')
    num_envs=10_000,            # 10k parallel simulations (cells)
    scan_autotune=True,         # Automatically optimize GPU training loop parameters
)

# Run the experiment (JIT-compiled & distributed on GPU)
results = train_and_evaluate(config)

# Inspect performance metrics
metrics = results.eval_metrics
print(f"Return (mean +/- std): {metrics.mean_return} +/- {metrics.std_return}")

# Quickly plot convergence curve
TO ADD

CLI Usage

Leverage Hydra to run massive parameter sweeps or experiments directly from the terminal.

# Train a DQN agent on 50,000 parallel cartpole environments
myriad train env=cartpole-control run.num_envs=50000 agent=dqn

See the Documentation for further information, including tutorials.

Flagship Environments 🌍

To add.

Contributing 🛠️

Our goal is for Myriad to become a platform that accelerates RL/control research, especially in the life sciences. As such, we'd love to have others contribute!

Please take a look at the contributing guide for instructions on how to add new environments, algorithms, or for other ways to contribute.

If in doubt, always feel free to reach out by opening a discussion.

See Also 🔎

Here is a non-exhaustive list of other JAX x RL libraries, some of which inspired the development of Myriad.

Environments:

  • Gymnax: classic environments including classic control, bsuite, MinAtar, and meta RL tasks.
  • Brax: a differentiable physics engine for rigid body control tasks.
  • JaxMARL: multi-agent RL tasks.
  • Jumanji: a diverse suite of environments, ranging from simple games to NP-hard combinatorial problems.
  • Pgx: classic board game environments such as Chess, Go, and Shogi.
  • XLand-Minigrid: meta RL gridworld environments.
  • Craftax: Crafter + NetHack in JAX.

Algorithms:


A little bit of history: Myriad is named after the Greek myrias ("ten thousand"), inspired by microfluidic "mother machines" that observe 100,000+ cells simultaneously. It brings this paradigm to computational research: providing a myriad of viewpoints from which to learn about and control complex systems — whether they are biological circuits, chemical reactors, or robotic swarms.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

myriad_jax-0.2.0.tar.gz (107.9 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

myriad_jax-0.2.0-py3-none-any.whl (151.2 kB view details)

Uploaded Python 3

File details

Details for the file myriad_jax-0.2.0.tar.gz.

File metadata

  • Download URL: myriad_jax-0.2.0.tar.gz
  • Upload date:
  • Size: 107.9 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: twine/6.1.0 CPython/3.13.7

File hashes

Hashes for myriad_jax-0.2.0.tar.gz
Algorithm Hash digest
SHA256 66b836b1f2e28e3e1e3b966b91b264a8b7ef10ef7311e78f6d9e553746d3f158
MD5 02784aae67b521d63e04b28fbf8a386e
BLAKE2b-256 7f245759b664173037d6f91b8a65689d78ec9e473e3eba0d6eefff42806bffd4

See more details on using hashes here.

Provenance

The following attestation bundles were made for myriad_jax-0.2.0.tar.gz:

Publisher: release.yml on robinhenry/myriad-jax

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

File details

Details for the file myriad_jax-0.2.0-py3-none-any.whl.

File metadata

  • Download URL: myriad_jax-0.2.0-py3-none-any.whl
  • Upload date:
  • Size: 151.2 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: twine/6.1.0 CPython/3.13.7

File hashes

Hashes for myriad_jax-0.2.0-py3-none-any.whl
Algorithm Hash digest
SHA256 1873ebe5ddb7db4a30b85b05874c20a3a9d1e11f1539e94b7e4a1668e10341a3
MD5 0c4e5b7f07dd50990331ead10045da42
BLAKE2b-256 7898551b1d96b80284f3e562993bb7c927559c292761db7ebe4f02b2eef1097f

See more details on using hashes here.

Provenance

The following attestation bundles were made for myriad_jax-0.2.0-py3-none-any.whl:

Publisher: release.yml on robinhenry/myriad-jax

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page