Skip to main content

Pure JAX/Flax NNX implementation of the Mamba3 state-space model with state caching, weight loading, and time-series forecasting.

Project description

Mamba3-JAX

PyPI License

A pure JAX/Flax NNX implementation of the Mamba3 state-space model with SSM state caching, weight loading from the canonical state-spaces layout, causal language modeling, and time-series forecasting.

This is the standalone PyPI package for the Mamba3 implementation authored by Cosmo Santoni, following the same pure-JAX, kernel-free approach as its sibling mamba2-jax.

Independent community reimplementation — not affiliated with, endorsed by, or connected to the original Mamba authors or state-spaces/mamba.

Supported Models

Mamba3ForCausalLM and Mamba3Forecaster across every canonical configuration: SISO and MIMO (is_mimo, mimo_rank), ngroups, is_outproj_norm, rope_fraction ∈ {0.5, 1.0}, and packed variable-length sequences (cu_seqlens). No pretrained Mamba-3 checkpoints are public; the architecture, config, and weight loader match the canonical state-spaces/mamba (Mamba-3) layout.

Features

  • Flax NNX modules (no legacy init/apply ceremony)
  • SSM + rotary + trapezoidal state caching for O(n) autoregressive generation (no convolution state)
  • MIMO rank-R recurrence, plus ngroups, is_outproj_norm, rope_fraction — every canonical config flag
  • Variable-length / packed sequences (cu_seqlens) for padding-free batched training
  • Causal language modeling (Mamba3ForCausalLM) with tied or untied embeddings
  • Time-series forecasting (Mamba3Forecaster)
  • Weight loading from the canonical state-spaces/mamba (Mamba-3) layout
  • Kernel-verified parity against the reference Triton (SISO) and TileLang (MIMO) implementations
  • Fully compatible with jax.jit, jax.grad, jax.vmap
  • Runs on CPU, GPU (CUDA), and TPU

State Space Caching

The SSM state cache enables O(n) autoregressive generation instead of O(n²) re-computation. Mamba-3 carries a fixed-size per-layer state — the cumulative rotary angle, the SSM state, and the previous B/x required by the trapezoidal recurrence (no convolution state) — so each decode step is O(1) in sequence length. Cached and non-cached generation produce identical tokens.

Installation

From PyPI

pip install mamba3-jax

From source

git clone https://github.com/jax-state-spaces/mamba3-jax.git
cd mamba3-jax
pip install -e ".[dev]"

For GPU or TPU support, install the appropriate JAX backend as described in the JAX installation guide.

Usage

Causal Language Model

import jax.numpy as jnp
from flax import nnx
from mamba3_jax import Mamba3Config, Mamba3ForCausalLM

cfg = Mamba3Config(vocab_size=1024, hidden_size=256, num_hidden_layers=4,
                   state_size=64, head_dim=32, chunk_size=64)
model = Mamba3ForCausalLM(cfg, rngs=nnx.Rngs(0))

input_ids = jnp.ones((2, 64), dtype=jnp.int32)
outputs = model(input_ids, labels=input_ids)

print(outputs["logits"].shape)  # (2, 64, 1024)
print(float(outputs["loss"]))

Cached generation

import jax.numpy as jnp
from flax import nnx
from mamba3_jax import Mamba3ForCausalLM, Mamba3Config

cfg = Mamba3Config.tiny()
model = Mamba3ForCausalLM(cfg, rngs=nnx.Rngs(0))

# Prefill
prompt = jnp.array([[1, 2, 3]], dtype=jnp.int32)
out = model(prompt)
cache = out["cache"]

# Decode with cache (O(1) per step)
next_token = jnp.array([[4]], dtype=jnp.int32)
out = model(next_token, cache=cache)
print(out["logits"].shape)  # (1, 1, vocab_size)

Time-series forecasting

import jax.numpy as jnp
from mamba3_jax import Mamba3Forecaster, create_random_forecaster

model = create_random_forecaster(input_dim=10, d_model=256, n_layers=4,
                                 output_dim=1, forecast_horizon=24)

x = jnp.ones((8, 100, 10))  # (batch, seq_len, features)
y = model(x)
print(y.shape)  # (8, 24, 1)

Numerical Parity

