Skip to main content

Memory-efficient Mamba2 scan for HuggingFace Transformers. Fixes Zamba2 OOM on small GPUs.

Project description

mamba-scan-lite

Memory-efficient Mamba2 scan for HuggingFace Transformers. Fixes Zamba2 OOM on small GPUs. No CUDA compilation required.

The Problem

Running Zamba2 models on GPUs with less than 8 GB VRAM fails with OutOfMemoryError, even though the model weights fit. The cause is HuggingFace's naive Mamba2 scan implementation, which materializes GB-sized intermediate tensors for the SSM computation.

The official fix is to install mamba-ssm, but that requires CUDA compilation that fails on many setups (driver mismatches, missing headers, ABI conflicts).

The Fix

pip install mamba-scan-lite
import mamba_scan_lite  # patches HF Zamba2 automatically

from transformers import AutoModelForCausalLM, AutoTokenizer

model = AutoModelForCausalLM.from_pretrained(
    "Zyphra/Zamba2-2.7B-instruct",
    device_map="cuda",
    torch_dtype=torch.float16,
)
tokenizer = AutoTokenizer.from_pretrained("Zyphra/Zamba2-2.7B-instruct")

inputs = tokenizer("Hello, world!", return_tensors="pt").to("cuda")
outputs = model.generate(**inputs, max_new_tokens=100)
print(tokenizer.decode(outputs[0], skip_special_tokens=True))

That's it. One import, no configuration.

What It Does

Replaces two components in HF's Zamba2MambaMixer.torch_forward:

  1. Sequential SSM scan instead of the chunked SSD that allocates GB-sized 6D tensors. Processes tokens one at a time, maintaining an [batch, heads, head_dim, state_size] state (~1 MB) instead of materializing the full scan matrix.

  2. Manual conv1d via unfold+einsum instead of F.conv1d, which avoids cuDNN initialization failures on older drivers (Turing/SM75 GPUs).

The decode path (single-token generation with KV cache) is unchanged.

Benchmarks

Tested on NVIDIA Quadro T2000 (4 GB VRAM, Turing SM75):

Model Without Patch With Patch VRAM Peak
Zamba2-1.2B OOM 2.3 tok/s 1,560 MB
Zamba2-2.7B-Instruct OOM 1.0 tok/s 3,134 MB

Also fixes RuntimeError: cuDNN error: CUDNN_STATUS_NOT_INITIALIZED on Turing GPUs.

When to Use This

  • You get OutOfMemoryError running Zamba2 on a small GPU
  • mamba-ssm won't compile on your system
  • You get cuDNN error: CUDNN_STATUS_NOT_INITIALIZED
  • You want to run Zamba2 without any CUDA kernel compilation

When NOT to Use This

  • You have mamba-ssm installed and working (the official kernels are faster)
  • You're on a large GPU (24+ GB) where the naive path doesn't OOM
  • You need maximum throughput (this trades speed for memory)

How It Works

The HF naive Mamba2 scan computes the structured state space dual (SSD) via chunked matrix operations:

# HF naive path — allocates ~1+ GB intermediate
G_intermediate = C[:, :, :, None, :, :] * B[:, :, None, :, :, :]  # 6D tensor
M = (G * L).sum(-1)
Y_diag = (M[..., None] * hidden_states[:, :, None]).sum(3)        # OOM here

The patched path computes the same recurrence sequentially:

# mamba-scan-lite — ~1 MB state
for t in range(seq_len):
    h = exp(A * dt[t]) * h + (dt[t] * x[t]) outer B[t]
    y[t] = (C[t] * h).sum(-1) + D * x[t]

Same math. O(state_size) memory instead of O(seq_len * state_size^2).

Requirements

  • Python >= 3.9
  • PyTorch >= 2.0
  • Transformers >= 4.45 (Zamba2 support)

Also From EchoLabs

  • helix-substrate — Universal model weight compression via HXQ (2D Vector Quantization). Faster-than-dense inference on compressed models.
  • EchoLabs33 on HuggingFace — 12 compressed models across Transformer, SSM, and Hybrid architectures.

License

MIT

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

mamba_scan_lite-0.1.0.tar.gz (6.9 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

mamba_scan_lite-0.1.0-py3-none-any.whl (7.6 kB view details)

Uploaded Python 3

File details

Details for the file mamba_scan_lite-0.1.0.tar.gz.

File metadata

  • Download URL: mamba_scan_lite-0.1.0.tar.gz
  • Upload date:
  • Size: 6.9 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.10.12

File hashes

Hashes for mamba_scan_lite-0.1.0.tar.gz
Algorithm Hash digest
SHA256 1e99519a291547931721a99a83a8b31e3a0f0329e5524c25959f70af80ed26a2
MD5 682d0e6ca32ef2ed84e53fa36e69ee34
BLAKE2b-256 f303654c0c40365ba7faae310dc660675738fe302413651f0152c0e03a9273d0

See more details on using hashes here.

File details

Details for the file mamba_scan_lite-0.1.0-py3-none-any.whl.

File metadata

File hashes

Hashes for mamba_scan_lite-0.1.0-py3-none-any.whl
Algorithm Hash digest
SHA256 25d3deadcde59809fa3ed990ebc513c1f3a434544ef43b428682e29322c9d85a
MD5 ac2aa0a18009612602abfd58cec53c7e
BLAKE2b-256 6d3ab0fd9c19f2b2fee27b36a918b41adcb9b44a4de751e22e0c92fad174a58a

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page