Skip to main content

Mamba-3: Improved Sequence Modeling using State Space Principles

Project description

Mamba-3: Improved Sequence Modeling using State Space Principles

PyPI version Python 3.10+ License: MIT

A clean, readable, from-scratch PyTorch implementation of Mamba-3 arXiv:2603.15569. No Triton/CUDA kernels. Train a 380M parameter model on a laptop GPU.

Installation

pip install mamba3-ssm

Quick Start

import torch
from mamba3_ssm import Mamba3, MambaLMHeadModel, MambaConfig

model = Mamba3(d_model=256, d_state=64, expand=2, headdim=32, is_mimo=True, mimo_rank=4)
x = torch.randn(2, 128, 256)
y = model(x)  # (2, 128, 256)

# Autoregressive decode
angle, state, prev = model.allocate_inference_cache(2)
out, angle, state, prev = model.step(torch.randn(2, 256), angle, state, prev)

# Full language model
cfg = MambaConfig(d_model=1536, n_layer=20, vocab_size=50000,
                  ssm_cfg={"d_state": 64, "is_mimo": True, "mimo_rank": 4})
lm = MambaLMHeadModel(cfg)
logits = lm(torch.randint(0, 50000, (1, 512)))

Training

Presets (benchmarked on RTX 4060 Laptop 8GB)

Based on actual VRAM measurements (bf16 + AdamW):

Preset Params d_model n_layer d_state batch seq_len VRAM Status
small 112M 1024 16 64 2 512 ~5.6GB ✅ 舒适
medium 306M 1536 20 64 1 256 ~7.2GB ✅ 推荐
large 367M 1536 24 64 1 256 ~8.6GB ⚠️ 极限

Effective batch size = batch × grad_accum (default grad_accum=16 for all presets).

# Train 306M model on TinyStories (auto-downloads)
python train.py --dataset tinystories --preset medium --epochs 3

# Quick experiment with 112M on custom text
python train.py --dataset custom --data-path myfile.txt --preset small --epochs 5

# Wikitext-103 benchmark
python train.py --dataset wikitext --preset medium --epochs 5

# Resume training
python train.py --dataset tinystories --preset medium --resume checkpoints/best.pt

# Full custom config
python train.py --dataset tinystories --d-model 1024 --n-layer 16 --d-state 64 \
    --batch-size 2 --seq-len 512 --grad-accum 8 --learning-rate 3e-4 --epochs 3

# With W&B logging
python train.py --dataset tinystories --preset medium --wandb --wandb-project my-mamba3

Text Generation

python generate.py --checkpoint checkpoints/best.pt \
    --prompt "Once upon a time" --max-tokens 200 --temperature 0.8

Custom Training Code

import torch
from mamba3_ssm import MambaLMHeadModel, MambaConfig, CONFIGS

# Use a preset or define your own config
cfg = CONFIGS["medium"]  # dict with d_model, n_layer, etc.
model = MambaLMHeadModel(MambaConfig(
    d_model=cfg["d_model"],
    n_layer=cfg["n_layer"],
    vocab_size=10000,
    ssm_cfg={"d_state": cfg["d_state"], "expand": 2, "headdim": 64,
             "is_mimo": True, "mimo_rank": 4},
)).cuda()

optimizer = torch.optim.AdamW(model.parameters(), lr=3e-4)
print(f"Params: {sum(p.numel() for p in model.parameters()):,}")

Core Ideas

1. Exponential-Trapezoidal Discretization

Mamba-2 used Zero-Order Hold (first-order). Mamba-3 uses the trapezoidal rule:

h_t = exp(A·dt_t) · h_{t-1} + dt_t · σ(trap_t) · (B_t·x_t + B_{t-1}·x_{t-1}) / 2

Learned trap gate blends between Euler (trap≈0) and full trapezoidal (trap≈1).

2. Complex-Valued (Rotary) State Space

Applies RoPE to B and C projections, giving the state an effective complex-valued structure for tracking phase-dependent dependencies.

3. MIMO Formulation

Reuses a shared (H, D) state for R rank streams instead of SISO's (H, P, D) outer product:

SISO MIMO
State shape (H, P, D) (H, D)
Decode FLOPs/byte Low (memory-bound) R× higher

API Reference

Mamba3(d_model, d_state=128, expand=2, headdim=64, ngroups=1, rope_fraction=0.5, is_mimo=False, mimo_rank=4)

