Inference optimizations for diffusion and video super-resolution models on MLX / Apple Silicon
Project description
mlx-diffusion-kit
Inference optimizations for diffusion and video super-resolution models on MLX / Apple Silicon. Training-free techniques that reduce compute by 2-5x without quality loss.
Current version: 0.2.1 — 23 optimization components, 284+ tests.
Foreword
This library was born from the same frustration that drove
mlx-mfa: VSR
inference on Apple Silicon is painfully slow. mlx-mfa tackles the attention
kernel; mlx-diffusion-kit tackles everything else — step caching, token
merging, cross-attention gating, VAE optimization, scheduling, and
orchestration.
The two libraries are complementary:
mlx-mfa mlx-diffusion-kit
┌──────────────────────────┐ ┌──────────────────────────┐
│ Flash Attention kernels │ │ Step-level caching │
│ Sparse / GNA / Paged │◄───────│ Token merging / pruning │
│ KV cache management │ │ Cross-attention gating │
│ attn_bias (native Metal) │ │ VAE optimization │
│ TurboQuant KV │ │ Scheduling │
│ SVDQuantLinear │ │ Orchestrator (PISA) │
└──────────────────────────┘ └──────────────────────────┘
Kernel layer Optimization layer
Installation
pip install -e ".[dev]"
With mlx-mfa integration (for proportional attention via attn_bias):
pip install -e ".[mfa]"
Requirements: Python >= 3.10, MLX >= 0.25.0, Apple Silicon Mac.
Current Status
- 9 STABLE components — tested, integrated, production-ready API.
- 11 BETA components — functional and tested, API may evolve.
- 1 EXPERIMENTAL — functional, use with caution.
- 2 STUB — interface defined, implementation pending.
- 276+ tests pass, 0 failures.
- Primary validation hardware: Apple M1 Max.
Component Overview
| ID | Component | Section | Maturity | Applies to |
|---|---|---|---|---|
| B1 | TeaCache + WorldCache | Step Cache | Stable | 6 multi-step |
| B2 | First-Block Cache | Step Cache | Beta | 6 multi-step |
| B3 | SpectralCache | Step Cache | Beta | 6 multi-step |
| B4 | SmoothCache + Taylor | Step Cache | Stable | 6 multi-step |
| B5 | DeepCache (+ MosaicDiff layer-redundancy tool) | Step Cache | Beta | 5 UNet multi-step |
| B6 | Multi-Granular Cache | Step Cache | Beta | 6 multi-step |
| B7 | ToCa (Token Cache) | Tokens | Beta | multi-step DiT |
| B8 | ToMe + ToPi | Tokens | Stable / Beta | ALL 11 |
| B9 | DiffSparse | Tokens | Stub | DiT models |
| B10 | DDiT Scheduling | Tokens | Beta | multi-step DiT |
| B11 | T-GATE | Gating | Stable | 6 multi-step |
| B12 | DiTFastAttn (4 strategies) | Attention | Beta | multi-step DiT |
| B13 | FreeU | Quality | Stable | 5 UNet |
| B14 | DPM-Solver-v3 / Adaptive | Scheduler | Stable / Beta | 6 multi-step |
| B15 | Text Embedding Cache | Encoder | Stable | ALL 11 |
| B17 | WF-VAE Causal Cache | VAE | Stable | SeedVR2 + CogVideoX |
| B18 | Separable Conv3D + SVD utility | VAE | Beta | SeedVR2 VAE |
| B22 | Encoder Sharing | Cache | Beta | multi-step DiT |
| B23 | Orchestrator + PISA | Orchestrator | Stable | ALL 11 |
Quick Start
import mlx_diffusion_kit as mdk
# 1. Cache text embeddings (all models)
emb_cache = mdk.TextEmbeddingCache()
embedding = emb_cache.get_or_compute("enhance 4x", my_t5_encoder, encoder_id="t5-xxl")
# 2. Step caching for multi-step models
from mlx_diffusion_kit.cache import TeaCacheConfig, load_coefficients
config = load_coefficients("cogvideox") # Pre-calibrated
# 3. Token merging (all models)
merged, info = mdk.tome_merge(tokens, mdk.ToMeConfig(merge_ratio=0.5))
# ... run attention on merged tokens ...
output = mdk.tome_unmerge(merged_output, info)
# 4. Full orchestration
from mlx_diffusion_kit.orchestrator import DiffusionOptimizer, OrchestratorConfig
opt = DiffusionOptimizer(OrchestratorConfig(
teacache=config,
tome=mdk.ToMeConfig(merge_ratio=0.5),
tgate=mdk.TGateConfig(gate_step=5),
is_single_step=False,
))
See docs/API_MANUAL.md for complete API reference.
Target Models
Single-step (no inter-step caching)
| Model | Backbone | Key trait |
|---|---|---|
| SeedVR2 | DiT 48b | Production ref. DiT=22%, VAE=77% |
| DOVE | DiT CogVideoX1.5-5B | Single-step DiT |
| FlashVSR | DiT Wan2.1, LCSA | Sparse attention |
| DLoRAL | UNet SD, Dual-LoRA | ~1B params |
| UltraVSR | UNet SD + RTS | ~1B params |
Multi-step (step caching applicable)
| Model | Backbone | Steps |
|---|---|---|
| SparkVSR | DiT CogVideoX1.5-5B-I2V | ~20-30 |
| STAR | DiT CogVideoX-5B | Multi |
| Vivid-VR | DiT CogVideoX1.5-5B + CN | Multi |
| DAM-VSR | SVD UNet + CN | ~30 |
| DiffVSR | UNet SD | 20-50 |
| VEnhancer | ControlNet + ModelScope UNet | 15-50 |
Scripts
# Calibrate TeaCache coefficients for a new model
python scripts/calibrate_teacache.py --features-dir ./features/ --output coefficients.json
# Analyze layer redundancy for DeepCache
python scripts/analyze_layer_redundancy.py --weights model.npz --output scores.json
Documentation
docs/API_MANUAL.md— Complete API reference for all 89 exportsdocs/ARCHITECTURE.md— Module structure and design principlesCHANGELOG.md— Version history
License
MIT — see LICENSE for details.
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file mlx_diffusion_kit-0.2.1.tar.gz.
File metadata
- Download URL: mlx_diffusion_kit-0.2.1.tar.gz
- Upload date:
- Size: 132.1 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.2.0 CPython/3.14.3
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
339a47350f000b284d860a6315ec982739661a8db0a999fe0a8f9e6d878bc714
|
|
| MD5 |
cc8c8c82c9bd3df569cf73ba004b64b5
|
|
| BLAKE2b-256 |
4643ab531b25128bdbcd6edcb9887b1b36a56cff1b5838941e44219791510b37
|
File details
Details for the file mlx_diffusion_kit-0.2.1-py3-none-any.whl.
File metadata
- Download URL: mlx_diffusion_kit-0.2.1-py3-none-any.whl
- Upload date:
- Size: 72.9 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.2.0 CPython/3.14.3
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
ec67ea84811583182d8fd575483953937ee1fb43c436ed056c600b49384d69de
|
|
| MD5 |
6b11e54fdfc9af64cb07c11b5711bf00
|
|
| BLAKE2b-256 |
1663357a87bd34a0c173602cff5ed8e84926fbc806fa62938ee608ec87b9e111
|