Pure-PyTorch lightweight Mamba with multi-dilated causal conv front-end
Project description
lite-mamba
A minimal, pure-PyTorch implementation of Mamba with a multi-dilated causal depthwise conv front-end. No CUDA/Triton build needed; works on CPU or GPU with standard PyTorch ops.
Dual Framework Support: Includes both PyTorch and TensorFlow implementations with identical core logic and mathematical formulations, so you can use either framework depending on your environment.
Install
pip install torch einops
pip install lite-mamba
TensorFlow path (optional):
pip install "lite-mamba[tensorflow]"
Usage
PyTorch
from lite_mamba import Mamba, PTCNMamba, STCNMamba, DPWCMamba
import torch
x = torch.randn(2, 128, 512) # (batch, seq, d_model)
m = Mamba(d_model=512, d_conv=3, conv_dilations=(1, 2, 4, 8))
y = m(x)
print(y.shape) # (2, 128, 512)
TensorFlow
from lite_mamba import TFPTCNMamba
import tensorflow as tf
x = tf.random.normal((2, 128, 512)) # (batch, seq, d_model)
m = TFPTCNMamba(d_model=512, d_conv=3, conv_dilations=(1, 2, 4, 8))
y = m(x)
print(y.shape) # (2, 128, 512)
Conv front-end variants
PTCNMamba: identical toMamba, mixes parallel dilated depthwise conv branches via learned softmax gates.STCNMamba: runs the same depthwise conv layers in sequence (no gating); each branch output feeds the next to create a deterministic dilation stack.DPWCMamba: pairs each depthwise branch with a pointwise (1×1) conv before the gating mix, adding extra channel mixing without stacking more layers.
All variants expose the same constructor signature (d_model, d_state, conv_dilations, etc.) and streaming helpers (allocate_inference_cache, step). Swap them simply by changing the imported class name:
from lite_mamba import STCNMamba
m = STCNMamba(d_model=512, d_state=16, conv_dilations=(1, 2, 4))
Use DPWCMamba for richer channel interactions in each branch, and STCNMamba when you want a straightforward sequential dilation pipeline (e.g., for debugging or reproducing the behavior of stacked TCN layers).
Baseline helper
BaselineMamba mirrors the upstream state-spaces/mamba block: a single depthwise causal convolution followed by the SSM parameter projection, selective scan recurrence, and streaming helpers. baseline_mamba is a thin functional alias that instantiates the class with the same defaults so you can reproduce the reference layout without duplicating constructor arguments.
from lite_mamba import BaselineMamba
m = BaselineMamba(d_model=512, d_conv=3)
TensorFlow variants
TensorFlow implementations mirror the exact same core logic and mathematical formulations as PyTorch:
TFBaselineMamba- Single-branch baseline matching reference MambaTFPTCNMamba- Parallel dilated TCN branches with learned gating (defaultTFMamba)TFSTCNMamba- Stacked/sequential dilated TCN branchesTFDPWCMamba- Depthwise + pointwise convolution branches
All TensorFlow variants support:
- Same constructor API as PyTorch versions
- Streaming inference via
allocate_inference_cache()andstep() - Custom dilated causal convolutions
- Identical SSM discretization and selective scan
import tensorflow as tf
from lite_mamba import TFBaselineMamba, TFSTCNMamba, TFDPWCMamba
# Baseline single-branch
m1 = TFBaselineMamba(d_model=512, d_conv=3)
# Stacked dilated branches
m2 = TFSTCNMamba(d_model=512, d_conv=3, conv_dilations=(1, 2, 4))
# Depthwise + pointwise
m3 = TFDPWCMamba(d_model=512, d_conv=3, conv_dilations=(1, 2, 4, 8))
API quick reference
Mamba(d_model, d_state=16, d_conv=4, conv_dilations=(1,), expand=2, dt_rank="auto", dt_min=0.001, dt_max=0.1, dt_init="random", dt_scale=1.0, dt_init_floor=1e-4, conv_bias=True, bias=False, use_fast_path=False, layer_idx=None, device=None, dtype=None)
d_model(int, required): input/output embedding size.d_state(int, default 16): SSM state dimension per channel. Larger gives longer memory; increases compute.d_conv(int, default 4): depthwise conv kernel size for each branch.conv_dilations(tuple[int], default(1,)): dilation per branch. Multiple values create parallel dilated convs; effective receptive field is(d_conv-1)*dilation.expand(float, default 2): inner width multiplier; setsd_inner = expand * d_model.dt_rank(int or "auto", default "auto"): rank of delta projection. "auto" setsceil(d_model/16).dt_min,dt_max(float, defaults 1e-3 / 1e-1): log-uniform range for delta initialization.dt_init("random" | "constant", default "random") anddt_scale,dt_init_floor: control delta init magnitude/stability.conv_bias(bool, default True): include bias in depthwise convs.bias(bool, default False): include bias in input/output linear projections.use_fast_path(bool): ignored in this pure-PyTorch build; kept for API compatibility.layer_idx(int | None): identifier for streaming cache registration; required when usingallocate_inference_cache+inference_params.device,dtype: standard module factory kwargs.
Inference / streaming helpers
allocate_inference_cache(batch_size, max_seqlen, dtype=None): preallocates conv and SSM state buffers for step-wise decoding.step(hidden_states, conv_state, ssm_state): single-token forward (expectshidden_stateswith shape(B, 1, d_model)).forward(..., inference_params): ifinference_paramshas cached states (withkey_value_memory_dictandseqlen_offset), uses them for streaming.
Highlights
- Dual Framework: PyTorch and TensorFlow implementations with identical core logic
- Multi-branch Architecture: Parallel or stacked causal dilated convolutions with learned gating
- Pure Python: No custom C++/CUDA or Triton kernels required
- Streaming Support: Per-branch conv states and SSM state caching for autoregressive generation
- Framework Parity: Mathematically equivalent implementations verified across both frameworks
Practical setups
- Local modeling / small context:
d_conv=3,conv_dilations=(1,2,4),d_state=8–16,expand=2. - Longer context: widen
conv_dilations(e.g.,(1,2,4,8,16)) or increased_stateto 32; expect higher memory/compute. - Streaming/AR decoding: call
allocate_inference_cacheonce per layer, passinference_paramsduring forward; usestepinside your generation loop. - Stability first: keep
dt_min>= 1e-4 anddt_init_floorsmall; leave defaults unless you observe drift or exploding activations.
Framework Compatibility
PyTorch vs TensorFlow
Both implementations are mathematically equivalent and follow the same core Mamba architecture:
- Tensor Layout: PyTorch uses
(B, D, L)(channels-first), TensorFlow uses(B, L, D)(channels-last) - SSM Formulation: Identical selective state space model with zero-order hold discretization
- Initialization: Same parameter initialization schemes (A matrix, D parameter, dt bias)
- Variants: All architectural variants available in both frameworks
The implementations have been thoroughly tested to ensure numerical equivalence within floating-point precision.
Notes
- Set different
conv_dilationsto adjust receptive field; keep kernels small (e.g., 3–5) to avoid excessive padding. use_fast_pathflag is ignored here (kept for API compatibility with reference implementations).- Reference selective scan is implemented in pure Python for portability; faster fused kernels are omitted intentionally.
- TensorFlow implementation uses custom depthwise convolution via
tf.nn.depthwise_conv2dfor dilated causal convolutions.
License
Apache-2.0
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file lite_mamba-0.2.6.tar.gz.
File metadata
- Download URL: lite_mamba-0.2.6.tar.gz
- Upload date:
- Size: 17.3 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.2.0 CPython/3.14.2
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
6c7b8520b30febd045e59f110d2543d2a1cfc190ebc7c659428b6d6b0a2c7b95
|
|
| MD5 |
61bfdb979cfc12004bd963a0720a4c69
|
|
| BLAKE2b-256 |
2430b4e283ad48e9a549401213414238a6446b8c7ed77d1f1a471899c4940a9c
|
File details
Details for the file lite_mamba-0.2.6-py3-none-any.whl.
File metadata
- Download URL: lite_mamba-0.2.6-py3-none-any.whl
- Upload date:
- Size: 14.9 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.2.0 CPython/3.14.2
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
880b2a8e9f4a588c725b72bfda8956bce43942bc4c407fd94c11e97776a120e3
|
|
| MD5 |
93cb288c4ae63ef4b4d65ee16b97df11
|
|
| BLAKE2b-256 |
45b481a540fb179b1d48d114befea77095fef191a969895ed991957a54fed57f
|