Method Description
forward(u) (B, L, d_model)(B, L, d_model)
step(u, angle, state, prev) Single decode step, returns updated states
allocate_inference_cache(B) Allocate zero states for decoding

MambaLMHeadModel(config)

Field Default Description
d_model 2560 Hidden size
n_layer 64 Number of blocks
vocab_size 50277 Padded to multiple of 8
ssm_cfg {} Passed to Mamba3
d_intermediate 0 SwiGLU MLP (0 = disabled)
tie_embeddings True Tie LM head to embedding

Exports

from mamba3_ssm import (
    Mamba3, MambaLMHeadModel, MambaConfig, SSMConfig,
    RMSNorm, apply_rope, ssm_scan_siso, ssm_scan_mimo,
    CONFIGS,  # RTX 4060 benchmarked presets
)

Testing

python -m mamba3_ssm.tests

10/10 checks: shapes, numerical consistency (step-by-step == forward), gradient flow, parameter counting, edge cases.

Project Structure

mamba3_ssm/          # pip installable package
├── __init__.py      # Public API
├── config.py        # MambaConfig / SSMConfig
├── ops.py           # RMSNorm, RoPE, SSM scans
├── layer.py         # Mamba3 module (forward + step)
├── block.py         # MambaBlock, MambaLMHeadModel
├── presets.py       # RTX 4060 benchmarked configs
├── tests.py         # 10 sanity checks
└── utils.py         # Parameter counting

train.py             # Training script
generate.py          # Text generation
docs/
├── API.md           # Full API reference
└── TRAINING.md      # Training guide with tips

RTX 4060 Laptop Tips

  • bf16 is enabled automatically on RTX 40-series — no config needed
  • MIMO gives ~20% speedup in decode over SISO
  • d_state=64 is the sweet spot for 8GB; go to 128 only if you reduce d_model
  • grad_accum lets you simulate large batches without extra VRAM
  • If OOM: reduce seq_len first (512→256→128), then d_model

Hardware Requirements

Component Minimum Recommended
GPU VRAM 4 GB 8 GB
RAM 8 GB 16 GB
Disk 1 GB 5 GB (with datasets)

Tested on RTX 4060 Laptop (8GB), PyTorch 2.6+cu124.

Dependencies

torch>=2.0
einops>=0.7

Optional: datasets for auto-downloading TinyStories/Wikitext, wandb for logging.

Changelog

v0.1.1 (2026-05-31)

  • Add --preset flag to train.py (small/medium/large) benchmarked on RTX 4060 8GB
  • Fix default config to fit 8GB VRAM (d_state=64, bs=1, seq_len=256)
  • Add generate.py for text generation from checkpoints
  • Add mamba3_ssm.presets.CONFIGS with VRAM-benchmarked configurations
  • Update README with training guide and benchmark table

v0.1.0 (2026-05-31)

  • Initial release — SISO & MIMO Mamba-3, 10/10 tests passing

License

MIT

References

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

mamba3_ssm-0.1.1.tar.gz (20.1 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

mamba3_ssm-0.1.1-py3-none-any.whl (18.8 kB view details)

Uploaded Python 3

File details

Details for the file mamba3_ssm-0.1.1.tar.gz.

File metadata

  • Download URL: mamba3_ssm-0.1.1.tar.gz
  • Upload date:
  • Size: 20.1 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.13.9

File hashes

Hashes for mamba3_ssm-0.1.1.tar.gz
Algorithm Hash digest
SHA256 8f28b628058a2e1ab45eba9ab3e98e4d79d128078a6574b8b71dc39b0411fb1a
MD5 4155d5308a27c42f398b049e7c4a7bd8
BLAKE2b-256 bd080163135571746eec6486142fb84d631b17274efc54e1f7159253b8cdbe39

See more details on using hashes here.

File details

Details for the file mamba3_ssm-0.1.1-py3-none-any.whl.

File metadata

  • Download URL: mamba3_ssm-0.1.1-py3-none-any.whl
  • Upload date:
  • Size: 18.8 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.13.9

File hashes

Hashes for mamba3_ssm-0.1.1-py3-none-any.whl
Algorithm Hash digest
SHA256 ac71b6a2ea6eecaa9bd0c888188140c16aa2b5a9bfb5e25e8bf5096b6a63bc87
MD5 0edf901c8ea1e83196adbef3b08a124f
BLAKE2b-256 90867bb87e51db4f71ffaa1cbfa24584736e4a072a4c13924e5a85dbf464adc8

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page