Skip to main content

TeaCache step-skipping for FLUX diffusion on Apple Silicon, in pure MLX

Project description

mlx-teacache

PyPI version Python versions License: Apache 2.0 CI

TeaCache step-skipping for FLUX diffusion on Apple Silicon, in pure MLX.

mlx-teacache is the first MLX port of TeaCache — a training-free inference optimization that skips ~20-50% of denoising steps in FLUX-family diffusion models by predicting which steps contribute little to the final image. Measured ~1.48× wall-clock speedup at the default threshold on FLUX.1-dev / 25 steps with visually-equivalent output (SSIM ≥ 0.80 on a 5-prompt suite).

What it does

Diffusion models run the same big transformer 20-50 times in a loop. Between consecutive steps the output changes very little; TeaCache uses a tiny polynomial fit to predict which steps can reuse the previous step's output. On M1 Max FLUX.1-dev @ 25 steps the default threshold (rel_l1_thresh=0.20) skips 6 of 25 steps for a ~1.48× speedup.

from mflux.models.flux2.variants.txt2img.flux2_klein import Flux2Klein
from mflux.models.common.config.model_config import ModelConfig
from mlx_teacache import apply_teacache

flux = Flux2Klein(quantize=4, model_config=ModelConfig.flux2_klein_4b())
with apply_teacache(flux):  # default rel_l1_thresh=0.20
    flux.generate_image(prompt="a red apple", seed=42, num_inference_steps=25)

Install

pip install "mlx-teacache[mflux]"
# or with uv:
uv add "mlx-teacache[mflux]"

Requires Python ≥ 3.11 and Apple Silicon. The [mflux] extra pulls in mflux>=0.17,<0.18.

pip install "mlx-teacache==0.1.0[mflux]"  # pin for reproducibility

Quick start

FLUX.1 dev:

from mflux.models.flux.variants.txt2img.flux import Flux1
from mlx_teacache import apply_teacache

flux = Flux1.from_name("dev", quantize=4)
with apply_teacache(flux) as handle:  # default rel_l1_thresh=0.20
    flux.generate_image(prompt="...", seed=42, num_inference_steps=25, guidance=3.5)
    print(f"Speedup: {handle.stats.speedup_estimate:.2f}×")

Threshold guide

Measured on M1 Max 32GB, FLUX.1-dev @ 25 steps, bf16, seed=42, guidance=3.5, red-apple prompt:

rel_l1_thresh Skipped steps Speedup SSIM vs vanilla Recommended use
0.10 0 / 25 1.07× 1.0000 Cache never engages
0.15 0 / 25 1.13× 1.0000 Cache never engages
0.20 (default) 6 / 25 1.48× ≥ 0.80 (5-prompt suite) Visually-lossless sweet spot
0.25 11 / 25 1.96× 0.57-0.93 Visible style changes on text/synthetic prompts

The 0.20 default was chosen after side-by-side visual comparison: at 0.25 a text prompt that vanilla renders as neon tubes can come out as dot-matrix; at 0.20 the output is indistinguishable from vanilla while still skipping ~25% of steps. SSIM is a conservative metric on high-frequency-detail prompts (text, synthetic patterns).

Supported models

Variant id mflux class + config Coefficient source
flux1-dev Flux1(model_config=ModelConfig.dev()) upstream ali-vilab/TeaCache
flux1-schnell Flux1(model_config=ModelConfig.schnell()) upstream (shared with dev)
flux2-klein-4b Flux2Klein(model_config=ModelConfig.flux2_klein_4b()) in-repo (see docs/calibration.md)

Combining with mlx-taef

from mlx_taef.integrations.mflux import LivePreviewCallback
from mlx_teacache import apply_teacache

preview = LivePreviewCallback(variant="taef2", every=5, save_to="preview.png",
                              latent_height=32, latent_width=32)
flux.callbacks.register(preview)

with apply_teacache(flux):  # default rel_l1_thresh=0.20
    flux.generate_image(prompt="...", seed=42, num_inference_steps=25)

Inspecting stats

handle = apply_teacache(flux)
flux.generate_image(prompt="apple", seed=42, num_inference_steps=25)
print(handle.stats.computed_count, handle.stats.skipped_count)
print(handle.stats.last_generation.decisions[5])  # per-step record

Custom coefficients

custom = [...]  # length 5, all finite
apply_teacache(flux, coefficients=custom)

How it works

TeaCache observes that in diffusion denoising, consecutive transformer outputs change very little between most pairs of adjacent steps. The expensive transformer body (all the joint + single attention blocks) produces a residual that's added to the input — and that residual stays roughly stable for stretches of the denoising trajectory.

TeaCache trains a tiny polynomial that predicts how much the output will change given how much the input has changed (measured as relative L1 distance of the modulated block-0 input). When the predicted accumulated change since the last actual compute step is small, TeaCache reuses the cached residual instead of running the full transformer body — only the cheap prelude (embeddings) and tail (norm + projection) still run.

