Pure-TensorFlow lightweight Mamba with multi-dilated causal conv front-end
Project description
lite-mamba
A minimal, pure-TensorFlow implementation of Mamba with a multi-dilated causal depthwise conv front-end. No custom C++ or Triton kernels needed; works seamlessly on CPU, GPU, or TPU with standard TensorFlow ops.
Install
pip install lite-mamba
Usage
from lite_mamba import TFPTCNMamba
import tensorflow as tf
x = tf.random.normal((2, 128, 512)) # (batch, seq, d_model)
m = TFPTCNMamba(d_model=512, d_conv=3, conv_dilations=(1, 2, 4, 8))
y = m(x)
print(y.shape) # (2, 128, 512)
Conv front-end variants
TFPTCNMamba: default Mamba variant, mixes parallel dilated depthwise conv branches via learned softmax gates.TFSTCNMamba: runs the same depthwise conv layers in sequence (no gating); each branch output feeds the next to create a deterministic dilation stack.TFDPWCMamba: pairs each depthwise branch with a pointwise (1x1) conv before the gating mix, adding extra channel mixing without stacking more layers.TFBaselineMamba: single-branch baseline matching the reference Mamba architecture from state-spaces.
All variants expose the same constructor signature (d_model, d_state, conv_dilations, etc.) and streaming helpers (allocate_inference_cache, step). Swap them simply by changing the imported class name:
from lite_mamba import TFSTCNMamba
m = TFSTCNMamba(d_model=512, d_state=16, conv_dilations=(1, 2, 4))
API quick reference
TFMamba(d_model, d_state=16, d_conv=4, conv_dilations=(1,), expand=2, dt_rank="auto", dt_min=0.001, dt_max=0.1, dt_init="random", dt_scale=1.0, dt_init_floor=1e-4, conv_bias=True, bias=False, layer_idx=None)
d_model(int, required): input/output embedding size.d_state(int, default 16): SSM state dimension per channel. Larger gives longer memory; increases compute.d_conv(int, default 4): depthwise conv kernel size for each branch.conv_dilations(tuple[int], default(1,)): dilation per branch. Multiple values create parallel dilated convs; effective receptive field is(d_conv-1)*dilation.expand(float, default 2): inner width multiplier; setsd_inner = expand * d_model.dt_rank(int or "auto", default "auto"): rank of delta projection. "auto" setsceil(d_model/16).dt_min,dt_max(float, defaults 1e-3 / 1e-1): log-uniform range for delta initialization.dt_init("random" | "constant", default "random") anddt_scale,dt_init_floor: control delta init magnitude/stability.conv_bias(bool, default True): include bias in depthwise convs.bias(bool, default False): include bias in input/output linear projections.layer_idx(int | None): identifier for streaming cache registration.
Inference / streaming helpers
allocate_inference_cache(batch_size, max_seqlen, dtype=None): preallocates conv and SSM state buffers for step-wise decoding.step(hidden_states, conv_state, ssm_state): single-token forward (expectshidden_stateswith shape(B, 1, d_model)).
Highlights
- Multi-branch Architecture: Parallel or stacked causal dilated convolutions with learned gating.
- Pure TensorFlow: No custom C++/CUDA kernels required. Compatible with XLA compilation (
tf.function(jit_compile=True)). - Streaming Support: Per-branch conv states and SSM state caching for autoregressive generation.
License
Apache-2.0
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file lite_mamba-1.0.2.tar.gz.
File metadata
- Download URL: lite_mamba-1.0.2.tar.gz
- Upload date:
- Size: 7.9 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.2.0 CPython/3.14.2
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
4699343f68c9f3ffc0ab0862cd0e98112d46173bebdfb435c189dc9e2641bf8d
|
|
| MD5 |
a272b5d8c95c036453078cddc549acb7
|
|
| BLAKE2b-256 |
76820c13147fa3012a601438f153235bc22dd15d3274ecf4d2b6eceaec16ef51
|
File details
Details for the file lite_mamba-1.0.2-py3-none-any.whl.
File metadata
- Download URL: lite_mamba-1.0.2-py3-none-any.whl
- Upload date:
- Size: 7.2 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.2.0 CPython/3.14.2
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
944477f159d0eae9ec94c06ab5a0d2144868ad6dda6edbb9d4c0da8b510177b8
|
|
| MD5 |
f1d8df6908652fce3a61481d8fa1d2d5
|
|
| BLAKE2b-256 |
f5c31c340ea6a81bde8998383cf686e5b1d424d1f1ea9990de4d8871dc596cbd
|