Skip to main content

Pure-TensorFlow lightweight Mamba with multi-dilated causal conv front-end

Project description

lite-mamba

Publish Python Package Tests

A minimal, pure-TensorFlow implementation of Mamba with a multi-dilated causal depthwise conv front-end. No custom C++ or Triton kernels needed; works seamlessly on CPU, GPU, or TPU with standard TensorFlow ops.

Install

pip install lite-mamba

Usage

from lite_mamba import TFPTCNMamba
import tensorflow as tf

x = tf.random.normal((2, 128, 512))  # (batch, seq, d_model)
m = TFPTCNMamba(d_model=512, d_conv=3, conv_dilations=(1, 2, 4, 8))
y = m(x)
print(y.shape)  # (2, 128, 512)

Conv front-end variants

  • TFPTCNMamba: default Mamba variant, mixes parallel dilated depthwise conv branches via learned softmax gates.
  • TFSTCNMamba: runs the same depthwise conv layers in sequence (no gating); each branch output feeds the next to create a deterministic dilation stack.
  • TFDPWCMamba: pairs each depthwise branch with a pointwise (1x1) conv before the gating mix, adding extra channel mixing without stacking more layers.
  • TFBaselineMamba: single-branch baseline matching the reference Mamba architecture from state-spaces.

All variants expose the same constructor signature (d_model, d_state, conv_dilations, etc.) and streaming helpers (allocate_inference_cache, step). Swap them simply by changing the imported class name:

from lite_mamba import TFSTCNMamba

m = TFSTCNMamba(d_model=512, d_state=16, conv_dilations=(1, 2, 4))

API quick reference

TFMamba(d_model, d_state=16, d_conv=4, conv_dilations=(1,), expand=2, dt_rank="auto", dt_min=0.001, dt_max=0.1, dt_init="random", dt_scale=1.0, dt_init_floor=1e-4, conv_bias=True, bias=False, layer_idx=None)

  • d_model (int, required): input/output embedding size.
  • d_state (int, default 16): SSM state dimension per channel. Larger gives longer memory; increases compute.
  • d_conv (int, default 4): depthwise conv kernel size for each branch.
  • conv_dilations (tuple[int], default (1,)): dilation per branch. Multiple values create parallel dilated convs; effective receptive field is (d_conv-1)*dilation.
  • expand (float, default 2): inner width multiplier; sets d_inner = expand * d_model.
  • dt_rank (int or "auto", default "auto"): rank of delta projection. "auto" sets ceil(d_model/16).
  • dt_min, dt_max (float, defaults 1e-3 / 1e-1): log-uniform range for delta initialization.
  • dt_init ("random" | "constant", default "random") and dt_scale, dt_init_floor: control delta init magnitude/stability.
  • conv_bias (bool, default True): include bias in depthwise convs.
  • bias (bool, default False): include bias in input/output linear projections.
  • layer_idx (int | None): identifier for streaming cache registration.

Inference / streaming helpers

  • allocate_inference_cache(batch_size, max_seqlen, dtype=None): preallocates conv and SSM state buffers for step-wise decoding.
  • step(hidden_states, conv_state, ssm_state): single-token forward (expects hidden_states with shape (B, 1, d_model)).

Highlights

  • Multi-branch Architecture: Parallel or stacked causal dilated convolutions with learned gating.
  • Pure TensorFlow: No custom C++/CUDA kernels required. Compatible with XLA compilation (tf.function(jit_compile=True)).
  • Streaming Support: Per-branch conv states and SSM state caching for autoregressive generation.

License

Apache-2.0

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

lite_mamba-1.0.2.tar.gz (7.9 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

lite_mamba-1.0.2-py3-none-any.whl (7.2 kB view details)

Uploaded Python 3

File details

Details for the file lite_mamba-1.0.2.tar.gz.

File metadata

  • Download URL: lite_mamba-1.0.2.tar.gz
  • Upload date:
  • Size: 7.9 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.14.2

File hashes

Hashes for lite_mamba-1.0.2.tar.gz
Algorithm Hash digest
SHA256 4699343f68c9f3ffc0ab0862cd0e98112d46173bebdfb435c189dc9e2641bf8d
MD5 a272b5d8c95c036453078cddc549acb7
BLAKE2b-256 76820c13147fa3012a601438f153235bc22dd15d3274ecf4d2b6eceaec16ef51

See more details on using hashes here.

File details

Details for the file lite_mamba-1.0.2-py3-none-any.whl.

File metadata

  • Download URL: lite_mamba-1.0.2-py3-none-any.whl
  • Upload date:
  • Size: 7.2 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.14.2

File hashes

Hashes for lite_mamba-1.0.2-py3-none-any.whl
Algorithm Hash digest
SHA256 944477f159d0eae9ec94c06ab5a0d2144868ad6dda6edbb9d4c0da8b510177b8
MD5 f1d8df6908652fce3a61481d8fa1d2d5
BLAKE2b-256 f5c31c340ea6a81bde8998383cf686e5b1d424d1f1ea9990de4d8871dc596cbd

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page