Skip to main content

Pure JAX/Flax NNX implementation of the Mamba3 state-space model with state caching, pretrained weight loading, and time-series forecasting.

Project description

Mamba3-JAX

PyPI License

A pure JAX/Flax NNX implementation of the Mamba3 state-space model with SSM state caching, weight loading from the canonical state-spaces layout, causal language modeling, and time-series forecasting.

This is the standalone PyPI package for the Mamba3 implementation authored by Cosmo Santoni, following the same pure-JAX, kernel-free approach as its sibling mamba2-jax.

Supported Models

Mamba3ForCausalLM and Mamba3Forecaster across every canonical configuration: SISO and MIMO (is_mimo, mimo_rank), ngroups, is_outproj_norm, rope_fraction ∈ {0.5, 1.0}, and packed variable-length sequences (cu_seqlens). No pretrained Mamba-3 checkpoints are public yet; the architecture, config, and weight loader match the canonical state-spaces/mamba (Mamba-3) layout and load its state_dict directly the moment weights are released.

Features

  • Flax NNX modules (no legacy init/apply ceremony)
  • SSM + rotary + trapezoidal state caching for O(n) autoregressive generation (no convolution state)
  • MIMO rank-R recurrence, plus ngroups, is_outproj_norm, rope_fraction — every canonical config flag
  • Variable-length / packed sequences (cu_seqlens) for padding-free batched training
  • Causal language modeling (Mamba3ForCausalLM) with tied or untied embeddings
  • Time-series forecasting (Mamba3Forecaster)
  • Weight loading from the canonical state-spaces/mamba (Mamba-3) layout
  • Kernel-verified parity against the reference Triton (SISO) and TileLang (MIMO) implementations
  • Fully compatible with jax.jit, jax.grad, jax.vmap
  • Runs on CPU, GPU (CUDA), and TPU

State Space Caching

The SSM state cache enables O(n) autoregressive generation instead of O(n²) re-computation. Mamba-3 carries a fixed-size per-layer state — the cumulative rotary angle, the SSM state, and the previous B/x required by the trapezoidal recurrence (no convolution state) — so each decode step is O(1) in sequence length. Cached and non-cached generation produce identical tokens.

Installation

From PyPI

pip install mamba3-jax

From source

git clone https://github.com/CosmoNaught/mamba3-jax.git
cd mamba3-jax
pip install -e ".[dev]"

For loading pretrained weights

pip install "mamba3-jax[pretrained]"

For GPU or TPU support, install the appropriate JAX backend as described in the JAX installation guide.

Usage

Causal Language Model

import jax.numpy as jnp
from flax import nnx
from mamba3_jax import Mamba3Config, Mamba3ForCausalLM

cfg = Mamba3Config(vocab_size=1024, hidden_size=256, num_hidden_layers=4,
                   state_size=64, head_dim=32, chunk_size=64)
model = Mamba3ForCausalLM(cfg, rngs=nnx.Rngs(0))

input_ids = jnp.ones((2, 64), dtype=jnp.int32)
outputs = model(input_ids, labels=input_ids)

print(outputs["logits"].shape)  # (2, 64, 1024)
print(float(outputs["loss"]))

Loading weights

from mamba3_jax import Mamba3ForCausalLM

# Loads from a canonical state-spaces/mamba (Mamba-3) checkpoint or HuggingFace id.
# (No public Mamba-3 checkpoints exist yet; the loader is ready for them.)
model = Mamba3ForCausalLM.from_pretrained("state-spaces/mamba3-...")

Cached generation

import jax.numpy as jnp
from flax import nnx
from mamba3_jax import Mamba3ForCausalLM, Mamba3Config

cfg = Mamba3Config.tiny()
model = Mamba3ForCausalLM(cfg, rngs=nnx.Rngs(0))

# Prefill
prompt = jnp.array([[1, 2, 3]], dtype=jnp.int32)
out = model(prompt)
cache = out["cache"]

# Decode with cache (O(1) per step)
next_token = jnp.array([[4]], dtype=jnp.int32)
out = model(next_token, cache=cache)
print(out["logits"].shape)  # (1, 1, vocab_size)

Time-series forecasting

import jax.numpy as jnp
from mamba3_jax import Mamba3Forecaster, create_random_forecaster

model = create_random_forecaster(input_dim=10, d_model=256, n_layers=4,
                                 output_dim=1, forecast_horizon=24)