Verified directly against the canonical state-spaces/mamba kernels — Triton for SISO, TileLang for MIMO — with identical random weights, across ngroups ∈ {1, 2}, is_outproj_norm ∈ {False, True}, and rope_fraction ∈ {0.5, 1.0}, matching to ~1–2% mean relative (the Triton SISO kernel casts its SSD core to bf16; MIMO runs fp32 throughout). The varlen path is verified against the kernel and via varlen(packed) == concat(per-sequence). Reproduce on any CUDA GPU with scripts/verify_mimo_kernel.py.

Project Structure

mamba3-jax/
├── mamba3_jax/
│   ├── __init__.py          # Public API
│   ├── modeling.py          # Config, SSD kernels, all model classes
│   └── params.py            # Weight loading & parameter utilities
├── tests/
│   ├── test_mamba3.py       # Comprehensive test suite
│   ├── run_model.py         # Generation demo script
│   └── artifacts/           # Golden parity data & generator
├── scripts/
│   └── verify_mimo_kernel.py    # GPU parity check against the canonical kernels
├── LICENSE                  # Apache 2.0
├── pyproject.toml
└── README.md

Contributing

Contributions are welcome! Areas where help is particularly valuable:

  • Performance optimization and profiling
  • Test coverage expansion
  • Bug reports and feature requests

Please open an issue or submit a pull request on GitHub.

Acknowledgments

Original Mamba Authors:

  • Aakash Lahoti, Kevin Y. Li, Berlin Chen, Caitlin Wang, Aviv Bick, J. Zico Kolter, Tri Dao, and Albert Gu for the Mamba3 architecture and the original mamba_ssm implementation
  • The State Spaces team for advancing SSM research

JAX Ecosystem:

  • The JAX, Flax, and Optax teams at Google for the excellent frameworks

License

This project is licensed under the Apache License 2.0.

Citation

If you use this implementation in your research, please cite the original Mamba3 paper and this JAX implementation:

@misc{lahoti2026mamba3,
  title={Mamba-3: Improved Sequence Modeling using State Space Principles},
  author={Aakash Lahoti and Kevin Y. Li and Berlin Chen and Caitlin Wang and Aviv Bick and J. Zico Kolter and Tri Dao and Albert Gu},
  year={2026},
  eprint={2603.15569},
  archivePrefix={arXiv},
  primaryClass={cs.LG},
  url={https://arxiv.org/abs/2603.15569}
}

@software{mamba3jax,
  author  = {Cosmo Santoni},
  title   = {mamba3-jax: Pure JAX Implementation of Mamba3},
  year    = {2026},
  url     = {https://github.com/jax-state-spaces/mamba3-jax}
}

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

mamba3_jax-1.0.1.tar.gz (24.6 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

mamba3_jax-1.0.1-py3-none-any.whl (22.7 kB view details)

Uploaded Python 3

File details

Details for the file mamba3_jax-1.0.1.tar.gz.

File metadata

  • Download URL: mamba3_jax-1.0.1.tar.gz
  • Upload date:
  • Size: 24.6 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.12.11

File hashes

Hashes for mamba3_jax-1.0.1.tar.gz
Algorithm Hash digest
SHA256 22c8c2edc16dbb3a036cb7c8119d2390ce461932ed1ab0c71144b174f1ef62dc
MD5 823e71b09b856deec956ac99fccb4a62
BLAKE2b-256 14b125aff8ee364dc9070f8ed3d72d532023624d8860706f511370772bbfb597

See more details on using hashes here.

File details

Details for the file mamba3_jax-1.0.1-py3-none-any.whl.

File metadata

  • Download URL: mamba3_jax-1.0.1-py3-none-any.whl
  • Upload date:
  • Size: 22.7 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.12.11

File hashes

Hashes for mamba3_jax-1.0.1-py3-none-any.whl
Algorithm Hash digest
SHA256 8300c186604d0d6b6fd3e861437e1b14fc4fd57d4c2ab1a54a23ed9cf996e723
MD5 be5b9bc3974a87df78b86a65799fcbe9
BLAKE2b-256 d0f2c0a8ff8b44adc56bf9e07004912777a7537ad240ba0cf4fac33d0190e463

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page