mlx-teacache implements this for mflux on Apple Silicon. We replace flux.transformer with a per-instance proxy for FLUX.1, and flux._predict with an instance-level closure for FLUX.2 (to keep gating live even when mx.compile is normally used). The polynomial coefficients for FLUX.1 are vendored from upstream; the FLUX.2 Klein 4b coefficients are derived in-repo via scripts/calibrate.py. See the TeaCache paper at https://liewfeng.github.io/TeaCache/ and docs/superpowers/spikes/2026-05-14-mlx-teacache-phase-0-spike.md for the Apple-Silicon-specific design notes.

Benchmarks

M1 Max 32GB, default threshold (rel_l1_thresh=0.20):

Model Steps Vanilla With TeaCache Speedup SSIM (PR-gate prompt)
FLUX.1-dev @ 512² 25 ~5:18 ~3:35 1.48× ≥ 0.90
FLUX.2 Klein 4b @ 512² 8 ~37s ~30s ~1.2× ≥ 0.85

FLUX.2's smaller speedup at 8 steps reflects the short non-distilled denoising trajectory at Klein's recommended step count. Longer schedules (num_inference_steps>=15) give larger gains.

Limitations

  • v0.1 is txt2img only. img2img inputs raise Img2ImgNotSupportedError. v0.2 planned.
  • FLUX.2 CFG (guidance > 1.0) auto-falls-back to vanilla mflux. Output is bit-exact. v0.2 will add per-branch caching.
  • Distilled schedules see no speedup. FLUX.1 schnell 4-step and Klein 4-step defaults have too few non-forced steps. Use num_inference_steps >= 10 for benefit.
  • Klein variants other than flux2-klein-4b are not supported in v0.1. 9b and base configs planned for v0.2.
  • M3+ users lose mflux's mx.compile of _predict. v0.1 provides a manual benchmark recipe (docs/m3-plus-tradeoff.md) and does not claim a measured M3+ speedup.
  • FLUX.2 parity is numerical, not bit-exact. Because the wrapper replaces a function mflux wraps in mx.compile, vanilla-compiled vs wrapper-eager differ by ~1 ULP per element from Metal kernel-dispatch noise (compounds across steps; cosine similarity stays ≥ 0.99). The CFG-fallback path remains bit-exact. End-to-end image quality (SSIM ≥ 0.85 on Klein 4b) is the user-facing guarantee.
  • mflux pin is strict (>=0.17,<0.18). Bumping requires a deliberate release.
  • Parent-level flux.parameters() may miss transformer parameters while patched. Use flux.transformer.parameters() directly, or handle.restore() first.

Contributing

Open an issue at https://github.com/IonDen/mlx-teacache/issues.

License + acknowledgements

Apache-2.0. See LICENSE and NOTICE.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

mlx_teacache-0.1.1.tar.gz (62.9 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

mlx_teacache-0.1.1-py3-none-any.whl (37.1 kB view details)

Uploaded Python 3

File details

Details for the file mlx_teacache-0.1.1.tar.gz.

File metadata

  • Download URL: mlx_teacache-0.1.1.tar.gz
  • Upload date:
  • Size: 62.9 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: twine/6.1.0 CPython/3.13.12

File hashes

Hashes for mlx_teacache-0.1.1.tar.gz
Algorithm Hash digest
SHA256 840a316298ac38a1716e80a161396f4a7b8aecc854468633a652f6dbcf5fece9
MD5 d9819cf891cda736e23aa32f9133add9
BLAKE2b-256 282f34434109c4d7ab76fe8f3ad436ff104bc06e286ac6b8708ebe06bd460a52

See more details on using hashes here.

Provenance

The following attestation bundles were made for mlx_teacache-0.1.1.tar.gz:

Publisher: release.yml on IonDen/mlx-teacache

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

File details

Details for the file mlx_teacache-0.1.1-py3-none-any.whl.

File metadata

  • Download URL: mlx_teacache-0.1.1-py3-none-any.whl
  • Upload date:
  • Size: 37.1 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: twine/6.1.0 CPython/3.13.12

File hashes

Hashes for mlx_teacache-0.1.1-py3-none-any.whl
Algorithm Hash digest
SHA256 6c1f082cdd288b794f13152ac9f52f7710047403a7141d08c191f0ebf0fe88a6
MD5 5a7e122ce250e805b181ffc9a8bc889a
BLAKE2b-256 0ed65296228065c4ed3bf3c9f7bda9fa1507caeb3ad012b4c571ea8d89769f53

See more details on using hashes here.

Provenance

The following attestation bundles were made for mlx_teacache-0.1.1-py3-none-any.whl:

Publisher: release.yml on IonDen/mlx-teacache

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page