x = jnp.ones((8, 100, 10))  # (batch, seq_len, features)
y = model(x)
print(y.shape)  # (8, 24, 1)

Performance

Numerical parity is verified directly against the canonical state-spaces/mamba kernels — Triton for SISO, TileLang for MIMO — with identical random weights, at the bf16 level (~1–2% mean relative, since the kernels accumulate in bf16 while JAX runs fp32), across ngroups ∈ {1, 2}, is_outproj_norm ∈ {False, True}, and rope_fraction ∈ {0.5, 1.0}. The varlen path is verified against the kernel and via varlen(packed) == concat(per-sequence). Reproduce on any CUDA GPU with scripts/verify_mimo_kernel.py. Throughput benchmarks are forthcoming.

Project Structure

mamba3-jax/
├── mamba3_jax/
│   ├── __init__.py          # Public API
│   ├── modeling.py          # Config, SSD kernels, all model classes
│   └── params.py            # Weight loading & parameter utilities
├── tests/
│   ├── test_mamba3.py       # Comprehensive test suite
│   ├── run_model.py         # Generation demo script
│   └── artifacts/           # Golden parity data & generator
├── scripts/
│   └── verify_mimo_kernel.py  # GPU verification against the canonical kernels
├── LICENSE                  # Apache 2.0
├── pyproject.toml
└── README.md

Contributing

Contributions are welcome! Areas where help is particularly valuable:

  • Loading pretrained weights once Mamba-3 checkpoints are released
  • Performance optimization and profiling
  • Test coverage expansion
  • Bug reports and feature requests

Please open an issue or submit a pull request on GitHub.

Acknowledgments

Original Mamba Authors:

  • Aakash Lahoti, Kevin Y. Li, Berlin Chen, Caitlin Wang, Aviv Bick, J. Zico Kolter, Tri Dao, and Albert Gu for the Mamba3 architecture and the original mamba_ssm implementation
  • The State Spaces team for advancing SSM research

JAX Ecosystem:

  • The JAX, Flax, and Optax teams at Google for the excellent frameworks

License

This project is licensed under the Apache License 2.0.

Citation

If you use this implementation in your research, please cite the original Mamba3 paper and this JAX implementation:

@misc{lahoti2026mamba3,
  title={Mamba-3: Improved Sequence Modeling using State Space Principles},
  author={Aakash Lahoti and Kevin Y. Li and Berlin Chen and Caitlin Wang and Aviv Bick and J. Zico Kolter and Tri Dao and Albert Gu},
  year={2026},
  eprint={2603.15569},
  archivePrefix={arXiv},
  primaryClass={cs.LG},
  url={https://arxiv.org/abs/2603.15569}
}

@software{mamba3jax,
  author  = {Cosmo Santoni},
  title   = {mamba3-jax: Pure JAX Implementation of Mamba3},
  year    = {2026},
  url     = {https://github.com/CosmoNaught/mamba3-jax}
}

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

mamba3_jax-1.0.0.tar.gz (29.9 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

mamba3_jax-1.0.0-py3-none-any.whl (22.8 kB view details)

Uploaded Python 3

File details

Details for the file mamba3_jax-1.0.0.tar.gz.

File metadata

  • Download URL: mamba3_jax-1.0.0.tar.gz
  • Upload date:
  • Size: 29.9 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: uv/0.8.22

File hashes

Hashes for mamba3_jax-1.0.0.tar.gz
Algorithm Hash digest
SHA256 7a020743f396edbe28acdb47952ebd279e7657a3d63edf0bd322b8fd0b1bbfde
MD5 ecba9ec4cd5a210e0e614b403a907e63
BLAKE2b-256 e94fb365bc50f03bd1fc74812e94388759b5f4516fc8bfea7e6d2e3696ad8d0a

See more details on using hashes here.

File details

Details for the file mamba3_jax-1.0.0-py3-none-any.whl.

File metadata

  • Download URL: mamba3_jax-1.0.0-py3-none-any.whl
  • Upload date:
  • Size: 22.8 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: uv/0.8.22

File hashes

Hashes for mamba3_jax-1.0.0-py3-none-any.whl
Algorithm Hash digest
SHA256 96c46635cd8db6f0be92b3cc204391b3774660155b477ff176a8b377aec43be3
MD5 72c36b312f570073c8bcb013cfd8a860
BLAKE2b-256 6a789385ce8c054f3647c34fb420c64d68787b439fdb29411dbf000724ef726e

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page