ZMLX: Metal-kernel toolkit and optimization lab for MLX on Apple Silicon. Fused MoE decode (+5-12% on LFM2-8B-A1B), custom GPU kernels in one line, 70+ kernel catalog.
Project description
ZMLX — Faster MoE inference on Apple Silicon
ZMLX patches MLX models with fused Metal kernels for faster Mixture-of-Experts decode on Apple Silicon. No model conversion, no config changes — just pip install zmlx and patch(model).
import mlx_lm
from zmlx.patch import patch
model, tokenizer = mlx_lm.load("mlx-community/LFM2-8B-A1B-4bit")
patch(model)
text = mlx_lm.generate(model, tokenizer,
prompt="Explain mixture-of-experts in one paragraph.",
max_tokens=200)
Two modes
| Mode | What you need | What you get |
|---|---|---|
| Stable (default) | pip install zmlx + stock MLX |
Token-identical output, +5-12% decode on LFM2 |
| Fast (opt-in) | Custom MLX fork with mx.gather_qmm_swiglu |
Additional fused SwiGLU for MoE experts; faster on Qwen3 but may diverge on token fidelity |
patch() auto-detects which primitives are available and only enables what is safe. On stock MLX, Qwen3-MoE patterns are auto-excluded to prevent silent correctness loss.
Benchmarks — Stable mode (stock MLX)
LFM2-8B-A1B on M1 Pro 16 GB
macOS 14.6.1 · MLX 0.30.4 · ZMLX 0.7.11 · Python 3.10.0 · commit
7de879eMethod:
python -m zmlx.validate— greedy decode (temp=0), fixed prompt, 5 runs,max_tokens=500(4-bit runs terminated at 430 tokens due to EOS), median reported. Baseline is unpatchedmlx_lm; patched addspatch(model).Repro capsule:
benchmarks/repro_capsules/lfm2_m1pro_20260131.json· Print report:python -m zmlx.bench.report <capsule.json>
4-bit (mlx-community/LFM2-8B-A1B-4bit)
| Metric | Baseline | Patched | Change |
|---|---|---|---|
| Decode | 105.5 tok/s | 115.3 tok/s | +9.3% |
| Prefill | 225.4 tok/s | 227.1 tok/s | +0.8% (neutral) |
| Fidelity | — | 430/430 | token-identical |
| Peak memory | — | 5.3 GB |
8-bit (mlx-community/LFM2-8B-A1B-8bit-MLX)
| Metric | Baseline | Patched | Change |
|---|---|---|---|
| Decode | 72.8 tok/s | 76.4 tok/s | +5.0% |
| Prefill | 180.5 tok/s | 182.8 tok/s | +1.3% (neutral) |
| Fidelity | — | 500/500 | token-identical |
| Peak memory | — | 9.5 GB |
LFM2-8B-A1B on M4 Max 36 GB
macOS 26.1 · MLX 0.30.1 · ZMLX 0.7.11 · Python 3.12 · commit
139993eMethod:
python -m zmlx.validate— greedy decode (temp=0), fixed prompt, 5 runs,max_tokens=500(4-bit runs terminated at 430 tokens due to EOS), median reported. Baseline is unpatchedmlx_lm; patched addspatch(model).Repro capsule:
benchmarks/repro_capsules/lfm2_m4max_20260131.json· Print report:python -m zmlx.bench.report <capsule.json>
4-bit (mlx-community/LFM2-8B-A1B-4bit)
| Metric | Baseline | Patched | Change |
|---|---|---|---|
| Decode | 223.7 tok/s | 250.3 tok/s | +11.9% |
| Prefill | 737.4 tok/s | 755.4 tok/s | +2.4% (neutral) |
| Fidelity | — | 430/430 | token-identical |
| Peak memory | — | 5.30 GB |
8-bit (mlx-community/LFM2-8B-A1B-8bit-MLX)
| Metric | Baseline | Patched | Change |
|---|---|---|---|
| Decode | 152.5 tok/s | 164.3 tok/s | +7.7% |
| Prefill | 557.6 tok/s | 564.4 tok/s | +1.2% (neutral) |
| Fidelity | — | 500/500 | token-identical |
| Peak memory | — | 9.45 GB |
Benchmarks — Fast mode (custom MLX fork)
Requires a local MLX build that exposes
mx.gather_qmm_swiglu. This fuses the gate+up projections and SwiGLU activation into a single kernel for quantized MoE experts. Check availability:python -c "import mlx.core as mx; print(hasattr(mx, 'gather_qmm_swiglu'))".
Qwen3-30B-A3B (max_tokens=500, runs=3):
| Config | Decode speedup | Fidelity | Notes |
|---|---|---|---|
| Stock MLX, auto-excluded patches | 1.007x (base) / 0.979x (instruct) | PASS (500/500) | Safe but no MoE gain — patch() auto-excludes moe_mlp on Qwen3 |
Dev MLX + moe_mlp forced |
+6.9% (base) / +8.9% (instruct) | FAIL | Diverges at token 6 (base) / token 146 (instruct) |
The Qwen3 fidelity failure comes from precision differences in the fused gather_qmm_swiglu kernel (likely float16 internal accumulation vs MLX's default primitives). This is a known issue we plan to fix by switching the kernel to float32 accumulation. Until then, Qwen3 gains require explicit opt-in and are not enabled by default.
Notes on methodology
Prefill throughput is neutral by design — fused kernels are guarded to activate only at sequence length M <= 32 (decode), so prefill takes the standard MLX code path. Raw per-run data is in the repro capsules (benchmarks/repro_capsules/).
How It Works
The problem: dispatch overhead in MoE decode
In Mixture-of-Experts models, each token is routed to a subset of expert networks. During decode (generating one token at a time), the computation per expert is small — a few matrix multiplies on a single row vector. But the standard inference path dispatches multiple Metal kernels per expert per layer:
- Gating:
softmax(logits)→argpartition→gather→normalize— 4 dispatches - Expert execution: gate projection, up projection, SwiGLU activation, down projection — per expert
- Combine: element-wise multiply by gating weights → reduce-sum across experts — 2 dispatches
On Apple Silicon, each Metal kernel dispatch has fixed overhead (command buffer encoding, GPU scheduling). When the actual compute per dispatch is small — as it is for M=1 decode — this overhead dominates. The GPU spends more time waiting between kernels than doing math.
What ZMLX fuses
ZMLX replaces the multi-dispatch sequences with single Metal kernels that do the same math in one pass. All fused kernels are generated from Python via mx.fast.metal_kernel — no changes to MLX core required.
Fused top-k gating softmax (topk_gating_softmax):
Replaces the 4-dispatch gating sequence with a single kernel. For small expert counts (D <= 32, common in MoE), the kernel uses SIMD group operations — each row is processed by one SIMD group (32 threads), with simd_max and simd_sum for the softmax reduction and a register-based insertion sort for top-k selection. For larger D, a threadgroup reduction with shared memory is used. The kernel computes softmax probabilities and selects the top-k experts with their normalized weights in one pass.
Fused expert combine (moe_combine):
Replaces the separate element-wise multiply and reduce-sum with a single kernel that reads each expert output once, multiplies by its gating weight, and accumulates the weighted sum in float32. Output shape goes directly from (B, K, D) to (B, D) without materializing the intermediate weights * expert_outputs tensor.
Fused gate+up+SwiGLU (gather_qmm_swiglu, fast mode only):
A C++ Metal primitive (in mlx_local/) that fuses the gate projection, up projection, and SwiGLU activation for quantized MoE experts into a single kernel launch. Instead of reading the input vector twice (once for gate, once for up), it reads once and produces the activated output directly. This requires access to MLX's internal quantized matmul infrastructure and is the only ZMLX optimization that needs a custom MLX build.
Why prefill is unaffected
All fused kernels are guarded with a sequence length check (M <= 32). During prefill, M equals the prompt length (typically hundreds or thousands of tokens). At this scale, the compute-to-dispatch ratio is high and the standard MLX path is already efficient. The guards ensure ZMLX never regresses prefill performance.
Correctness guarantee
Token fidelity is a first-class requirement. patch() auto-detects the model family and excludes patterns with known fidelity issues. The fused gating kernel reproduces the exact same top-k selection and softmax normalization as the reference MLX ops. The combine kernel accumulates in float32 (or dtype-matched for Qwen3's moe_combine_exact). python -m zmlx.validate compares every generated token ID between patched and unpatched models under greedy decoding.
Patching options
from zmlx.patch import patch, smart_patch
patch(model) # auto-detect, apply safe defaults
patch(model, patterns=["moe_mlp"]) # force specific pattern (overrides safety)
patch(model, mode="training") # add norm fusions for backward pass
# Auto-benchmark: apply only patterns that actually help
sample = mx.array([tokenizer.encode("Hello")])
model = smart_patch(model, sample)
Model Support
Stable
Token-identical output, measurable decode improvement. Safe to use without further validation.
| Model | Decode speedup | Fidelity | Patterns |
|---|---|---|---|
| LFM2-8B-A1B-4bit | +9-12% | token-identical | moe_mlp + swiglu_mlp |
| LFM2-8B-A1B-8bit | +5-8% | token-identical | moe_mlp + swiglu_mlp |
Fast (requires custom MLX)
Requires a local MLX build with mx.gather_qmm_swiglu. Speedups exist but token parity is not guaranteed. Auto-excluded by patch() defaults — force with patch(model, patterns=["moe_mlp"]).
| Model | Decode speedup | Fidelity | Notes |
|---|---|---|---|
| Qwen3-30B-A3B-4bit | +6.9% | diverges at token 6 | Fused SwiGLU precision mismatch |
| Qwen3-30B-A3B-Instruct-2507-4bit | +8.9% | diverges at token 146 | Fused SwiGLU precision mismatch |
Tested (no gain)
| Model | Status | Notes |
|---|---|---|
| LFM2.5-1.2B-Thinking-MLX-8bit | 0.997x, PASS | Dense model, no matched MoE patterns |
| Qwen3-4B-4bit (dense) | diverges at token 18 | Dense model, patches not expected to help |
| Llama-3.2-1B-4bit | 0.98x, PASS | Dense model, bandwidth-bound |
For unlisted models: python -m zmlx.validate <model>.
Toolkit
ZMLX is also a Metal kernel authoring toolkit for MLX:
- 70+ kernel catalog — SwiGLU, GeGLU, fused dropout, MoE gating, RMSNorm, RoPE, quantization
- One-line kernel authoring —
elementwise("x * tanh(log(1 + exp(x)))")compiles to Metal - Automatic gradients — custom VJP backward passes as Metal kernels via
mx.custom_function - Benchmarking —
zmlx.bench.compare()for side-by-side timing,zmlx.bench.reportfor repro capsules
from zmlx.api import elementwise
import mlx.core as mx
mish = elementwise("x * tanh(log(1 + exp(x)))", name="mish")
y = mish(mx.random.normal((1024,)))
Op-level benchmarks
B=16, S=1024, D=1024, float16, M4 Max. python benchmarks/microbench.py:
| Operation | MLX | ZMLX | Speedup |
|---|---|---|---|
| SwiGLU | 0.87 ms | 0.43 ms | 2.0x |
| Dropout | 3.08 ms | 0.41 ms | 7.5x |
| Top-K | 1.81 ms | 0.49 ms | 3.7x |
| Gather-Add | 0.55 ms | 0.42 ms | 1.3x |
| Softmax | 0.45 ms | 0.44 ms | ~1.0x |
| RMSNorm | 0.51 ms | 0.54 ms | 0.95x |
ZMLX helps most for fused operations that MLX doesn't provide as single ops. MLX built-ins (mx.fast.rms_norm, mx.softmax) are already highly optimized.
Kernel catalog
70+ kernels organized by domain. Full reference: docs/KERNELS.md.
| Module | Highlights |
|---|---|
moe |
topk_gating_softmax, moe_dispatch, moe_combine — fused expert routing |
transformer |
swiglu, geglu, rmsnorm_residual, dropout — genuine fusions |
loss |
softmax_cross_entropy — memory-efficient fused loss |
bits |
pack_bits, unpack_bits — no MLX equivalent |
quant |
FP8, NF4, int8, int4 dequantization |
norms |
rmsnorm, layernorm — float32 internal compute |
rope |
apply_rope, apply_rope_interleaved, apply_gqa_rope |
optimizers |
adamw_step — fused parameter update |
Install
Requirements: macOS (Apple Silicon), Python >= 3.10, mlx >= 0.30.0
pip install zmlx
From source:
git clone https://github.com/Hmbown/ZMLX.git
cd ZMLX
pip install -e ".[dev]"
Quick Start
Custom elementwise kernel
from zmlx.api import elementwise
import mlx.core as mx
# Non-differentiable
fast_exp = elementwise("metal::exp(x)", name="fast_exp")
y = fast_exp(mx.random.normal((1024,)))
# Differentiable with custom VJP
from zmlx import msl
silu = elementwise(
"kk_silu(x)",
name="my_silu",
grad_expr="g * (s + x * s * ((T)1 - s))",
grad_prelude="T s = kk_sigmoid(x);",
use_output=False,
header=msl.DEFAULT_HEADER,
)
gx = mx.grad(lambda z: silu(z).sum())(mx.random.normal((1024,)))
Custom reduction
from zmlx.api import reduce
import mlx.core as mx
my_sum = reduce(init="0.0f", update="acc + v", name="row_sum")
y = my_sum(mx.random.normal((8, 1024))) # shape (8,)
Two-pass map-reduce (softmax pattern)
from zmlx.api import map_reduce
import mlx.core as mx
my_softmax = map_reduce(
pass1={"init": "-INFINITY", "update": "max(acc1, x)", "reduce": "max(a, b)"},
pass2={"init": "0.0f", "update": "acc2 + exp(x - s1)", "reduce": "a + b"},
write="exp(x - s1) / s2",
name="my_softmax",
)
y = my_softmax(mx.random.normal((8, 1024)))
Test and benchmark
import zmlx
import mlx.core as mx
zmlx.testing.assert_matches(
my_softmax, lambda x: mx.softmax(x, axis=-1),
shapes=[(8, 1024), (32, 4096)],
)
zmlx.bench.compare(
{"ZMLX": my_softmax, "MLX": lambda x: mx.softmax(x, axis=-1)},
shapes=[(1024, 4096), (4096, 4096)],
)
Optimization Lab
ZMLX includes a local MLX fork (mlx_local/) for prototyping fused C++ Metal primitives that need access to MLX internals. These are intended for eventual upstream contribution — see UPSTREAM_PLAN.md.
| Primitive | Status | Description |
|---|---|---|
gather_qmm_swiglu |
Working | Fused gate+up+SwiGLU for quantized MoE experts |
gather_qmm_combine |
Working | Fused down projection + weighted expert sum |
add_rms_norm |
Planned | Fused residual add + RMSNorm |
Precision
All Python-level Metal kernels compute internally in float32 regardless of input dtype. The C++ gather_qmm_swiglu primitive currently uses the default MLX quantized matmul precision, which may differ — this is the source of the Qwen3 fidelity issue and a target for improvement.
Documentation
docs/TOUR.md— Quick walkthrough and orientationdocs/QUICKSTART.md— 5-minute tutorialdocs/COOKBOOK.md— Recipes for common patternsdocs/KERNELS.md— Complete kernel catalogdocs/ARCHITECTURE.md— Design philosophyUPSTREAM_PLAN.md— What belongs upstream in MLX
Acknowledgments
ZMLX is built on MLX by Apple machine learning research. If you use ZMLX in your work, please also cite MLX:
@software{mlx2023,
author = {Awni Hannun and Jagrit Digani and Angelos Katharopoulos and Ronan Collobert},
title = {{MLX}: Efficient and flexible machine learning on Apple silicon},
url = {https://github.com/ml-explore},
version = {0.0},
year = {2023},
}
Contributing
See CONTRIBUTING.md for setup, testing, and conventions.
License
MIT. See LICENSE.
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file zmlx-0.7.11.tar.gz.
File metadata
- Download URL: zmlx-0.7.11.tar.gz
- Upload date:
- Size: 208.5 kB
- Tags: Source
- Uploaded using Trusted Publishing? Yes
- Uploaded via: twine/6.1.0 CPython/3.13.7
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
2c1876a28ed3acab684561413b3d22873c9437350aa190d7790e956f13f4a815
|
|
| MD5 |
b546b475035983234813c321c7363178
|
|
| BLAKE2b-256 |
9ee18890a59ae1b8e1509c4060535c9306a0d99b3ff521a5a5d385a7e9a6cc44
|
Provenance
The following attestation bundles were made for zmlx-0.7.11.tar.gz:
Publisher:
release.yml on Hmbown/ZMLX
-
Statement:
-
Statement type:
https://in-toto.io/Statement/v1 -
Predicate type:
https://docs.pypi.org/attestations/publish/v1 -
Subject name:
zmlx-0.7.11.tar.gz -
Subject digest:
2c1876a28ed3acab684561413b3d22873c9437350aa190d7790e956f13f4a815 - Sigstore transparency entry: 887548872
- Sigstore integration time:
-
Permalink:
Hmbown/ZMLX@101279122f8676ef8486f6e3e6ee45e11023c94e -
Branch / Tag:
refs/tags/v0.7.11 - Owner: https://github.com/Hmbown
-
Access:
public
-
Token Issuer:
https://token.actions.githubusercontent.com -
Runner Environment:
github-hosted -
Publication workflow:
release.yml@101279122f8676ef8486f6e3e6ee45e11023c94e -
Trigger Event:
release
-
Statement type:
File details
Details for the file zmlx-0.7.11-py3-none-any.whl.
File metadata
- Download URL: zmlx-0.7.11-py3-none-any.whl
- Upload date:
- Size: 135.0 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? Yes
- Uploaded via: twine/6.1.0 CPython/3.13.7
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
c194425b7a018b516068ddd27036b01dfe0b7db32b3686653e494b7189d390ea
|
|
| MD5 |
fb84b7fd9611a9979fa0b9fb3fc4bc23
|
|
| BLAKE2b-256 |
94ded15feca54283a0ef4de32d7c5b4e3551e356de58b103852ad9e1c5145aaf
|
Provenance
The following attestation bundles were made for zmlx-0.7.11-py3-none-any.whl:
Publisher:
release.yml on Hmbown/ZMLX
-
Statement:
-
Statement type:
https://in-toto.io/Statement/v1 -
Predicate type:
https://docs.pypi.org/attestations/publish/v1 -
Subject name:
zmlx-0.7.11-py3-none-any.whl -
Subject digest:
c194425b7a018b516068ddd27036b01dfe0b7db32b3686653e494b7189d390ea - Sigstore transparency entry: 887548961
- Sigstore integration time:
-
Permalink:
Hmbown/ZMLX@101279122f8676ef8486f6e3e6ee45e11023c94e -
Branch / Tag:
refs/tags/v0.7.11 - Owner: https://github.com/Hmbown
-
Access:
public
-
Token Issuer:
https://token.actions.githubusercontent.com -
Runner Environment:
github-hosted -
Publication workflow:
release.yml@101279122f8676ef8486f6e3e6ee45e11023c94e -
Trigger Event:
release
-
Statement type: