Skip to main content

High-performance JAX implementation of Control Barrier Functions for safe control

Project description

CBFJAX — Control Barrier Functions in JAX

Python 3.9+ JAX License: MIT PyPI version

CBFJAX is a high-performance JAX implementation of Control Barrier Functions (CBFs) for safety-critical control. It provides a clean, functional, JIT-compatible API for building safe controllers — from simple closed-form filters up to Backup-CBFs and NMPC with barrier constraints — and runs efficiently on CPU and GPU.

This project is the JAX successor to the CBFTorch framework.


Features

  • Pure JAX, end-to-end JIT — barriers, dynamics, and safety filters are all equinox modules with functional semantics; trajectory rollouts use diffrax.
  • Higher-Order CBFs (HOCBFs) with automatic differentiation for arbitrary relative degree.
  • A toolbox of safe-control backends:
    • Closed-form min-intervention safe control (MinIntervCFSafeControl)
    • QP-based safe control with slack variables (MinIntervQPSafeControl)
    • Input-constrained QP (MinIntervInputConstQPSafeControl)
    • Backup-CBF with forward invariance (MinIntervBackupSafeControl)
    • NMPC with barrier constraints (acados / do-mpc — optional)
    • Constrained iLQR with barrier-aware cost (trajax — optional)
  • Composable barrier algebra: MultiBarriers, SoftCompositionBarrier, NonSmoothCompositionBarrier, BackupBarrier.
  • Built-in dynamics: unicycle, single/double integrator, bicycle, inverted pendulum, reduced-order unicycle — plus a generic AffineInControlDynamics base.
  • Map editor (cbfjax-map-editor) for visually authoring obstacle/boundary maps.
  • 64-bit precision by default for the numerical stability that CBF methods require.

Installation

From PyPI

pip install cbfjax

The core install is lightweight — it pulls in JAX, Equinox, Diffrax, qpax, NumPy, and SciPy, and is sufficient for the closed-form, QP, and Backup-CBF safety filters.

Optional extras

Extra Adds Install
examples matplotlib, animation deps pip install cbfjax[examples]
gpu JAX CUDA 12 wheels pip install cbfjax[gpu]
nmpc CasADi + do-mpc (IPOPT backend) pip install cbfjax[nmpc]
dev pytest, build, twine, ruff, black, mypy, … pip install cbfjax[dev]

The nmpc extra provides an IPOPT-based NMPC backend out of the box. For the acados SQP backend, install acados separately from source.

The iLQR controllers depend on Google's trajax, which is not on PyPI; install it directly from GitHub:

pip install "trajax @ git+https://github.com/google/trajax.git"

NMPC and iLQR controllers are lazily imported, so the core package keeps working even when these optional dependencies are not installed — the ImportError is only raised when you actually instantiate the controller.

From source

git clone https://github.com/amirsaeid254/cbfjax.git
cd cbfjax
pip install -e .[dev,examples]

Quick Start

The pattern is the same for every safe controller in CBFJAX:

dynamics  ──┐
            ├──► safety_filter ──► safe action u(x)
barrier   ──┤
            ├──► assign_desired_control(u_des)
desired u ──┘

Minimal example — QP safety filter on a unicycle

import jax.numpy as jnp
import cbfjax
from cbfjax.barriers import Barrier

# 1. Dynamics
dynamics = cbfjax.UnicycleDynamics()  # state: [x, y, v, theta], control: [a, omega]

# 2. Barrier: stay outside the unit disk centered at the origin
#    h(x) = ||p|| - 1.0 >= 0  (relative degree 2 for unicycle position)
barrier = (
    Barrier.create_empty()
    .assign(barrier_func=lambda x: jnp.linalg.norm(x[:2]) - 1.0, rel_deg=2)
    .assign_dynamics(dynamics)
)

# 3. Desired (nominal) controller — drive to a goal
goal = jnp.array([5.0, 5.0])
def desired_control(x):
    return 0.5 * jnp.array([goal[0] - x[0], goal[1] - x[1]])

# 4. QP-based min-intervention safety filter
safety_filter = (
    cbfjax.MinIntervQPSafeControl(
        action_dim=dynamics.action_dim,
        alpha=lambda h: 1.0 * h,
        params={"slack_gain": 200.0, "slacked": True},
    )
    .assign_dynamics(dynamics)
    .assign_state_barrier(barrier)
    .assign_desired_control(desired_control)
)

# 5. Query the safe action
x0 = jnp.array([[-2.0, -2.0, 0.0, 0.0]])     # batched (1, 4)
u_safe, _ = safety_filter.optimal_control(x0, safety_filter.get_init_state())
print(u_safe)

Multiple barriers via MultiBarriers

import jax.numpy as jnp
from cbfjax.barriers import Barrier, MultiBarriers
import cbfjax

dynamics = cbfjax.UnicycleDynamics()

# Two obstacle barriers + workspace boundary, each with dynamics already assigned.
def obstacle(center, radius):
    return lambda x: jnp.linalg.norm(x[:2] - jnp.array(center)) - radius

barriers = [
    Barrier.create_empty().assign(obstacle([2.0, 2.0], 0.5), rel_deg=2).assign_dynamics(dynamics),
    Barrier.create_empty().assign(obstacle([-1.0, 3.0], 0.5), rel_deg=2).assign_dynamics(dynamics),
    Barrier.create_empty().assign(
        lambda x: 10.0 - jnp.maximum(jnp.abs(x[0]), jnp.abs(x[1])), rel_deg=2
    ).assign_dynamics(dynamics),
]

# Pass infer_dynamics=True to pick up the dynamics from the first barrier.
multi = MultiBarriers.create_empty().add_barriers(barriers, infer_dynamics=True)

Closed-loop simulation

trajs = safety_filter.get_optimal_trajs(
    x0=x0,
    sim_time=10.0,
    timestep=0.01,
    method="euler",
)
print(trajs.shape)  # (T, batch, state_dim)

More end-to-end scripts live under examples/unicycle/ (closed-form, QP, input-constrained QP, NMPC, iLQR, hierarchical) and examples/backup_examples/ (Backup-CBF).

cd examples/unicycle
python 03_unicycle_qp.py

Map editor

CBFJAX ships with a browser-based visual editor for authoring obstacle/boundary maps:

cbfjax-map-editor

It opens an HTML canvas in your default browser where you can drop in cylinders, ellipses, norm-boxes, and boundaries and export a CBFJAX map_config.py.


Architecture

cbfjax/
├── barriers/                       # CBF & HOCBF
│   ├── barrier.py                  #  Single barrier
│   ├── multi_barrier.py            #  Multiple barriers
│   ├── composite_barrier.py        #  Soft / non-smooth composition
│   └── backup_barrier.py           #  Backup-CBF
├── dynamics/                       # Affine-in-control system dynamics
│   ├── base_dynamic.py
│   ├── unicycle.py
│   ├── unicycle_reduced_order.py
│   ├── double_integrator.py
│   ├── single_integrator.py
│   ├── bicycle.py
│   └── inverted_pendulum.py
├── controls/                       # Nominal/optimal controllers
│   ├── base_control.py
│   ├── ilqr_control.py             #  (optional: trajax)
│   ├── nmpc_control.py             #  (optional: casadi + acados/do-mpc)
│   └── control_types.py
├── safe_controls/                  # Safety filters
│   ├── base_safe_control.py
│   ├── closed_form_safe_control.py
│   ├── qp_safe_control.py
│   ├── backup_safe_control.py
│   ├── nmpc_safe_control.py        #  (optional)
│   └── ilqr_safe_control.py        #  (optional)
├── utils/
│   ├── integration.py              #  Diffrax-based ODE rollouts
│   ├── make_map.py                 #  Map / barrier factory
│   ├── jax2casadi/                 #  JAX → CasADi conversion (used by NMPC)
│   ├── map_editor/                 #  HTML map editor
│   ├── profile_utils.py
│   ├── run_map_editor.py
│   └── utils.py
└── config.py                       # JAX configuration helpers

Key concepts

Control Barrier Functions

For a control-affine system ẋ = f(x) + g(x) u and a safe set C = {x | h(x) ≥ 0}, a barrier function h ensures forward invariance of C whenever there exists u such that

L_f h(x) + L_g h(x) · u ≥ -α(h(x))

where α is a class-K function.

Higher-Order CBFs

For barriers of relative degree r > 1, CBFJAX automatically constructs the HOCBF series ψ_0, ψ_1, …, ψ_r from a user-provided list of class-K functions (via Barrier(...).assign(barrier_func, rel_deg=r, alphas=[α_1, …, α_r])).


Citation

If you use CBFJAX in your research, please cite it as:

@software{CBFJAX,
  author       = {Safari, Amirsaeid},
  title        = {{CBFJAX}: Control Barrier Functions in {JAX}},
  howpublished = {\url{https://github.com/amirsaeid254/cbfjax}},
  year         = {2025}
}

Related work

  • CBFTorch — PyTorch implementation of CBFs.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

cbfjax-0.1.0.tar.gz (103.0 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

cbfjax-0.1.0-py3-none-any.whl (117.6 kB view details)

Uploaded Python 3

File details

Details for the file cbfjax-0.1.0.tar.gz.

File metadata

  • Download URL: cbfjax-0.1.0.tar.gz
  • Upload date:
  • Size: 103.0 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.10.18

File hashes

Hashes for cbfjax-0.1.0.tar.gz
Algorithm Hash digest
SHA256 43d7109328c90d7ec73834add12e97f7f66e0d915c2ffab02061b76e0e1b2add
MD5 502674e8036222e48045b4ec8edc64ed
BLAKE2b-256 2ec7df871dec07522e9f7d72141a17d49bca6079d3dc81fda6a78bff20bf6c59

See more details on using hashes here.

File details

Details for the file cbfjax-0.1.0-py3-none-any.whl.

File metadata

  • Download URL: cbfjax-0.1.0-py3-none-any.whl
  • Upload date:
  • Size: 117.6 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.10.18

File hashes

Hashes for cbfjax-0.1.0-py3-none-any.whl
Algorithm Hash digest
SHA256 fad6cb50945df42e2e1df45ae323ba8f8f82d6455917bb2319e29fff5a77764d
MD5 a36bc69b09743f1aca9d66fce22d7aa2
BLAKE2b-256 480689c31273c9dbbd27ce75ba225e27f2a56ed1460671a63bbab894eb22e28e

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page