Memory-efficient Mamba2 scan for HuggingFace Transformers. Fixes Zamba2 OOM on small GPUs.
Project description
mamba-scan-lite
Memory-efficient Mamba2 scan for HuggingFace Transformers. Fixes Zamba2 OOM on small GPUs. No CUDA compilation required.
The Problem
Running Zamba2 models on GPUs with less than 8 GB VRAM fails with OutOfMemoryError, even though the model weights fit. The cause is HuggingFace's naive Mamba2 scan implementation, which materializes GB-sized intermediate tensors for the SSM computation.
The official fix is to install mamba-ssm, but that requires CUDA compilation that fails on many setups (driver mismatches, missing headers, ABI conflicts).
The Fix
pip install mamba-scan-lite
import mamba_scan_lite # patches HF Zamba2 automatically
from transformers import AutoModelForCausalLM, AutoTokenizer
model = AutoModelForCausalLM.from_pretrained(
"Zyphra/Zamba2-2.7B-instruct",
device_map="cuda",
torch_dtype=torch.float16,
)
tokenizer = AutoTokenizer.from_pretrained("Zyphra/Zamba2-2.7B-instruct")
inputs = tokenizer("Hello, world!", return_tensors="pt").to("cuda")
outputs = model.generate(**inputs, max_new_tokens=100)
print(tokenizer.decode(outputs[0], skip_special_tokens=True))
That's it. One import, no configuration.
What It Does
Replaces two components in HF's Zamba2MambaMixer.torch_forward:
-
Chunked vectorized SSM scan (v0.2.0) instead of the chunked SSD that allocates GB-sized 6D tensors. Within each chunk, the linear recurrence is solved in closed form via cumulative products and prefix sums — no Python loop. Across chunks, state is carried sequentially. Memory stays at ~32 MB working set instead of 1+ GB.
-
Manual conv1d via unfold+einsum instead of
F.conv1d, which avoids cuDNN initialization failures on older drivers (Turing/SM75 GPUs).
The decode path (single-token generation with KV cache) is unchanged.
Benchmarks
Tested on NVIDIA Quadro T2000 (4 GB VRAM, Turing SM75):
| Model | Without Patch | v0.1.0 Sequential | v0.2.0 Chunked | VRAM Peak |
|---|---|---|---|---|
| Zamba2-1.2B | OOM | 2.3 tok/s | 2.3 tok/s (decode unchanged) | 1,560 MB |
| Zamba2-2.7B-Instruct | OOM | 1.0 tok/s | 1.0 tok/s (decode unchanged) | 3,134 MB |
Prefill speedup (v0.2.0 vs v0.1.0, Zamba2-2.7B on T2000):
| Sequence Length | v0.1.0 Prefill | v0.2.0 Prefill | Speedup |
|---|---|---|---|
| 32 tokens | 1.56s | 0.81s | 1.9x |
| 64 tokens | 3.14s | 0.85s | 3.7x |
| 128 tokens | 6.37s | 1.23s | 5.2x |
| 256 tokens | 13.04s | 2.24s | 5.8x |
Decode speed (single-token generation) is identical between versions — the chunked scan only affects prefill.
What's New in v0.2.0
- Chunked vectorized scan replaces the token-by-token sequential loop from v0.1.0
- 1.9x–5.8x prefill speedup depending on sequence length (longer sequences benefit more)
- NaN-safe fallback: automatically falls back to sequential processing when cumulative decay spans >80 orders of magnitude (rare edge case with extreme decay rates)
- Token-identical output: proven via comparison tests on real Zamba2 weights — the chunked and sequential scans produce bit-identical results
When to Use This
- You get
OutOfMemoryErrorrunning Zamba2 on a small GPU mamba-ssmwon't compile on your system- You get
cuDNN error: CUDNN_STATUS_NOT_INITIALIZED - You want to run Zamba2 without any CUDA kernel compilation
When NOT to Use This
- You have
mamba-ssminstalled and working (the official kernels are faster) - You're on a large GPU (24+ GB) where the naive path doesn't OOM
- You need maximum throughput (this trades speed for memory, though v0.2.0 closes the gap significantly)
How It Works
The HF naive Mamba2 scan computes the structured state space dual (SSD) via chunked matrix operations:
# HF naive path — allocates ~1+ GB intermediate
G_intermediate = C[:, :, :, None, :, :] * B[:, :, None, :, :, :] # 6D tensor
M = (G * L).sum(-1)
Y_diag = (M[..., None] * hidden_states[:, :, None]).sum(3) # OOM here
The v0.2.0 chunked scan solves the recurrence in closed form per chunk:
# mamba-scan-lite v0.2.0 — ~32 MB working set
for chunk in chunks(seq_len, chunk_size=32):
cumA = cumprod(decay_factors) # vectorized
scaled_b = inputs / cumA # element-wise
prefix = cumsum(scaled_b) # vectorized
h_chunk = cumA * (h_init + prefix) # element-wise
y_chunk = readout(h_chunk, C) + D * x # element-wise
h_init = h_chunk[-1] # carry state
Same math. O(state_size) memory. 4 vectorized ops per chunk instead of chunk_size loop iterations.
Requirements
- Python >= 3.9
- PyTorch >= 2.0
- Transformers >= 4.45 (Zamba2 support)
Also From EchoLabs
- helix-substrate — Universal model weight compression via HXQ (2D Vector Quantization). Faster-than-dense inference on compressed models.
- EchoLabs33 on HuggingFace — 14 compressed models across Transformer, SSM, and Hybrid architectures.
License
MIT
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file mamba_scan_lite-0.2.0.tar.gz.
File metadata
- Download URL: mamba_scan_lite-0.2.0.tar.gz
- Upload date:
- Size: 11.8 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.2.0 CPython/3.10.12
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
f6483cc1b7593d68d10c829853c7052e8aa982bac64e822e2485bbbaa9c7516a
|
|
| MD5 |
15ba9df6ff45571fd04d5f39359bb86b
|
|
| BLAKE2b-256 |
02d626a28cfc355809b405017998834f91338f63b62db3d62d37553ad56e5661
|
File details
Details for the file mamba_scan_lite-0.2.0-py3-none-any.whl.
File metadata
- Download URL: mamba_scan_lite-0.2.0-py3-none-any.whl
- Upload date:
- Size: 10.5 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.2.0 CPython/3.10.12
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
b8f2ede360af878c27e4514c5726a2ac2c0f6aeb2acb91fa0aab90726e27eb08
|
|
| MD5 |
0d13dfb6da32df0cf4389a764f416dbc
|
|
| BLAKE2b-256 |
898061eeb72640e84676b374a93ed8e8256fdb4914a7f6a5181901b9aab5a658
|