Mamba-3: Improved Sequence Modeling using State Space Principles
Project description
Mamba-3: Improved Sequence Modeling using State Space Principles
pip install:
pip install mamba3-ssm · version: 0.2.1
A clean, readable, from-scratch PyTorch implementation of Mamba-3 arXiv:2603.15569. Features CUDA-accelerated SSM scans (50× speedup) and MSVC/CUDA 12.1 compatibility on Windows.
Installation
pip install mamba3-ssm
Quick Start
import torch
from mamba3_ssm import Mamba3, MambaLMHeadModel, MambaConfig
model = Mamba3(d_model=256, d_state=64, expand=2, headdim=32, is_mimo=True, mimo_rank=4)
x = torch.randn(2, 128, 256)
y = model(x) # (2, 128, 256)
# Autoregressive decode
angle, state, prev = model.allocate_inference_cache(2)
out, angle, state, prev = model.step(torch.randn(2, 256), angle, state, prev)
# Full language model
cfg = MambaConfig(d_model=1536, n_layer=20, vocab_size=50000,
ssm_cfg={"d_state": 64, "is_mimo": True, "mimo_rank": 4})
lm = MambaLMHeadModel(cfg)
logits = lm(torch.randint(0, 50000, (1, 512)))
Training
# Train on TinyStories with preset config
python train.py --dataset tinystories --preset small --epochs 3
# Custom data
python train.py --dataset custom --data-path myfile.txt --preset medium --epochs 1
# Resume from checkpoint
python train.py --dataset tinystories --resume checkpoints/best.pt
# Generate text from trained model
python generate.py --checkpoint checkpoints/best.pt --prompt "Once upon a time"
Available presets: small (112M, seq=512), medium (306M, seq=256), large (367M, seq=256). See mamba3_ssm/presets.py for details.
Performance
Acceleration Tiers
The SSM scan — the core bottleneck — uses a tiered acceleration strategy:
| Tier | Speedup | Availability |
|---|---|---|
| CUDA kernel | ~50× vs Python | Requires MSVC + CUDA 12.1+ |
| JIT (torch.jit.script) | ~2–3× vs Python | All platforms, no compilation |
| Pure Python | 1× | Always works |
Training Estimates (CUDA + JIT, RTX 4060 8GB Laptop)
SSM scan accelerated with CUDA SISO kernel (50× speedup) and JIT MIMO fallback:
| Preset | Params | Seq | VRAM | Micro-batch | Tok/s | Steps/ep | TinyStories×3ep |
|---|---|---|---|---|---|---|---|
| small | 112M | 512 | ~5.6GB | 117 ms | 8,780 | 5,798 | ~4 h 30 m |
| medium | 306M | 256 | ~7.2GB | 125 ms | 2,040 | 11,596 | ~19 h 24 m |
| large | 367M | 256 | ~8.6GB | 144 ms | 1,777 | 11,596 | ~22 h 16 m |
Note: All three presets fit on an 8GB laptop GPU. Training small on TinyStories for 3 epochs completes in under 5 hours. Medium/large use JIT MIMO fallback — a fixed CUDA MIMO kernel would further improve throughput.
VRAM at bf16 (batch=2, seq_len=512 with grad_accum)
| Preset | Params | d_model | n_layer | VRAM |
|---|---|---|---|---|
| small | 112M | 1024 | 16 | ~5.6 GB |
| medium | 306M | 1536 | 20 | ~7.2 GB |
| large | 367M | 1536 | 24 | ~8.6 GB |
Core Ideas
1. Exponential-Trapezoidal Discretization
Mamba-2 used Zero-Order Hold (first-order). Mamba-3 uses the trapezoidal rule:
h_t = exp(A·dt_t) · h_{t-1} + dt_t · σ(trap_t) · (B_t·x_t + B_{t-1}·x_{t-1}) / 2
Learned trap gate blends between Euler (trap≈0) and full trapezoidal (trap≈1).
2. Complex-Valued (Rotary) State Space
Applies RoPE to B and C projections, giving the state an effective complex-valued structure for tracking phase-dependent dependencies.
3. MIMO Formulation
Reuses a shared (H, D) state for R rank streams instead of SISO's (H, P, D) outer product:
| SISO | MIMO | |
|---|---|---|
| State shape | (H, P, D) |
(H, D) |
| Decode FLOPs/byte | Low (memory-bound) | R× higher |
CUDA Acceleration
The SSM scan is accelerated with a fused CUDA kernel when the MSVC compiler is available:
- SISO: Fully fused kernel — one block per (batch, head), P threads hold all D state-values in registers, B/C loaded via shared memory each timestep. Replaces the Python for-loop entirely.
- MIMO: Split design — outer einsums in PyTorch, inner state scan in CUDA. Uses tree-reduction over D for the output.
- JIT fallback: If MSVC is unavailable,
torch.jit.scriptprovides a ~2–3× speedup with no compilation needed.
To compile the CUDA kernel, install Visual Studio Build Tools with MSVC and run any scan function (compilation happens automatically on first call).
API Reference
Mamba3(d_model, d_state=128, expand=2, headdim=64, ngroups=1, rope_fraction=0.5, is_mimo=False, mimo_rank=4)
| Method | Description |
|---|---|
forward(u) |
(B, L, d_model) → (B, L, d_model) |
step(u, angle, state, prev) |
Single decode step, returns updated states |
allocate_inference_cache(B) |
Allocate zero states for decoding |
MambaLMHeadModel(config)
| Field | Default | Description |
|---|---|---|
d_model |
2560 | Hidden size |
n_layer |
64 | Number of blocks |
vocab_size |
50277 | Padded to multiple of 8 |
ssm_cfg |
{} |
Passed to Mamba3 |
d_intermediate |
0 | SwiGLU MLP (0 = disabled) |
tie_embeddings |
True | Tie LM head to embedding |
Exports
from mamba3_ssm import (
Mamba3, MambaLMHeadModel, MambaConfig, SSMConfig,
RMSNorm, apply_rope, ssm_scan_siso, ssm_scan_mimo,
CONFIGS,
)
Testing
python -m mamba3_ssm.tests
10/10 checks: shapes, numerical consistency (step-by-step == forward), gradient flow, parameter counting, edge cases.
Project Structure
mamba3_ssm/
├── __init__.py # Public API
├── config.py # MambaConfig / SSMConfig
├── ops.py # RMSNorm, RoPE, SSM scans (CUDA/JIT/Python)
├── cuda_backend.py # CUDA kernel compilation + Python wrappers
├── layer.py # Mamba3 module (forward + step)
├── block.py # MambaBlock, MambaLMHeadModel
├── presets.py # RTX 4060 benchmarked configs
├── tests.py # 10 sanity checks
└── utils.py # Parameter counting
Dependencies
torch>=2.0
einops>=0.7
Optional: datasets for auto-downloading TinyStories/Wikitext, wandb for logging.
Changelog
v0.2.1 (2026-06-28)
- Fix autocast dtype override: MIMO CUDA einsum pre/post-mix now wrapped in
autocast(enabled=False)to prevent bf16 autocast from overriding float32 tensors - Benchmark training times: Added measured throughput for all three presets on RTX 4060 (small: 4.5h, medium: 19h, large: 22h for TinyStories×3ep)
v0.2.0 (2026-06-28)
- CUDA-accelerated SSM scan: Fused SISO kernel (50× speedup); MIMO split kernel
- JIT fallback:
torch.jit.script— 1.9–2.7× speedup without CUDA compilation - Bug fix: double sigmoid: Removed redundant sigmoid on trap gate (forward path)
- Bug fix: lm_head dimension: Swapped to
Linear(d_model, vocab_size) - MSVC 14.44 + CUDA 12.1 compatibility: Added
_ALLOW_COMPILER_AND_STL_VERSION_MISMATCHworkaround
v0.1.2 (2026-05-31)
- Fix tokenizer cache loading bug, checkpoint resume, steps_per_epoch calculation
- Optimized SSM scan with pre-computed decay/trap factors
v0.1.1 (2026-05-31)
- Add
--presetflag, generate.py, presets.CONFIGS, RTX 4060 benchmarks
v0.1.0 (2026-05-31)
- Initial release — SISO & MIMO Mamba-3, 10/10 tests passing
License
MIT
References
- Lahoti et al., Mamba-3: Improved Sequence Modeling using State Space Principles, 2026. arXiv:2603.15569
- Official implementation: state-spaces/mamba
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file mamba3_ssm-0.2.1.tar.gz.
File metadata
- Download URL: mamba3_ssm-0.2.1.tar.gz
- Upload date:
- Size: 24.1 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.2.0 CPython/3.13.9
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
471a6a8c5929e5232886251e4ab5b62fbed29add609190387abcfe980f6d6e1b
|
|
| MD5 |
3cd66faeeee31a362f558ca0c237918e
|
|
| BLAKE2b-256 |
8c80afa9663fd0efba4a668fab9f16a647b8d4803a2bd76b6a07af2ef518842e
|
File details
Details for the file mamba3_ssm-0.2.1-py3-none-any.whl.
File metadata
- Download URL: mamba3_ssm-0.2.1-py3-none-any.whl
- Upload date:
- Size: 22.7 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.2.0 CPython/3.13.9
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
425400d0b4056e588926c599b80cdc86f7601a1b5eaceaca6fe15dfe5b5ea370
|
|
| MD5 |
4ebd3a9c934f478af06ac4a0bf88d3ae
|
|
| BLAKE2b-256 |
fb80ffa4fa877507411a6706fb48a006f6f8b9475c0093efcae949305e9acf86
|