Skip to main content

Mamba-3: Improved Sequence Modeling using State Space Principles

Project description

Mamba-3: Improved Sequence Modeling using State Space Principles

PyPI version pip install: pip install mamba3-ssm · version: 0.2.1 Python 3.10+ License: MIT

A clean, readable, from-scratch PyTorch implementation of Mamba-3 arXiv:2603.15569. Features CUDA-accelerated SSM scans (50× speedup) and MSVC/CUDA 12.1 compatibility on Windows.

Installation

pip install mamba3-ssm

Quick Start

import torch
from mamba3_ssm import Mamba3, MambaLMHeadModel, MambaConfig

model = Mamba3(d_model=256, d_state=64, expand=2, headdim=32, is_mimo=True, mimo_rank=4)
x = torch.randn(2, 128, 256)
y = model(x)  # (2, 128, 256)

# Autoregressive decode
angle, state, prev = model.allocate_inference_cache(2)
out, angle, state, prev = model.step(torch.randn(2, 256), angle, state, prev)

# Full language model
cfg = MambaConfig(d_model=1536, n_layer=20, vocab_size=50000,
                  ssm_cfg={"d_state": 64, "is_mimo": True, "mimo_rank": 4})
lm = MambaLMHeadModel(cfg)
logits = lm(torch.randint(0, 50000, (1, 512)))

Training

# Train on TinyStories with preset config
python train.py --dataset tinystories --preset small --epochs 3

# Custom data
python train.py --dataset custom --data-path myfile.txt --preset medium --epochs 1

# Resume from checkpoint
python train.py --dataset tinystories --resume checkpoints/best.pt

# Generate text from trained model
python generate.py --checkpoint checkpoints/best.pt --prompt "Once upon a time"

Available presets: small (112M, seq=512), medium (306M, seq=256), large (367M, seq=256). See mamba3_ssm/presets.py for details.

Performance

Acceleration Tiers

The SSM scan — the core bottleneck — uses a tiered acceleration strategy:

Tier Speedup Availability
CUDA kernel ~50× vs Python Requires MSVC + CUDA 12.1+
JIT (torch.jit.script) ~2–3× vs Python All platforms, no compilation
Pure Python Always works

Training Estimates (CUDA + JIT, RTX 4060 8GB Laptop)

SSM scan accelerated with CUDA SISO kernel (50× speedup) and JIT MIMO fallback:

Preset Params Seq VRAM Micro-batch Tok/s Steps/ep TinyStories×3ep
small 112M 512 ~5.6GB 117 ms 8,780 5,798 ~4 h 30 m
medium 306M 256 ~7.2GB 125 ms 2,040 11,596 ~19 h 24 m
large 367M 256 ~8.6GB 144 ms 1,777 11,596 ~22 h 16 m

Note: All three presets fit on an 8GB laptop GPU. Training small on TinyStories for 3 epochs completes in under 5 hours. Medium/large use JIT MIMO fallback — a fixed CUDA MIMO kernel would further improve throughput.

VRAM at bf16 (batch=2, seq_len=512 with grad_accum)

Preset Params d_model n_layer VRAM
small 112M 1024 16 ~5.6 GB
medium 306M 1536 20 ~7.2 GB
large 367M 1536 24 ~8.6 GB

Core Ideas

1. Exponential-Trapezoidal Discretization

Mamba-2 used Zero-Order Hold (first-order). Mamba-3 uses the trapezoidal rule:

h_t = exp(A·dt_t) · h_{t-1} + dt_t · σ(trap_t) · (B_t·x_t + B_{t-1}·x_{t-1}) / 2

Learned trap gate blends between Euler (trap≈0) and full trapezoidal (trap≈1).

2. Complex-Valued (Rotary) State Space

Applies RoPE to B and C projections, giving the state an effective complex-valued structure for tracking phase-dependent dependencies.

3. MIMO Formulation

Reuses a shared (H, D) state for R rank streams instead of SISO's (H, P, D) outer product:

SISO MIMO
State shape (H, P, D) (H, D)
Decode FLOPs/byte Low (memory-bound) R× higher

CUDA Acceleration

The SSM scan is accelerated with a fused CUDA kernel when the MSVC compiler is available:

  • SISO: Fully fused kernel — one block per (batch, head), P threads hold all D state-values in registers, B/C loaded via shared memory each timestep. Replaces the Python for-loop entirely.
  • MIMO: Split design — outer einsums in PyTorch, inner state scan in CUDA. Uses tree-reduction over D for the output.
  • JIT fallback: If MSVC is unavailable, torch.jit.script provides a ~2–3× speedup with no compilation needed.

To compile the CUDA kernel, install Visual Studio Build Tools with MSVC and run any scan function (compilation happens automatically on first call).

API Reference

