Pure JAX/Flax NNX implementation of the Mamba2 state-space model with state caching, pretrained weight loading, and time-series forecasting.
Project description
Mamba2-JAX
A pure JAX/Flax NNX implementation of the Mamba2 state-space model with SSM state caching, pretrained weight loading from HuggingFace, causal language modeling, and time-series forecasting.
This is the standalone PyPI package for the Mamba2 implementation authored by Cosmo Santoni and merged into Google's JAX ML Bonsai library. The upstream source lives at bonsai/models/mamba2.
Supported Models
Features
- Flax NNX modules (no legacy
init/applyceremony) - SSM + convolution state caching for O(n) autoregressive generation
- Pretrained weight loading from HuggingFace (
state-spaces/mamba2-130m, etc.) - Causal language modeling (
Mamba2ForCausalLM) with tied or untied embeddings - Time-series forecasting (
Mamba2Forecaster) - Golden parity tests against the reference
mamba_ssmPyTorch implementation - Fully compatible with
jax.jit,jax.grad,jax.vmap - Runs on CPU, GPU (CUDA), and TPU
State Space Caching
The SSM state cache enables O(n) autoregressive generation instead of O(n^2) re-computation. The example below demonstrates a ~30x speedup on the 780M parameter model running on a TPU v6e when caching is enabled:
Installation
From PyPI
pip install mamba2-jax
From source
git clone https://github.com/CosmoNaught/mamba2-jax.git
cd mamba2-jax
pip install -e ".[dev]"
For loading pretrained weights
pip install "mamba2-jax[pretrained]"
For GPU or TPU support, install the appropriate JAX backend as described in the JAX installation guide.
Usage
Causal Language Model
import jax.numpy as jnp
from flax import nnx
from mamba2_jax import Mamba2Config, Mamba2ForCausalLM
cfg = Mamba2Config(vocab_size=1024, hidden_size=256, num_hidden_layers=4,
state_size=64, head_dim=32, chunk_size=64)
model = Mamba2ForCausalLM(cfg, rngs=nnx.Rngs(0))
input_ids = jnp.ones((2, 64), dtype=jnp.int32)
outputs = model(input_ids, labels=input_ids)
print(outputs["logits"].shape) # (2, 64, 1024)
print(float(outputs["loss"]))
Loading pretrained weights
from mamba2_jax import Mamba2ForCausalLM
model = Mamba2ForCausalLM.from_pretrained("state-spaces/mamba2-130m")
Cached generation
import jax.numpy as jnp
from mamba2_jax import Mamba2ForCausalLM, Mamba2Config
cfg = Mamba2Config.tiny()
model = Mamba2ForCausalLM(cfg, rngs=__import__("flax").nnx.Rngs(0))
# Prefill
prompt = jnp.array([[1, 2, 3]], dtype=jnp.int32)
out = model(prompt)
cache = out["cache"]
# Decode with cache (O(1) per step)
next_token = jnp.array([[4]], dtype=jnp.int32)
out = model(next_token, cache=cache)
print(out["logits"].shape) # (1, 1, vocab_size)
Time-series forecasting
import jax.numpy as jnp
from mamba2_jax import Mamba2Forecaster, create_random_forecaster
model = create_random_forecaster(input_dim=10, d_model=256, n_layers=4,
output_dim=1, forecast_horizon=24)
x = jnp.ones((8, 100, 10)) # (batch, seq_len, features)
y = model(x)
print(y.shape) # (8, 24, 1)
Performance
Benchmarked on a TPU v6e with the state-spaces/mamba2-130m checkpoint:
Project Structure
mamba2-jax/
├── mamba2_jax/
│ ├── __init__.py # Public API
│ ├── modeling.py # Config, SSD kernel, all model classes
│ └── params.py # Weight loading & parameter utilities
├── tests/
│ ├── test_mamba2.py # Comprehensive test suite
│ ├── run_model.py # Generation demo script
│ └── artifacts/ # Golden parity data
├── docs/ # Benchmark figures
├── LICENSE # Apache 2.0
├── pyproject.toml
└── README.md
Contributing
Contributions are welcome! Areas where help is particularly valuable:
- Performance optimization and profiling
- Test coverage expansion
- Bug reports and feature requests
Please open an issue or submit a pull request on GitHub.
Acknowledgments
Original Mamba2 Authors:
- Tri Dao and Albert Gu for the Mamba2 architecture and the original
mamba_ssmimplementation - The State Spaces team for advancing SSM research
JAX Ecosystem:
- The JAX, Flax, and Optax teams at Google for the excellent frameworks
License
This project is licensed under the Apache License 2.0.
Citation
If you use this implementation in your research, please cite the original Mamba2 paper and this JAX implementation:
@inproceedings{mamba2,
title={Transformers are {SSM}s: Generalized Models and Efficient Algorithms
Through Structured State Space Duality},
author={Dao, Tri and Gu, Albert},
booktitle={International Conference on Machine Learning (ICML)},
year={2024}
}
@software{mamba2jax,
author = {Cosmo Santoni},
title = {mamba2-jax: Pure JAX Implementation of Mamba2},
year = {2025},
url = {https://github.com/CosmoNaught/mamba2-jax}
}
Project details
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file mamba2_jax-1.0.0.tar.gz.
File metadata
- Download URL: mamba2_jax-1.0.0.tar.gz
- Upload date:
- Size: 24.0 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.2.0 CPython/3.10.7
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
9782a4b5d1828ef391e719fd81f5e40dcfe1a4b93ade2d7082f5fb8a68f61911
|
|
| MD5 |
b2280080a73cd43eed202797ce62f0f4
|
|
| BLAKE2b-256 |
6e5572742d651af6cad353cc835874c98c541366c20ea7597e27dc7e07e9f33d
|
File details
Details for the file mamba2_jax-1.0.0-py3-none-any.whl.
File metadata
- Download URL: mamba2_jax-1.0.0-py3-none-any.whl
- Upload date:
- Size: 19.7 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.2.0 CPython/3.10.7
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
0358cae63eb61b5bd5218fd671df275027fc013be850f7c77b2e0312430824b8
|
|
| MD5 |
bd114839f22d21a3577c49f9766cb386
|
|
| BLAKE2b-256 |
8f2da63c76ce8c68befadc272225d70dbb512c95da110b3cd69ac3e44810292d
|