JAX implementation of Model Predictive Path Integral (MPPI) control
Project description
jax_mppi
jax_mppi is a functional, JIT-compilable port of the pytorch_mppi library to JAX. It implements Model Predictive Path Integral (MPPI) control with a focus on performance and composability.
Design Philosophy
This library embraces JAX's functional paradigm:
- Pure Functions: Core logic is implemented as pure functions
command(state, mppi_state) -> (action, mppi_state). - Dataclass State: State is held in
jax.tree_util.register_dataclasscontainers, allowing easy integration withjit,vmap, andgrad. - No Side Effects: Unlike the PyTorch version, there is no mutable
self. State transitions are explicit.
Key Features
- Core MPPI: Robust implementation of the standard MPPI algorithm.
- Smooth MPPI (SMPPI): Maintains action sequences and smoothness costs for better trajectory generation.
- Kernel MPPI (KMPPI): Uses kernel interpolation for control points, reducing the parameter space.
- Autotuning: Built-in hyperparameter optimization with multiple backends:
- CMA-ES (via
cmalibrary) - Classic evolution strategy - CMA-ES, Sep-CMA-ES, OpenES (via
evosax) - JAX-native, GPU-accelerated ⚡ - Ray Tune - Distributed hyperparameter search
- CMA-ME (via
ribs) - Quality diversity optimization
- CMA-ES (via
- JAX Integration:
jax.vmapfor efficient batch processing.jax.lax.scanfor fast horizon loops.- Fully compatible with JIT compilation for high-performance control loops.
Installation
# Install from PyPI
pip install jax-mppi
# Or with optional dependencies
pip install jax-mppi[dev] # Development tools
pip install jax-mppi[docs] # Documentation
pip install jax-mppi[autotuning] # Autotuning (cma + evosax)
pip install jax-mppi[autotuning-extra] # Ray Tune, Hyperopt, Ribs
Development Installation
For contributors who want to work on the package (requires Python 3.12+):
# Clone the repository
git clone https://github.com/riccardo-enr/jax_mppi.git
cd jax_mppi
# Install in development mode
pip install -e .
Versioning
This project uses Semantic Versioning following the major.minor.patch scheme:
- Major: Breaking changes to the API or significant feature additions.
- Minor: New features or enhancements that are backward compatible.
- Patch: Bug fixes and minor updates.
See CHANGELOG for detailed version history.
Usage
import jax
import jax.numpy as jnp
from jax_mppi import mppi
# Define dynamics and cost functions
def dynamics(state, action):
# Your dynamics model here
return state + action
def running_cost(state, action):
# Your cost function here
return jnp.sum(state**2) + jnp.sum(action**2)
# Create configuration and initial state
config, mppi_state = mppi.create(
nx=4, nu=2,
noise_sigma=jnp.eye(2) * 0.1,
horizon=20,
lambda_=1.0
)
# Control loop
key = jax.random.PRNGKey(0)
current_obs = jnp.zeros(4)
# JIT compile the command function for performance
jitted_command = jax.jit(mppi.command, static_argnames=['dynamics', 'running_cost'])
for _ in range(100):
key, subkey = jax.random.split(key)
action, mppi_state = jitted_command(
config,
mppi_state,
current_obs,
dynamics=dynamics,
running_cost=running_cost
)
# Apply action to environment...
Autotuning
JAX-MPPI includes powerful hyperparameter optimization capabilities. You can automatically tune MPPI parameters like lambda_, noise_sigma, and horizon using multiple optimization backends.
Quick Example
from jax_mppi import autotune, mppi
# Create MPPI configuration
config, state = mppi.create(nx=4, nu=2, horizon=20)
holder = autotune.ConfigStateHolder(config, state)
# Define what to tune
params_to_tune = [
autotune.LambdaParameter(holder, min_value=0.1),
autotune.NoiseSigmaParameter(holder, min_value=0.01),
]
# Define evaluation function
def evaluate():
# Run MPPI, return cost
# ... your evaluation logic ...
return autotune.EvaluationResult(mean_cost=cost, ...)
# Choose an optimizer
from jax_mppi import autotune_evosax # JAX-native, GPU-accelerated
optimizer = autotune_evosax.CMAESOpt(population=10, sigma=0.1)
# Or use classic CMA-ES
# optimizer = autotune.CMAESOpt(population=10, sigma=0.1)
# Run optimization
tuner = autotune.Autotune(
params_to_tune=params_to_tune,
evaluate_fn=evaluate,
optimizer=optimizer,
)
best = tuner.optimize_all(iterations=50)
Available Optimizers
| Optimizer | Backend | GPU Support | Best For |
|---|---|---|---|
autotune.CMAESOpt |
cma library |
❌ | Classic CMA-ES, stable |
autotune_evosax.CMAESOpt |
evosax | ✅ | JAX-native, 5-10x faster on GPU |
autotune_evosax.SepCMAESOpt |
evosax | ✅ | High-dimensional problems |
autotune_evosax.OpenESOpt |
evosax | ✅ | Large populations, parallelization |
autotune_global.RayOptimizer |
Ray Tune | ✅ | Distributed search |
autotune_qd.CMAMEOpt |
ribs | ❌ | Quality diversity |
Evosax vs CMA Library
Migrating from cma to evosax:
# Before (cma library)
from jax_mppi.autotune import CMAESOpt
optimizer = CMAESOpt(population=10, sigma=0.1)
# After (evosax - JAX-native)
from jax_mppi.autotune_evosax import CMAESOpt
optimizer = CMAESOpt(population=10, sigma=0.1)
Benefits of evosax:
- ⚡ 5-10x faster on GPU due to JIT compilation
- 🔧 Multiple strategies (CMA-ES, Sep-CMA-ES, OpenES, SNES, xNES)
- 🎯 JAX-native - seamless integration with JAX code
- 📦 Pure Python - no external C++ dependencies
See examples/autotune_evosax_comparison.py for a detailed performance comparison.
Project Structure
jax_mppi/
├── src/jax_mppi/
│ ├── mppi.py # Core MPPI implementation
│ ├── smppi.py # Smooth MPPI variant
│ ├── kmppi.py # Kernel MPPI variant
│ ├── types.py # Type definitions
│ ├── autotune.py # Autotuning core & CMA-ES (cma lib)
│ ├── autotune_evosax.py # JAX-native optimizers (evosax)
│ ├── autotune_global.py # Ray Tune integration
│ └── autotune_qd.py # Quality Diversity optimization
├── examples/
│ ├── pendulum.py # Pendulum environment example
│ ├── autotune_basic.py # Basic autotuning example
│ ├── autotune_pendulum.py # Autotuning pendulum
│ ├── autotune_evosax_comparison.py # Evosax vs cma performance
│ └── smooth_comparison.py # Comparison of MPPI variants
└── tests/ # Unit and integration tests
Roadmap
The development is structured in phases:
- Core MPPI: Basic implementation with JAX parity.
- Integration: Pendulum example and verification.
- Smooth MPPI: Implementation of smoothness constraints.
- Kernel MPPI: Kernel-based control parameterization.
- Comparisons: Benchmarking and visual comparisons.
- Autotuning: Parameter optimization using CMA-ES, Ray Tune, and QD.
Credits
This project is a direct port of pytorch_mppi. We aim to maintain parity with the original implementation while leveraging JAX's unique features for performance and flexibility.
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file jax_mppi-0.1.8.tar.gz.
File metadata
- Download URL: jax_mppi-0.1.8.tar.gz
- Upload date:
- Size: 45.4 kB
- Tags: Source
- Uploaded using Trusted Publishing? Yes
- Uploaded via: twine/6.1.0 CPython/3.13.7
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
04c2534b4a63060c5727175fc50bde52e0316755e83f4b7c574717f518d80289
|
|
| MD5 |
8e9d3f45e4f3669e93e727cb29593026
|
|
| BLAKE2b-256 |
31f6810ad2547662a20f555b048f1de0e6a5635b94f0c54cd2de3e2d99cec344
|
Provenance
The following attestation bundles were made for jax_mppi-0.1.8.tar.gz:
Publisher:
publish.yml on riccardo-enr/jax_mppi
-
Statement:
-
Statement type:
https://in-toto.io/Statement/v1 -
Predicate type:
https://docs.pypi.org/attestations/publish/v1 -
Subject name:
jax_mppi-0.1.8.tar.gz -
Subject digest:
04c2534b4a63060c5727175fc50bde52e0316755e83f4b7c574717f518d80289 - Sigstore transparency entry: 894964623
- Sigstore integration time:
-
Permalink:
riccardo-enr/jax_mppi@559a95d8feb96a8c639c415cf73b6e1f53c0de05 -
Branch / Tag:
refs/tags/v0.1.8 - Owner: https://github.com/riccardo-enr
-
Access:
public
-
Token Issuer:
https://token.actions.githubusercontent.com -
Runner Environment:
github-hosted -
Publication workflow:
publish.yml@559a95d8feb96a8c639c415cf73b6e1f53c0de05 -
Trigger Event:
push
-
Statement type:
File details
Details for the file jax_mppi-0.1.8-py3-none-any.whl.
File metadata
- Download URL: jax_mppi-0.1.8-py3-none-any.whl
- Upload date:
- Size: 37.1 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? Yes
- Uploaded via: twine/6.1.0 CPython/3.13.7
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
61e5c6094fe85b215020b99c1441d772f198f5e18b351b83c98805ba6db8fd8b
|
|
| MD5 |
950c776e4ffa8f2de71f038553903dce
|
|
| BLAKE2b-256 |
fc9478795332e466d133718b95d6ac692b04f4a607a9b5f7b981411602569df8
|
Provenance
The following attestation bundles were made for jax_mppi-0.1.8-py3-none-any.whl:
Publisher:
publish.yml on riccardo-enr/jax_mppi
-
Statement:
-
Statement type:
https://in-toto.io/Statement/v1 -
Predicate type:
https://docs.pypi.org/attestations/publish/v1 -
Subject name:
jax_mppi-0.1.8-py3-none-any.whl -
Subject digest:
61e5c6094fe85b215020b99c1441d772f198f5e18b351b83c98805ba6db8fd8b - Sigstore transparency entry: 894964666
- Sigstore integration time:
-
Permalink:
riccardo-enr/jax_mppi@559a95d8feb96a8c639c415cf73b6e1f53c0de05 -
Branch / Tag:
refs/tags/v0.1.8 - Owner: https://github.com/riccardo-enr
-
Access:
public
-
Token Issuer:
https://token.actions.githubusercontent.com -
Runner Environment:
github-hosted -
Publication workflow:
publish.yml@559a95d8feb96a8c639c415cf73b6e1f53c0de05 -
Trigger Event:
push
-
Statement type: