Pure-PyTorch lightweight Mamba with multi-dilated causal conv front-end
Project description
lite-mamba
A minimal, pure-PyTorch version of Mamba with a multi-dilated causal depthwise conv front-end. No CUDA/Triton build needed; works on CPU or GPU with standard PyTorch ops.
Install
pip install torch einops
pip install lite-mamba
Usage
from lite_mamba import Mamba
import torch
x = torch.randn(2, 128, 512) # (batch, seq, d_model)
m = Mamba(d_model=512, d_conv=3, conv_dilations=(1,2,4,8))
y = m(x)
print(y.shape) # (2, 128, 512)
API quick reference
Mamba(d_model, d_state=16, d_conv=4, conv_dilations=(1,), expand=2, dt_rank="auto", dt_min=0.001, dt_max=0.1, dt_init="random", dt_scale=1.0, dt_init_floor=1e-4, conv_bias=True, bias=False, use_fast_path=False, layer_idx=None, device=None, dtype=None)
d_model(int, required): input/output embedding size.d_state(int, default 16): SSM state dimension per channel. Larger gives longer memory; increases compute.d_conv(int, default 4): depthwise conv kernel size for each branch.conv_dilations(tuple[int], default(1,)): dilation per branch. Multiple values create parallel dilated convs; effective receptive field is(d_conv-1)*dilation.expand(float, default 2): inner width multiplier; setsd_inner = expand * d_model.dt_rank(int or "auto", default "auto"): rank of delta projection. "auto" setsceil(d_model/16).dt_min,dt_max(float, defaults 1e-3 / 1e-1): log-uniform range for delta initialization.dt_init("random" | "constant", default "random") anddt_scale,dt_init_floor: control delta init magnitude/stability.conv_bias(bool, default True): include bias in depthwise convs.bias(bool, default False): include bias in input/output linear projections.use_fast_path(bool): ignored in this pure-PyTorch build; kept for API compatibility.layer_idx(int | None): identifier for streaming cache registration; required when usingallocate_inference_cache+inference_params.device,dtype: standard module factory kwargs.
Inference / streaming helpers
allocate_inference_cache(batch_size, max_seqlen, dtype=None): preallocates conv and SSM state buffers for step-wise decoding.step(hidden_states, conv_state, ssm_state): single-token forward (expectshidden_stateswith shape(B, 1, d_model)).forward(..., inference_params): ifinference_paramshas cached states (withkey_value_memory_dictandseqlen_offset), uses them for streaming.
Highlights
- Multi-branch causal dilated convs (weighted sum via learned gates).
- Pure Python: no custom C++/CUDA or Triton kernels.
- Streaming support via per-branch conv states and SSM state caching.
Practical setups
- Local modeling / small context:
d_conv=3,conv_dilations=(1,2,4),d_state=8–16,expand=2. - Longer context: widen
conv_dilations(e.g.,(1,2,4,8,16)) or increased_stateto 32; expect higher memory/compute. - Streaming/AR decoding: call
allocate_inference_cacheonce per layer, passinference_paramsduring forward; usestepinside your generation loop. - Stability first: keep
dt_min>= 1e-4 anddt_init_floorsmall; leave defaults unless you observe drift or exploding activations.
Notes
- Set different
conv_dilationsto adjust receptive field; keep kernels small (e.g., 3–5) to avoid excessive padding. use_fast_pathflag is ignored here (kept for API compatibility).- Reference selective scan is implemented in PyTorch for portability; faster fused kernels are omitted intentionally.
License
Apache-2.0
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file lite_mamba-0.1.4.tar.gz.
File metadata
- Download URL: lite_mamba-0.1.4.tar.gz
- Upload date:
- Size: 7.4 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.2.0 CPython/3.14.2
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
bed236e0843ea8052bd695a95d8051fae902414c31afee200007149dfe43aa04
|
|
| MD5 |
b62ed57bc662c8b689e7deb72ee60710
|
|
| BLAKE2b-256 |
a63f66313c1ffd1ef3fe4d3c2eb420a1aa89af36116a9dc8c354012b075f2935
|
File details
Details for the file lite_mamba-0.1.4-py3-none-any.whl.
File metadata
- Download URL: lite_mamba-0.1.4-py3-none-any.whl
- Upload date:
- Size: 7.8 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.2.0 CPython/3.14.2
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
3f5278fa270709083b628feeca5d7c6303ba40d69ca74e6ba269067e1ed6d8c4
|
|
| MD5 |
09020be5747b0500beb38ef7a73eb12f
|
|
| BLAKE2b-256 |
b78a285e57651a6ed6241d8bcb59b38e1ef14c1fe5cd80ee5196983ee8062094
|