Pure JAX/Flax NNX implementation of the Mamba3 state-space model with state caching, pretrained weight loading, and time-series forecasting.
Project description
Mamba3-JAX
A pure JAX/Flax NNX implementation of the Mamba3 state-space model with SSM state caching, weight loading from the canonical state-spaces layout, causal language modeling, and time-series forecasting.
This is the standalone PyPI package for the Mamba3 implementation authored by Cosmo Santoni, following the same pure-JAX, kernel-free approach as its sibling mamba2-jax.
Supported Models
Mamba3ForCausalLM and Mamba3Forecaster across every canonical configuration: SISO and MIMO (is_mimo, mimo_rank), ngroups, is_outproj_norm, rope_fraction ∈ {0.5, 1.0}, and packed variable-length sequences (cu_seqlens). No pretrained Mamba-3 checkpoints are public yet; the architecture, config, and weight loader match the canonical state-spaces/mamba (Mamba-3) layout and load its state_dict directly the moment weights are released.
Features
- Flax NNX modules (no legacy
init/applyceremony) - SSM + rotary + trapezoidal state caching for O(n) autoregressive generation (no convolution state)
- MIMO rank-
Rrecurrence, plusngroups,is_outproj_norm,rope_fraction— every canonical config flag - Variable-length / packed sequences (
cu_seqlens) for padding-free batched training - Causal language modeling (
Mamba3ForCausalLM) with tied or untied embeddings - Time-series forecasting (
Mamba3Forecaster) - Weight loading from the canonical state-spaces/mamba (Mamba-3) layout
- Kernel-verified parity against the reference Triton (SISO) and TileLang (MIMO) implementations
- Fully compatible with
jax.jit,jax.grad,jax.vmap - Runs on CPU, GPU (CUDA), and TPU
State Space Caching
The SSM state cache enables O(n) autoregressive generation instead of O(n²) re-computation. Mamba-3 carries a fixed-size per-layer state — the cumulative rotary angle, the SSM state, and the previous B/x required by the trapezoidal recurrence (no convolution state) — so each decode step is O(1) in sequence length. Cached and non-cached generation produce identical tokens.
Installation
From PyPI
pip install mamba3-jax
From source
git clone https://github.com/CosmoNaught/mamba3-jax.git
cd mamba3-jax
pip install -e ".[dev]"
For loading pretrained weights
pip install "mamba3-jax[pretrained]"
For GPU or TPU support, install the appropriate JAX backend as described in the JAX installation guide.
Usage
Causal Language Model
import jax.numpy as jnp
from flax import nnx
from mamba3_jax import Mamba3Config, Mamba3ForCausalLM
cfg = Mamba3Config(vocab_size=1024, hidden_size=256, num_hidden_layers=4,
state_size=64, head_dim=32, chunk_size=64)
model = Mamba3ForCausalLM(cfg, rngs=nnx.Rngs(0))
input_ids = jnp.ones((2, 64), dtype=jnp.int32)
outputs = model(input_ids, labels=input_ids)
print(outputs["logits"].shape) # (2, 64, 1024)
print(float(outputs["loss"]))
Loading weights
from mamba3_jax import Mamba3ForCausalLM
# Loads from a canonical state-spaces/mamba (Mamba-3) checkpoint or HuggingFace id.
# (No public Mamba-3 checkpoints exist yet; the loader is ready for them.)
model = Mamba3ForCausalLM.from_pretrained("state-spaces/mamba3-...")
Cached generation
import jax.numpy as jnp
from flax import nnx
from mamba3_jax import Mamba3ForCausalLM, Mamba3Config
cfg = Mamba3Config.tiny()
model = Mamba3ForCausalLM(cfg, rngs=nnx.Rngs(0))
# Prefill
prompt = jnp.array([[1, 2, 3]], dtype=jnp.int32)
out = model(prompt)
cache = out["cache"]
# Decode with cache (O(1) per step)
next_token = jnp.array([[4]], dtype=jnp.int32)
out = model(next_token, cache=cache)
print(out["logits"].shape) # (1, 1, vocab_size)
Time-series forecasting
import jax.numpy as jnp
from mamba3_jax import Mamba3Forecaster, create_random_forecaster
model = create_random_forecaster(input_dim=10, d_model=256, n_layers=4,
output_dim=1, forecast_horizon=24)
x = jnp.ones((8, 100, 10)) # (batch, seq_len, features)
y = model(x)
print(y.shape) # (8, 24, 1)
Performance
Numerical parity is verified directly against the canonical state-spaces/mamba kernels — Triton for SISO, TileLang for MIMO — with identical random weights, at the bf16 level (~1–2% mean relative, since the kernels accumulate in bf16 while JAX runs fp32), across ngroups ∈ {1, 2}, is_outproj_norm ∈ {False, True}, and rope_fraction ∈ {0.5, 1.0}. The varlen path is verified against the kernel and via varlen(packed) == concat(per-sequence). Reproduce on any CUDA GPU with scripts/verify_mimo_kernel.py. Throughput benchmarks are forthcoming.
Project Structure
mamba3-jax/
├── mamba3_jax/
│ ├── __init__.py # Public API
│ ├── modeling.py # Config, SSD kernels, all model classes
│ └── params.py # Weight loading & parameter utilities
├── tests/
│ ├── test_mamba3.py # Comprehensive test suite
│ ├── run_model.py # Generation demo script
│ └── artifacts/ # Golden parity data & generator
├── scripts/
│ └── verify_mimo_kernel.py # GPU verification against the canonical kernels
├── LICENSE # Apache 2.0
├── pyproject.toml
└── README.md
Contributing
Contributions are welcome! Areas where help is particularly valuable:
- Loading pretrained weights once Mamba-3 checkpoints are released
- Performance optimization and profiling
- Test coverage expansion
- Bug reports and feature requests
Please open an issue or submit a pull request on GitHub.
Acknowledgments
Original Mamba Authors:
- Aakash Lahoti, Kevin Y. Li, Berlin Chen, Caitlin Wang, Aviv Bick, J. Zico Kolter, Tri Dao, and Albert Gu for the Mamba3 architecture and the original
mamba_ssmimplementation - The State Spaces team for advancing SSM research
JAX Ecosystem:
- The JAX, Flax, and Optax teams at Google for the excellent frameworks
License
This project is licensed under the Apache License 2.0.
Citation
If you use this implementation in your research, please cite the original Mamba3 paper and this JAX implementation:
@misc{lahoti2026mamba3,
title={Mamba-3: Improved Sequence Modeling using State Space Principles},
author={Aakash Lahoti and Kevin Y. Li and Berlin Chen and Caitlin Wang and Aviv Bick and J. Zico Kolter and Tri Dao and Albert Gu},
year={2026},
eprint={2603.15569},
archivePrefix={arXiv},
primaryClass={cs.LG},
url={https://arxiv.org/abs/2603.15569}
}
@software{mamba3jax,
author = {Cosmo Santoni},
title = {mamba3-jax: Pure JAX Implementation of Mamba3},
year = {2026},
url = {https://github.com/CosmoNaught/mamba3-jax}
}
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file mamba3_jax-1.0.0.tar.gz.
File metadata
- Download URL: mamba3_jax-1.0.0.tar.gz
- Upload date:
- Size: 29.9 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: uv/0.8.22
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
7a020743f396edbe28acdb47952ebd279e7657a3d63edf0bd322b8fd0b1bbfde
|
|
| MD5 |
ecba9ec4cd5a210e0e614b403a907e63
|
|
| BLAKE2b-256 |
e94fb365bc50f03bd1fc74812e94388759b5f4516fc8bfea7e6d2e3696ad8d0a
|
File details
Details for the file mamba3_jax-1.0.0-py3-none-any.whl.
File metadata
- Download URL: mamba3_jax-1.0.0-py3-none-any.whl
- Upload date:
- Size: 22.8 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: uv/0.8.22
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
96c46635cd8db6f0be92b3cc204391b3774660155b477ff176a8b377aec43be3
|
|
| MD5 |
72c36b312f570073c8bcb013cfd8a860
|
|
| BLAKE2b-256 |
6a789385ce8c054f3647c34fb420c64d68787b439fdb29411dbf000724ef726e
|