Skip to main content

A library providing performant NumPy & JAX implementations of an MPPI planner, along with implementation of related algorithms/tools.

This project has been archived.

The maintainers of this project have marked this project as archived. No new releases are expected.

Project description

trajax

Pipeline Status Coverage Benchmarks PyPI Python License

A sampling-based trajectory planning library with NumPy and JAX backends for building MPPI (Model Predictive Path Integral) planners.

Features

  • Dual Backend: Identical APIs for NumPy (prototyping) and JAX (GPU acceleration)
  • MPPI Planning: Sampling-based trajectory optimization with configurable cost functions
  • MPCC Support: Model Predictive Contouring Control for path-following tasks
  • Modular Design: Composable cost functions, samplers, and dynamical models
  • Risk-Aware Planning: Integration with risk metrics (CVaR, VaR, entropic risk)
  • Obstacle Avoidance: Circle and polygon collision checking with motion prediction

Installation

pip install trajax

Or with uv:

uv add trajax

For GPU acceleration (Linux only):

pip install trajax[cuda]

Quick Start

from trajax.numpy import mppi, model, sampler, trajectory, types, extract
from numtypes import array

# Define position extractor for the cost function
def position(states):
    return types.positions(x=states.positions.x(), y=states.positions.y())

# Define the reference path to follow
reference = trajectory.waypoints(
    points=array([[0, 0], [10, 0], [20, 5], [30, 5]], shape=(4, 2)),
    path_length=35.0,
)

# Create an MPCC planner (path-following with contouring/lag costs)
planner, augmented_model, contouring_cost, lag_cost = mppi.mpcc(
    model=model.bicycle.dynamical(
        time_step_size=0.1,
        wheelbase=2.5,
        speed_limits=(0.0, 15.0),
        steering_limits=(-0.5, 0.5),
        acceleration_limits=(-3.0, 3.0),
    ),
    sampler=sampler.gaussian(
        standard_deviation=array([0.5, 0.2], shape=(2,)),
        rollout_count=256,
        to_batch=types.bicycle.control_input_batch.create,
        seed=42,
    ),
    reference=reference,
    position_extractor=extract.from_physical(position),
    config={
        "weights": {"contouring": 50.0, "lag": 100.0, "progress": 1000.0},
        "virtual": {"velocity_limits": (0.0, 15.0)},
    },
)

# Initialize state
initial_state = types.augmented.state.of(
    physical=types.bicycle.state.create(x=0.0, y=0.0, heading=0.0, speed=0.0),
    virtual=types.simple.state.zeroes(dimension=1),
)
nominal_input = types.augmented.control_input_sequence.of(
    physical=types.bicycle.control_input_sequence.zeroes(horizon=30),
    virtual=types.simple.control_input_sequence.zeroes(horizon=30, dimension=1),
)

# Run the planner
control = planner.step(
    temperature=50.0,
    nominal_input=nominal_input,
    initial_state=initial_state,
)

# control.optimal - the optimal control sequence
# control.nominal - the updated nominal for the next iteration

Switching to JAX

Replace imports to use GPU acceleration:

# Change this:
from trajax.numpy import mppi, model, sampler, trajectory, types, extract

# To this:
from trajax.jax import mppi, model, sampler, trajectory, types, extract

All APIs remain identical between backends.

Documentation

Architecture

┌─────────────────────────────────────────────────────────────┐
│                         MPPI Planner                        │
├─────────────┬─────────────┬─────────────┬──────────────────┤
│   Sampler   │    Model    │    Cost     │     Filter       │
│  (Gaussian  │  (Bicycle,  │  (Tracking, │   (Savitzky-     │
│   Halton)   │  Integrator)│   Safety)   │    Golay)        │
└─────────────┴─────────────┴─────────────┴──────────────────┘

Available Components

Category Components
Models Kinematic bicycle, Unicycle, Integrator
Samplers Gaussian, Halton-spline
Costs Contouring, Lag, Progress, Collision, Boundary, Control smoothing
Trajectories Waypoints (spline), Line
Risk Metrics Expected value, Mean-variance, VaR, CVaR, Entropic risk

Requirements

  • Python ≥ 3.13
  • NumPy, JAX, SciPy

Changelog

See CHANGELOG.md for release history.

Contributing

Contributions are welcome! See CONTRIBUTING.md for development setup, coding style, and testing guidelines.

License

MIT License — see LICENSE for details.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

trajax-0.2.0.tar.gz (141.5 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

trajax-0.2.0-py3-none-any.whl (176.7 kB view details)

Uploaded Python 3

File details

Details for the file trajax-0.2.0.tar.gz.

File metadata

  • Download URL: trajax-0.2.0.tar.gz
  • Upload date:
  • Size: 141.5 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: uv/0.9.18 {"installer":{"name":"uv","version":"0.9.18","subcommand":["publish"]},"python":null,"implementation":{"name":null,"version":null},"distro":{"name":"Debian GNU/Linux","version":"12","id":"bookworm","libc":null},"system":{"name":null,"release":null},"cpu":null,"openssl_version":null,"setuptools_version":null,"rustc_version":null,"ci":true}

File hashes

Hashes for trajax-0.2.0.tar.gz
Algorithm Hash digest
SHA256 76bad25f2a01b616ba3178d7297ab68d5db9b8d0d5b678243ab2403e38ef617f
MD5 8b051a49da1429ec03b6873da1973185
BLAKE2b-256 af8cf1dafaaf1c1f9a2bc65b6a7e9fa371a2e9f0176de4a668d3bb76a7d56f37

See more details on using hashes here.

File details

Details for the file trajax-0.2.0-py3-none-any.whl.

File metadata

  • Download URL: trajax-0.2.0-py3-none-any.whl
  • Upload date:
  • Size: 176.7 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: uv/0.9.18 {"installer":{"name":"uv","version":"0.9.18","subcommand":["publish"]},"python":null,"implementation":{"name":null,"version":null},"distro":{"name":"Debian GNU/Linux","version":"12","id":"bookworm","libc":null},"system":{"name":null,"release":null},"cpu":null,"openssl_version":null,"setuptools_version":null,"rustc_version":null,"ci":true}

File hashes

Hashes for trajax-0.2.0-py3-none-any.whl
Algorithm Hash digest
SHA256 a0a70eab44473fcf4c98eccc4f27a9e82f7cc827fbd46f0489d697c72154b06b
MD5 11bb45c6bd4b6a2b642407a906bb0c97
BLAKE2b-256 5450b4d661f75675b2cc4c9c2105ea9a4b0a623192c46f85b6548764ad452268

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page