Skip to main content

Pure-TensorFlow lightweight Mamba with multi-dilated causal conv front-end

Project description

lite-mamba

Publish Python Package Tests

A minimal, pure-TensorFlow implementation of Mamba with a multi-dilated causal depthwise conv front-end. No custom C++ or Triton kernels needed; works seamlessly on CPU, GPU, or TPU with standard TensorFlow ops.

Install

pip install lite-mamba

Usage

from lite_mamba import TFPTCNMamba
import tensorflow as tf

x = tf.random.normal((2, 128, 512))  # (batch, seq, d_model)
m = TFPTCNMamba(d_model=512, d_conv=3, conv_dilations=(1, 2, 4, 8))
y = m(x)
print(y.shape)  # (2, 128, 512)

Conv front-end variants

  • TFPTCNMamba: default Mamba variant, mixes parallel dilated depthwise conv branches via learned softmax gates.
  • TFSTCNMamba: runs the same depthwise conv layers in sequence (no gating); each branch output feeds the next to create a deterministic dilation stack.
  • TFDPWCMamba: pairs each depthwise branch with a pointwise (1x1) conv before the gating mix, adding extra channel mixing without stacking more layers.
  • TFBaselineMamba: single-branch baseline matching the reference Mamba architecture from state-spaces.

All variants expose the same constructor signature (d_model, d_state, conv_dilations, etc.) and streaming helpers (allocate_inference_cache, step). Swap them simply by changing the imported class name:

from lite_mamba import TFSTCNMamba

m = TFSTCNMamba(d_model=512, d_state=16, conv_dilations=(1, 2, 4))

API quick reference

TFMamba(d_model, d_state=16, d_conv=4, conv_dilations=(1,), expand=2, dt_rank="auto", dt_min=0.001, dt_max=0.1, dt_init="random", dt_scale=1.0, dt_init_floor=1e-4, conv_bias=True, bias=False, layer_idx=None)

  • d_model (int, required): input/output embedding size.
  • d_state (int, default 16): SSM state dimension per channel. Larger gives longer memory; increases compute.
  • d_conv (int, default 4): depthwise conv kernel size for each branch.
  • conv_dilations (tuple[int], default (1,)): dilation per branch. Multiple values create parallel dilated convs; effective receptive field is (d_conv-1)*dilation.
  • expand (float, default 2): inner width multiplier; sets d_inner = expand * d_model.
  • dt_rank (int or "auto", default "auto"): rank of delta projection. "auto" sets ceil(d_model/16).
  • dt_min, dt_max (float, defaults 1e-3 / 1e-1): log-uniform range for delta initialization.
  • dt_init ("random" | "constant", default "random") and dt_scale, dt_init_floor: control delta init magnitude/stability.
  • conv_bias (bool, default True): include bias in depthwise convs.
  • bias (bool, default False): include bias in input/output linear projections.
  • layer_idx (int | None): identifier for streaming cache registration.

Inference / streaming helpers

  • allocate_inference_cache(batch_size, max_seqlen, dtype=None): preallocates conv and SSM state buffers for step-wise decoding.
  • step(hidden_states, conv_state, ssm_state): single-token forward (expects hidden_states with shape (B, 1, d_model)).

Highlights

  • Multi-branch Architecture: Parallel or stacked causal dilated convolutions with learned gating.
  • Pure TensorFlow: No custom C++/CUDA kernels required. Compatible with XLA compilation (tf.function(jit_compile=True)).
  • Streaming Support: Per-branch conv states and SSM state caching for autoregressive generation.

License

Apache-2.0

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

lite_mamba-1.0.1.tar.gz (7.9 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

lite_mamba-1.0.1-py3-none-any.whl (7.2 kB view details)

Uploaded Python 3

File details

Details for the file lite_mamba-1.0.1.tar.gz.

File metadata

  • Download URL: lite_mamba-1.0.1.tar.gz
  • Upload date:
  • Size: 7.9 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.14.2

File hashes

Hashes for lite_mamba-1.0.1.tar.gz
Algorithm Hash digest
SHA256 7f0a31e929ae61cd2635159c077bd1220bf517d27c17f0ab08b70befbbc20c1b
MD5 f66a6523861ca32138045c12ec962230
BLAKE2b-256 1924a66e0a63cab9899c212e69136162e5e8b59bac8aab7b1c5ca65fd10cd541

See more details on using hashes here.

File details

Details for the file lite_mamba-1.0.1-py3-none-any.whl.

File metadata

  • Download URL: lite_mamba-1.0.1-py3-none-any.whl
  • Upload date:
  • Size: 7.2 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.14.2

File hashes

Hashes for lite_mamba-1.0.1-py3-none-any.whl
Algorithm Hash digest
SHA256 d42731658604dae5404c814d51c276fcedbfba9d6f7845a32bc82795a4b17c84
MD5 720a24c8ea93efab5c96da31392dc7f6
BLAKE2b-256 91a3dd30cd08fac606a26993f8e8115ef9f6c3c086e13739298a8e1f596cfe35

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page