ZMLX: Metal-kernel toolkit and optimization lab for MLX on Apple Silicon. Fused MoE decode (+5-12% on LFM2-8B-A1B), custom GPU kernels in one line, 70+ kernel catalog.
Project description
ZMLX — Faster MoE inference on Apple Silicon
ZMLX patches MLX models with fused Metal kernels for faster Mixture-of-Experts decode on Apple Silicon. Stock MLX only for the supported models. No model conversion, no config changes — just install and patch(model).
Quick start (model patching; requires zmlx[train] or mlx-lm):
import mlx_lm
from zmlx.patch import patch
model, tokenizer = mlx_lm.load("mlx-community/LFM2-8B-A1B-4bit")
patch(model)
text = mlx_lm.generate(model, tokenizer,
prompt="Explain mixture-of-experts in one paragraph.",
max_tokens=200)
Install (recommended)
pip install "zmlx[train]"
That includes mlx-lm and everything needed for model patching. For kernel authoring only:
pip install zmlx
Supported mode
| Mode | What you need | What you get |
|---|---|---|
| Stable (default) | Stock MLX + zmlx |
Token-identical output, real decode speedups |
patch() auto-detects which patterns are safe for your model family.
Benchmarks — Stable mode (stock MLX)
LFM2-8B-A1B: +5-12% decode, token-identical, measured on M1 Pro and M4 Max. Prefill neutral by design (fused kernels activate only at M <= 32).
LFM2-8B-A1B on M1 Pro 16 GB
macOS 14.6.1 · MLX 0.30.4 · ZMLX 0.7.12 · Python 3.10.0 · commit
7de879eRepro capsule:
benchmarks/repro_capsules/lfm2_m1pro_20260131.json· Print report:python -m zmlx.bench.report <capsule.json>
4-bit (mlx-community/LFM2-8B-A1B-4bit)
| Metric | Baseline | Patched | Change |
|---|---|---|---|
| Decode | 105.5 tok/s | 115.3 tok/s | +9.3% |
| Prefill | 225.4 tok/s | 227.1 tok/s | +0.8% (neutral) |
| Fidelity | — | 430/430 | token-identical |
| Peak memory | — | 5.3 GB |
8-bit (mlx-community/LFM2-8B-A1B-8bit-MLX)
| Metric | Baseline | Patched | Change |
|---|---|---|---|
| Decode | 72.8 tok/s | 76.4 tok/s | +5.0% |
| Prefill | 180.5 tok/s | 182.8 tok/s | +1.3% (neutral) |
| Fidelity | — | 500/500 | token-identical |
| Peak memory | — | 9.5 GB |
LFM2-8B-A1B on M4 Max 36 GB
macOS 26.1 · MLX 0.30.1 · ZMLX 0.7.12 · Python 3.12 · commit
139993eRepro capsule:
benchmarks/repro_capsules/lfm2_m4max_20260131.json· Print report:python -m zmlx.bench.report <capsule.json>
4-bit (mlx-community/LFM2-8B-A1B-4bit)
| Metric | Baseline | Patched | Change |
|---|---|---|---|
| Decode | 223.7 tok/s | 250.3 tok/s | +11.9% |
| Prefill | 737.4 tok/s | 755.4 tok/s | +2.4% (neutral) |
| Fidelity | — | 430/430 | token-identical |
| Peak memory | — | 5.30 GB |
8-bit (mlx-community/LFM2-8B-A1B-8bit-MLX)
| Metric | Baseline | Patched | Change |
|---|---|---|---|
| Decode | 152.5 tok/s | 164.3 tok/s | +7.7% |
| Prefill | 557.6 tok/s | 564.4 tok/s | +1.2% (neutral) |
| Fidelity | — | 500/500 | token-identical |
| Peak memory | — | 9.45 GB |
Full methodology and raw data: docs/BENCHMARKS.md.
How It Works
The problem: dispatch overhead in MoE decode
In Mixture-of-Experts models, each token is routed to a subset of expert networks. During decode (generating one token at a time), the computation per expert is small — a few matrix multiplies on a single row vector. But the standard inference path dispatches multiple Metal kernels per expert per layer:
- Gating:
softmax(logits)→argpartition→gather→normalize— 4 dispatches - Expert execution: gate projection, up projection, SwiGLU activation, down projection — per expert
- Combine: element-wise multiply by gating weights → reduce-sum across experts — 2 dispatches
On Apple Silicon, each Metal kernel dispatch has fixed overhead (command buffer encoding, GPU scheduling). When the actual compute per dispatch is small — as it is for M=1 decode — this overhead dominates. The GPU spends more time waiting between kernels than doing math.
What ZMLX fuses
ZMLX replaces the multi-dispatch sequences with single Metal kernels that do the same math in one pass. All fused kernels are generated from Python via mx.fast.metal_kernel — no changes to MLX core required.
Fused top-k gating softmax (topk_gating_softmax):
Replaces the 4-dispatch gating sequence with a single kernel. For small expert counts (D <= 32, common in MoE), the kernel uses SIMD group operations — each row is processed by one SIMD group (32 threads), with simd_max and simd_sum for the softmax reduction and a register-based insertion sort for top-k selection. For larger D, a threadgroup reduction with shared memory is used. The kernel computes softmax probabilities and selects the top-k experts with their normalized weights in one pass.
Fused expert combine (moe_combine):
Replaces the separate element-wise multiply and reduce-sum with a single kernel that reads each expert output once, multiplies by its gating weight, and accumulates the weighted sum in float32. Output shape goes directly from (B, K, D) to (B, D) without materializing the intermediate weights * expert_outputs tensor.
Why prefill is unaffected
All fused kernels are guarded with a sequence length check (M <= 32). During prefill, M equals the prompt length (typically hundreds or thousands of tokens). At this scale, the compute-to-dispatch ratio is high and the standard MLX path is already efficient. The guards ensure ZMLX never regresses prefill performance.
Correctness guarantee
Token fidelity is a first-class requirement. patch() auto-detects the model family and excludes patterns with known fidelity issues. The fused gating kernel reproduces the exact same top-k selection and softmax normalization as the reference MLX ops. The combine kernel accumulates in float32 (or dtype-matched for Qwen3's moe_combine_exact). python -m zmlx.validate compares every generated token ID between patched and unpatched models under greedy decoding.
Patching options
from zmlx.patch import patch, smart_patch
patch(model) # auto-detect, apply safe defaults
patch(model, patterns=["moe_mlp"]) # force specific pattern (overrides safety)
patch(model, mode="training") # add norm fusions for backward pass
# Auto-benchmark: apply only patterns that actually help
sample = mx.array([tokenizer.encode("Hello")])
model = smart_patch(model, sample)
Model Support
Stable
Token-identical output, measurable decode improvement. Safe to use without further validation.
| Model | Decode speedup | Fidelity | Patterns |
|---|---|---|---|
| LFM2-8B-A1B-4bit | +9-12% | token-identical | moe_mlp + swiglu_mlp |
| LFM2-8B-A1B-8bit | +5-8% | token-identical | moe_mlp + swiglu_mlp |
| Qwen3-30B-A3B-4bit | +7% | token-identical | moe_mlp |
| GPT-OSS-20B-MXFP4-Q4 | +1% | token-identical | moe_mlp |
Tested (no gain)
| Model | Status | Notes |
|---|---|---|
| Nemotron-3-Nano-30B-A3B-NVFP4 | 0.999x, PASS | Hybrid Mamba-MoE, bandwidth-limited at 19.4 GB |
| LFM2.5-1.2B-Thinking-MLX-8bit | 0.997x, PASS | Dense model, no matched MoE patterns |
| Qwen3-4B-4bit (dense) | diverges at token 18 | Dense model, patches not expected to help |
| Llama-3.2-1B-4bit | 0.98x, PASS | Dense model, bandwidth-bound |
For unlisted models: python -m zmlx.validate <model>.
Toolkit
ZMLX is also a Metal kernel authoring toolkit for MLX:
- 70+ kernel catalog — SwiGLU, GeGLU, fused dropout, MoE gating, RMSNorm, RoPE, quantization
- One-line kernel authoring —
elementwise("x * tanh(log(1 + exp(x)))")compiles to Metal - Automatic gradients — custom VJP backward passes as Metal kernels via
mx.custom_function - Benchmarking —
zmlx.bench.compare()for side-by-side timing,zmlx.bench.reportfor repro capsules
from zmlx.api import elementwise
import mlx.core as mx
mish = elementwise("x * tanh(log(1 + exp(x)))", name="mish")
y = mish(mx.random.normal((1024,)))
Op-level microbenchmarks
B=16, S=1024, D=1024, float16, M4 Max. python benchmarks/microbench.py:
| Operation | MLX | ZMLX | Speedup |
|---|---|---|---|
| SwiGLU | 0.87 ms | 0.43 ms | 2.0x |
| Dropout | 3.08 ms | 0.41 ms | 7.5x |
| Top-K | 1.81 ms | 0.49 ms | 3.7x |
| Gather-Add | 0.55 ms | 0.42 ms | 1.3x |
| Softmax | 0.45 ms | 0.44 ms | ~1.0x |
| RMSNorm | 0.51 ms | 0.54 ms | 0.95x |
ZMLX helps most for fused operations that MLX doesn't provide as single ops. MLX built-ins (mx.fast.rms_norm, mx.softmax) are already highly optimized.
Kernel catalog
70+ kernels organized by domain. Full reference: docs/KERNELS.md.
| Module | Highlights |
|---|---|
moe |
topk_gating_softmax, moe_dispatch, moe_combine — fused expert routing |
transformer |
swiglu, geglu, rmsnorm_residual, dropout — genuine fusions |
loss |
softmax_cross_entropy — memory-efficient fused loss |
bits |
pack_bits, unpack_bits — no MLX equivalent |
quant |
FP8, NF4, int8, int4 dequantization |
norms |
rmsnorm, layernorm — float32 internal compute |
rope |
apply_rope, apply_rope_interleaved, apply_gqa_rope |
optimizers |
adamw_step — fused parameter update |
Install (full)
Requirements: macOS (Apple Silicon), Python >= 3.10, mlx >= 0.30.0
pip install "zmlx[train]"
Kernel authoring only:
pip install zmlx
From source:
git clone https://github.com/Hmbown/ZMLX.git
cd ZMLX
pip install -e ".[dev]"
Quick Start
from zmlx.api import elementwise
import mlx.core as mx
# Non-differentiable
fast_exp = elementwise("metal::exp(x)", name="fast_exp")
y = fast_exp(mx.random.normal((1024,)))
# Differentiable with custom VJP
from zmlx import msl
silu = elementwise(
"kk_silu(x)",
name="my_silu",
grad_expr="g * (s + x * s * ((T)1 - s))",
grad_prelude="T s = kk_sigmoid(x);",
use_output=False,
header=msl.DEFAULT_HEADER,
)
gx = mx.grad(lambda z: silu(z).sum())(mx.random.normal((1024,)))
More examples: reductions, softmax, testing
Custom reduction
from zmlx.api import reduce
import mlx.core as mx
my_sum = reduce(init="0.0f", update="acc + v", name="row_sum")
y = my_sum(mx.random.normal((8, 1024))) # shape (8,)
Two-pass map-reduce (softmax pattern)
from zmlx.api import map_reduce
import mlx.core as mx
my_softmax = map_reduce(
pass1={"init": "-INFINITY", "update": "max(acc1, x)", "reduce": "max(a, b)"},
pass2={"init": "0.0f", "update": "acc2 + exp(x - s1)", "reduce": "a + b"},
write="exp(x - s1) / s2",
name="my_softmax",
)
y = my_softmax(mx.random.normal((8, 1024)))
Test and benchmark
import zmlx
import mlx.core as mx
zmlx.testing.assert_matches(
my_softmax, lambda x: mx.softmax(x, axis=-1),
shapes=[(8, 1024), (32, 4096)],
)
zmlx.bench.compare(
{"ZMLX": my_softmax, "MLX": lambda x: mx.softmax(x, axis=-1)},
shapes=[(1024, 4096), (4096, 4096)],
)
Precision
All Python-level Metal kernels compute internally in float32 regardless of input dtype. When exact dtype behavior matters (e.g., bfloat16 accumulation order), ZMLX provides specialized kernels to match MLX’s semantics.
Documentation
docs/TOUR.md— Quick walkthrough and orientationdocs/QUICKSTART.md— 5-minute tutorialdocs/COOKBOOK.md— Recipes for common patternsdocs/KERNELS.md— Complete kernel catalogdocs/BENCHMARKS.md— Detailed benchmark methodologydocs/EXPERIMENTAL_MLX.md— Optional custom‑MLX experimentsdocs/ARCHITECTURE.md— Design philosophyUPSTREAM_PLAN.md— What belongs upstream in MLX
Acknowledgments
ZMLX is built on MLX by Apple machine learning research. If you use ZMLX in your work, please also cite MLX:
@software{mlx2023,
author = {Awni Hannun and Jagrit Digani and Angelos Katharopoulos and Ronan Collobert},
title = {{MLX}: Efficient and flexible machine learning on Apple silicon},
url = {https://github.com/ml-explore},
version = {0.0},
year = {2023},
}
Contributing
See CONTRIBUTING.md for setup, testing, and conventions.
License
MIT. See LICENSE.
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file zmlx-0.7.12.tar.gz.
File metadata
- Download URL: zmlx-0.7.12.tar.gz
- Upload date:
- Size: 210.5 kB
- Tags: Source
- Uploaded using Trusted Publishing? Yes
- Uploaded via: twine/6.1.0 CPython/3.13.7
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
ed0e3b8655acbfc7bb38f3dccc59652041afc2234648eed7e2e2580277f1ad15
|
|
| MD5 |
2fb1b2daf5b2e53f6a5ecb28bb3d5a50
|
|
| BLAKE2b-256 |
288ad7c052943e0df88b1310648d308584c6f7dd6c51651209e2d8a4c296a284
|
Provenance
The following attestation bundles were made for zmlx-0.7.12.tar.gz:
Publisher:
release.yml on Hmbown/ZMLX
-
Statement:
-
Statement type:
https://in-toto.io/Statement/v1 -
Predicate type:
https://docs.pypi.org/attestations/publish/v1 -
Subject name:
zmlx-0.7.12.tar.gz -
Subject digest:
ed0e3b8655acbfc7bb38f3dccc59652041afc2234648eed7e2e2580277f1ad15 - Sigstore transparency entry: 896328249
- Sigstore integration time:
-
Permalink:
Hmbown/ZMLX@c9858f64baa6341ac787a68cc729f8c013f123c2 -
Branch / Tag:
refs/tags/v0.7.12 - Owner: https://github.com/Hmbown
-
Access:
public
-
Token Issuer:
https://token.actions.githubusercontent.com -
Runner Environment:
github-hosted -
Publication workflow:
release.yml@c9858f64baa6341ac787a68cc729f8c013f123c2 -
Trigger Event:
release
-
Statement type:
File details
Details for the file zmlx-0.7.12-py3-none-any.whl.
File metadata
- Download URL: zmlx-0.7.12-py3-none-any.whl
- Upload date:
- Size: 135.2 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? Yes
- Uploaded via: twine/6.1.0 CPython/3.13.7
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
d8359b0d4b20e4991105f9be9fe43996277684d18195b57695218f23a1c09705
|
|
| MD5 |
2867d72d4b5cecd34a7fd546019291bf
|
|
| BLAKE2b-256 |
609b7596654fcfb376a31f7511c9cf99f7d81710071fc1cb7ac166baebeb7ac1
|
Provenance
The following attestation bundles were made for zmlx-0.7.12-py3-none-any.whl:
Publisher:
release.yml on Hmbown/ZMLX
-
Statement:
-
Statement type:
https://in-toto.io/Statement/v1 -
Predicate type:
https://docs.pypi.org/attestations/publish/v1 -
Subject name:
zmlx-0.7.12-py3-none-any.whl -
Subject digest:
d8359b0d4b20e4991105f9be9fe43996277684d18195b57695218f23a1c09705 - Sigstore transparency entry: 896328332
- Sigstore integration time:
-
Permalink:
Hmbown/ZMLX@c9858f64baa6341ac787a68cc729f8c013f123c2 -
Branch / Tag:
refs/tags/v0.7.12 - Owner: https://github.com/Hmbown
-
Access:
public
-
Token Issuer:
https://token.actions.githubusercontent.com -
Runner Environment:
github-hosted -
Publication workflow:
release.yml@c9858f64baa6341ac787a68cc729f8c013f123c2 -
Trigger Event:
release
-
Statement type: