Skip to main content

Metal-accelerated Vision Mamba for Apple Silicon (2D/3D/4D)

Project description

MLX Vision Mamba

First Metal-accelerated Vision Mamba for Apple Silicon with 2D/3D/4D support and training VJP.

A complete Mamba (Selective State Space Model) implementation for MLX that supports multi-dimensional vision inputs. Includes custom Metal kernels for the selective scan that provide 3.9x training speedup over Python loops.

Features

  • Multi-dimensional Vision Mamba: supports 2D images, 3D volumes (CT/MRI), and 4D spatiotemporal data (cine MRI, video)
  • Multi-directional scanning: K=4 (2D), K=6 (3D), K=8 (4D) scan directions with VMamba-style spatial permutations
  • Custom Metal kernels: fused selective scan in Metal Shading Language with full training VJP via mx.custom_function
  • 3.9x training speedup: over Python sequential scan on Apple Silicon (M5 Pro benchmarked)
  • Drop-in ViT replacement: same (B, H, W, C) -> (B, N, D) interface as Vision Transformer
  • Bidirectional scanning: forward + backward Mamba blocks for non-causal spatial data

Benchmark

Tested on Apple M5 Pro, MLX 0.31, batch=2, seq_len=196, d_model=768, d_state=16:

Mode Scan Time Full Training Step vs Python
Python sequential 16.4 ms 0.393 s 1.0x
Metal (forward only) 2.0 ms 0.173 s 2.3x
Metal (fwd + VJP) 1.5 ms 0.106 s 3.7x
Metal fused (BEST) 1.2 ms 0.035 s ~4x

All gradients verified correct against Python reference (max diff < 1e-6).

Installation

pip install mlx
# Clone this repo
git clone https://github.com/YOUR_USERNAME/mlx-vision-mamba.git
cd mlx-vision-mamba

Quick Start

Basic Mamba Block

import mlx.core as mx
from mlx_vision_mamba import MambaBlock

# Single Mamba block (1D sequence)
block = MambaBlock(d_model=384, d_state=16, scan_mode="metal_fused")
x = mx.random.normal((2, 196, 384))  # (batch, seq_len, dim)
y = block(x)  # (2, 196, 384)

Vision Mamba Encoder (Drop-in ViT Replacement)

from mlx_vision_mamba import VisionMamba

# 2D images (X-ray, natural images)
encoder_2d = VisionMamba(img_size=224, patch_size=16, embed_dim=384, depth=6, input_dim=2)
images = mx.random.normal((2, 224, 224, 1))  # (B, H, W, C)
features = encoder_2d(images)  # (2, 196, 384)

# 3D volumes (CT, MRI) — K=6 multi-directional scanning
encoder_3d = VisionMamba(img_size=64, patch_size=16, embed_dim=384, depth=6, input_dim=3)
volumes = mx.random.normal((1, 64, 64, 64, 1))  # (B, D, H, W, C)
features_3d = encoder_3d(volumes)  # (1, 64, 384)

# 4D spatiotemporal (cine MRI, video) — K=8 directions
encoder_4d = VisionMamba(
    img_size=32, patch_size=16, embed_dim=384, depth=6,
    input_dim=4, num_temporal_frames=4, temporal_patch_size=1,
)
video = mx.random.normal((1, 4, 32, 32, 32, 1))  # (B, T, D, H, W, C)
features_4d = encoder_4d(video)  # (1, 32, 384)

Metal Acceleration

from mlx_vision_mamba import MambaBlock

# Inference (fastest, no autograd)
block = MambaBlock(d_model=384, scan_mode="metal")

# Training with gradients (Metal forward + Metal backward VJP)
block = MambaBlock(d_model=384, scan_mode="metal_fused")

Scan Modes

Mode Use Case Speed
sequential Safe default, debugging 1.0x
chunked Slightly faster Python scan ~1.1x
metal Inference / frozen layers (no grad) ~9x scan
metal_train Training (Metal fwd + Metal bwd VJP) ~3.7x
metal_fused Training (fused discretization + scan) ~3.9x

Architecture

Multi-Directional Scanning

Standard Mamba processes sequences left-to-right (causal). For spatial data with no causal ordering, we use multi-directional scanning:

  • 2D (K=4): forward-row, backward-row, forward-column, backward-column
  • 3D (K=6): scan along +X, -X, +Y, -Y, +Z, -Z
  • 4D (K=8): scan along +T, -T, +X, -X, +Y, -Y, +Z, -Z

Outputs from all directions are summed with a residual connection.

When processing a subset of patches (e.g., JEPA context encoder), automatically falls back to K=2 (forward + backward) since spatial permutations require the full grid.

Metal Kernel Design

The selective scan recurrence h = A*h + b; y = C*h + D*x is the bottleneck — it's inherently sequential over time. Our Metal kernel:

  1. One GPU thread per (batch, channel) — embarrassingly parallel across B*D_inner
  2. SSM state in thread registers — N=16 floats (64 bytes) per thread, no shared memory needed
  3. Single kernel launch — replaces ~200 Python loop iterations with one Metal dispatch
  4. Fused discretization — exp(deltaA) and deltaB*x computed inline, no intermediate tensors
  5. Custom VJP via mx.custom_function — backward pass is a reverse scan in the same pattern

Files

mlx_vision_mamba/
  __init__.py              — Public API
  mamba_block.py           — MambaBlock, BidirectionalMambaBlock, MultiDirectionalMambaBlock
  vision_mamba.py          — VisionMamba encoder (2D/3D/4D)
  mamba_metal.py           — Metal kernel (forward + VJP)
  mamba_metal_fused.py     — Fused Metal kernel (discretize + scan in one kernel)
  mamba_fast.py            — Python-level optimizations (chunked, parallel)
  vit.py                   — PatchEmbed, PatchEmbed3D (shared with ViT)

Citation

If you use this code in your research, please cite:

@software{kang2026mlx_vision_mamba,
  author = {Kang, Shinyoung},
  title = {MLX Vision Mamba: Metal-Accelerated State Space Models for Multi-Dimensional Vision},
  year = {2026},
  url = {https://github.com/YOUR_USERNAME/mlx-vision-mamba},
  license = {Apache-2.0},
}

Acknowledgments

  • Mamba by Albert Gu and Tri Dao
  • MLX by Apple
  • VMamba for multi-directional scanning
  • Built with assistance from Claude (Anthropic)

License

Apache License 2.0. See LICENSE.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

mlx_vision_mamba-0.1.1.tar.gz (26.9 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

mlx_vision_mamba-0.1.1-py3-none-any.whl (29.6 kB view details)

Uploaded Python 3

File details

Details for the file mlx_vision_mamba-0.1.1.tar.gz.

File metadata

  • Download URL: mlx_vision_mamba-0.1.1.tar.gz
  • Upload date:
  • Size: 26.9 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.12.13

File hashes

Hashes for mlx_vision_mamba-0.1.1.tar.gz
Algorithm Hash digest
SHA256 ea38a0b55cd60e3f73de60c30f6b52172f600b30e6789ba92fc4f815f624fcc2
MD5 a579c5aad4abd2014b77cad5799f57a2
BLAKE2b-256 aee75bcb058da86ba22225918a521ba75ef052115863280228196e06c06d837c

See more details on using hashes here.

File details

Details for the file mlx_vision_mamba-0.1.1-py3-none-any.whl.

File metadata

File hashes

Hashes for mlx_vision_mamba-0.1.1-py3-none-any.whl
Algorithm Hash digest
SHA256 96a94f1f5be4d52886d9d61b76291bb519d6aa3cf423e13fd355d4f7dfe0092c
MD5 fe14c1522e9677bd12b3a01e99890d70
BLAKE2b-256 c0ba10eed0d51769391c205509245eace59132000eccf4dbacd0bd89dc99cea9

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page