Skip to main content

Memory-efficient Mamba2 scan for HuggingFace Transformers. Fixes Zamba2 OOM on small GPUs.

Project description

mamba-scan-lite

Memory-efficient Mamba2 scan for HuggingFace Transformers. Fixes Zamba2 OOM on small GPUs. No CUDA compilation required.

The Problem

Running Zamba2 models on GPUs with less than 8 GB VRAM fails with OutOfMemoryError, even though the model weights fit. The cause is HuggingFace's naive Mamba2 scan implementation, which materializes GB-sized intermediate tensors for the SSM computation.

The official fix is to install mamba-ssm, but that requires CUDA compilation that fails on many setups (driver mismatches, missing headers, ABI conflicts).

The Fix

pip install mamba-scan-lite
import mamba_scan_lite  # patches HF Zamba2 automatically

from transformers import AutoModelForCausalLM, AutoTokenizer

model = AutoModelForCausalLM.from_pretrained(
    "Zyphra/Zamba2-2.7B-instruct",
    device_map="cuda",
    torch_dtype=torch.float16,
)
tokenizer = AutoTokenizer.from_pretrained("Zyphra/Zamba2-2.7B-instruct")

inputs = tokenizer("Hello, world!", return_tensors="pt").to("cuda")
outputs = model.generate(**inputs, max_new_tokens=100)
print(tokenizer.decode(outputs[0], skip_special_tokens=True))

That's it. One import, no configuration.

What It Does

Replaces two components in HF's Zamba2MambaMixer.torch_forward:

  1. Chunked vectorized SSM scan (v0.2.0) instead of the chunked SSD that allocates GB-sized 6D tensors. Within each chunk, the linear recurrence is solved in closed form via cumulative products and prefix sums — no Python loop. Across chunks, state is carried sequentially. Memory stays at ~32 MB working set instead of 1+ GB.

  2. Manual conv1d via unfold+einsum instead of F.conv1d, which avoids cuDNN initialization failures on older drivers (Turing/SM75 GPUs).

The decode path (single-token generation with KV cache) is unchanged.

Benchmarks

Tested on NVIDIA Quadro T2000 (4 GB VRAM, Turing SM75):

Model Without Patch v0.1.0 Sequential v0.2.0 Chunked VRAM Peak
Zamba2-1.2B OOM 2.3 tok/s 2.3 tok/s (decode unchanged) 1,560 MB
Zamba2-2.7B-Instruct OOM 1.0 tok/s 1.0 tok/s (decode unchanged) 3,134 MB

Prefill speedup (v0.2.0 vs v0.1.0, Zamba2-2.7B on T2000):

Sequence Length v0.1.0 Prefill v0.2.0 Prefill Speedup
32 tokens 1.56s 0.81s 1.9x
64 tokens 3.14s 0.85s 3.7x
128 tokens 6.37s 1.23s 5.2x
256 tokens 13.04s 2.24s 5.8x

Decode speed (single-token generation) is identical between versions — the chunked scan only affects prefill.

What's New in v0.2.0

  • Chunked vectorized scan replaces the token-by-token sequential loop from v0.1.0
  • 1.9x–5.8x prefill speedup depending on sequence length (longer sequences benefit more)
  • NaN-safe fallback: automatically falls back to sequential processing when cumulative decay spans >80 orders of magnitude (rare edge case with extreme decay rates)
  • Token-identical output: proven via comparison tests on real Zamba2 weights — the chunked and sequential scans produce bit-identical results

When to Use This

  • You get OutOfMemoryError running Zamba2 on a small GPU
  • mamba-ssm won't compile on your system
  • You get cuDNN error: CUDNN_STATUS_NOT_INITIALIZED
  • You want to run Zamba2 without any CUDA kernel compilation

When NOT to Use This

  • You have mamba-ssm installed and working (the official kernels are faster)
  • You're on a large GPU (24+ GB) where the naive path doesn't OOM
  • You need maximum throughput (this trades speed for memory, though v0.2.0 closes the gap significantly)

How It Works

The HF naive Mamba2 scan computes the structured state space dual (SSD) via chunked matrix operations:

# HF naive path — allocates ~1+ GB intermediate
G_intermediate = C[:, :, :, None, :, :] * B[:, :, None, :, :, :]  # 6D tensor
M = (G * L).sum(-1)
Y_diag = (M[..., None] * hidden_states[:, :, None]).sum(3)        # OOM here

The v0.2.0 chunked scan solves the recurrence in closed form per chunk:

# mamba-scan-lite v0.2.0 — ~32 MB working set
for chunk in chunks(seq_len, chunk_size=32):
    cumA = cumprod(decay_factors)           # vectorized
    scaled_b = inputs / cumA                # element-wise
    prefix = cumsum(scaled_b)              # vectorized
    h_chunk = cumA * (h_init + prefix)     # element-wise
    y_chunk = readout(h_chunk, C) + D * x  # element-wise
    h_init = h_chunk[-1]                   # carry state

Same math. O(state_size) memory. 4 vectorized ops per chunk instead of chunk_size loop iterations.

Requirements

  • Python >= 3.9
  • PyTorch >= 2.0
  • Transformers >= 4.45 (Zamba2 support)

Also From EchoLabs

  • helix-substrate — Universal model weight compression via HXQ (2D Vector Quantization). Faster-than-dense inference on compressed models.
  • EchoLabs33 on HuggingFace — 14 compressed models across Transformer, SSM, and Hybrid architectures.

License

MIT

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

mamba_scan_lite-0.2.0.tar.gz (11.8 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

mamba_scan_lite-0.2.0-py3-none-any.whl (10.5 kB view details)

Uploaded Python 3

File details

Details for the file mamba_scan_lite-0.2.0.tar.gz.

File metadata

  • Download URL: mamba_scan_lite-0.2.0.tar.gz
  • Upload date:
  • Size: 11.8 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.10.12

File hashes

Hashes for mamba_scan_lite-0.2.0.tar.gz
Algorithm Hash digest
SHA256 f6483cc1b7593d68d10c829853c7052e8aa982bac64e822e2485bbbaa9c7516a
MD5 15ba9df6ff45571fd04d5f39359bb86b
BLAKE2b-256 02d626a28cfc355809b405017998834f91338f63b62db3d62d37553ad56e5661

See more details on using hashes here.

File details

Details for the file mamba_scan_lite-0.2.0-py3-none-any.whl.

File metadata

File hashes

Hashes for mamba_scan_lite-0.2.0-py3-none-any.whl
Algorithm Hash digest
SHA256 b8f2ede360af878c27e4514c5726a2ac2c0f6aeb2acb91fa0aab90726e27eb08
MD5 0d13dfb6da32df0cf4389a764f416dbc
BLAKE2b-256 898061eeb72640e84676b374a93ed8e8256fdb4914a7f6a5181901b9aab5a658

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page