Mamba3(d_model, d_state=128, expand=2, headdim=64, ngroups=1, rope_fraction=0.5, is_mimo=False, mimo_rank=4)

Method Description
forward(u) (B, L, d_model)(B, L, d_model)
step(u, angle, state, prev) Single decode step, returns updated states
allocate_inference_cache(B) Allocate zero states for decoding

MambaLMHeadModel(config)

Field Default Description
d_model 2560 Hidden size
n_layer 64 Number of blocks
vocab_size 50277 Padded to multiple of 8
ssm_cfg {} Passed to Mamba3
d_intermediate 0 SwiGLU MLP (0 = disabled)
tie_embeddings True Tie LM head to embedding

Exports

from mamba3_ssm import (
    Mamba3, MambaLMHeadModel, MambaConfig, SSMConfig,
    RMSNorm, apply_rope, ssm_scan_siso, ssm_scan_mimo,
    CONFIGS,
)

Testing

python -m mamba3_ssm.tests

10/10 checks: shapes, numerical consistency (step-by-step == forward), gradient flow, parameter counting, edge cases.

Project Structure

mamba3_ssm/
├── __init__.py      # Public API
├── config.py        # MambaConfig / SSMConfig
├── ops.py           # RMSNorm, RoPE, SSM scans (CUDA/JIT/Python)
├── cuda_backend.py  # CUDA kernel compilation + Python wrappers
├── layer.py         # Mamba3 module (forward + step)
├── block.py         # MambaBlock, MambaLMHeadModel
├── presets.py       # RTX 4060 benchmarked configs
├── tests.py         # 10 sanity checks
└── utils.py         # Parameter counting

Dependencies

torch>=2.0
einops>=0.7

Optional: datasets for auto-downloading TinyStories/Wikitext, wandb for logging.

Changelog

v0.2.1 (2026-06-28)

  • Fix autocast dtype override: MIMO CUDA einsum pre/post-mix now wrapped in autocast(enabled=False) to prevent bf16 autocast from overriding float32 tensors
  • Benchmark training times: Added measured throughput for all three presets on RTX 4060 (small: 4.5h, medium: 19h, large: 22h for TinyStories×3ep)

v0.2.0 (2026-06-28)

  • CUDA-accelerated SSM scan: Fused SISO kernel (50× speedup); MIMO split kernel
  • JIT fallback: torch.jit.script — 1.9–2.7× speedup without CUDA compilation
  • Bug fix: double sigmoid: Removed redundant sigmoid on trap gate (forward path)
  • Bug fix: lm_head dimension: Swapped to Linear(d_model, vocab_size)
  • MSVC 14.44 + CUDA 12.1 compatibility: Added _ALLOW_COMPILER_AND_STL_VERSION_MISMATCH workaround

v0.1.2 (2026-05-31)

  • Fix tokenizer cache loading bug, checkpoint resume, steps_per_epoch calculation
  • Optimized SSM scan with pre-computed decay/trap factors

v0.1.1 (2026-05-31)

  • Add --preset flag, generate.py, presets.CONFIGS, RTX 4060 benchmarks

v0.1.0 (2026-05-31)

  • Initial release — SISO & MIMO Mamba-3, 10/10 tests passing

License

MIT

References

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

mamba3_ssm-0.2.1.tar.gz (24.1 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

mamba3_ssm-0.2.1-py3-none-any.whl (22.7 kB view details)

Uploaded Python 3

File details

Details for the file mamba3_ssm-0.2.1.tar.gz.

File metadata

  • Download URL: mamba3_ssm-0.2.1.tar.gz
  • Upload date:
  • Size: 24.1 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.13.9

File hashes

Hashes for mamba3_ssm-0.2.1.tar.gz
Algorithm Hash digest
SHA256 471a6a8c5929e5232886251e4ab5b62fbed29add609190387abcfe980f6d6e1b
MD5 3cd66faeeee31a362f558ca0c237918e
BLAKE2b-256 8c80afa9663fd0efba4a668fab9f16a647b8d4803a2bd76b6a07af2ef518842e

See more details on using hashes here.

File details

Details for the file mamba3_ssm-0.2.1-py3-none-any.whl.

File metadata

  • Download URL: mamba3_ssm-0.2.1-py3-none-any.whl
  • Upload date:
  • Size: 22.7 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.13.9

File hashes

Hashes for mamba3_ssm-0.2.1-py3-none-any.whl
Algorithm Hash digest
SHA256 425400d0b4056e588926c599b80cdc86f7601a1b5eaceaca6fe15dfe5b5ea370
MD5 4ebd3a9c934f478af06ac4a0bf88d3ae
BLAKE2b-256 fb80ffa4fa877507411a6706fb48a006f6f8b9475c0093efcae949305e9acf86

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page