Skip to main content

ZMLX: Metal-kernel toolkit and optimization lab for MLX on Apple Silicon. Fused MoE decode (+5-12% on LFM2-8B-A1B), custom GPU kernels in one line, 70+ kernel catalog.

Project description

ZMLX — Faster MoE inference on Apple Silicon

PyPI Python 3.10+ License: MIT Platform: macOS Apple Silicon

ZMLX patches MLX models with fused Metal kernels for faster Mixture-of-Experts decode on Apple Silicon. Stock MLX only for the supported models. No model conversion, no config changes — just install and patch(model).

Quick start (model patching; requires zmlx[train] or mlx-lm):

import mlx_lm
from zmlx.patch import patch

model, tokenizer = mlx_lm.load("mlx-community/LFM2-8B-A1B-4bit")
patch(model)

text = mlx_lm.generate(model, tokenizer,
    prompt="Explain mixture-of-experts in one paragraph.",
    max_tokens=200)

Install (recommended)

pip install "zmlx[train]"

That includes mlx-lm and everything needed for model patching. For kernel authoring only:

pip install zmlx

Supported mode

Mode What you need What you get
Stable (default) Stock MLX + zmlx Token-identical output, real decode speedups

patch() auto-detects which patterns are safe for your model family.


Benchmarks — Stable mode (stock MLX)

LFM2-8B-A1B: +5-12% decode, token-identical, measured on M1 Pro and M4 Max. Prefill neutral by design (fused kernels activate only at M <= 32).

LFM2-8B-A1B on M1 Pro 16 GB

macOS 14.6.1 · MLX 0.30.4 · ZMLX 0.7.12 · Python 3.10.0 · commit 7de879e

Repro capsule: benchmarks/repro_capsules/lfm2_m1pro_20260131.json · Print report: python -m zmlx.bench.report <capsule.json>

4-bit (mlx-community/LFM2-8B-A1B-4bit)

Metric Baseline Patched Change
Decode 105.5 tok/s 115.3 tok/s +9.3%
Prefill 225.4 tok/s 227.1 tok/s +0.8% (neutral)
Fidelity 430/430 token-identical
Peak memory 5.3 GB

8-bit (mlx-community/LFM2-8B-A1B-8bit-MLX)

Metric Baseline Patched Change
Decode 72.8 tok/s 76.4 tok/s +5.0%
Prefill 180.5 tok/s 182.8 tok/s +1.3% (neutral)
Fidelity 500/500 token-identical
Peak memory 9.5 GB
LFM2-8B-A1B on M4 Max 36 GB

macOS 26.1 · MLX 0.30.1 · ZMLX 0.7.12 · Python 3.12 · commit 139993e

Repro capsule: benchmarks/repro_capsules/lfm2_m4max_20260131.json · Print report: python -m zmlx.bench.report <capsule.json>

4-bit (mlx-community/LFM2-8B-A1B-4bit)

Metric Baseline Patched Change
Decode 223.7 tok/s 250.3 tok/s +11.9%
Prefill 737.4 tok/s 755.4 tok/s +2.4% (neutral)
Fidelity 430/430 token-identical
Peak memory 5.30 GB

8-bit (mlx-community/LFM2-8B-A1B-8bit-MLX)

Metric Baseline Patched Change
Decode 152.5 tok/s 164.3 tok/s +7.7%
Prefill 557.6 tok/s 564.4 tok/s +1.2% (neutral)
Fidelity 500/500 token-identical
Peak memory 9.45 GB

Full methodology and raw data: docs/BENCHMARKS.md.


How It Works

The problem: dispatch overhead in MoE decode

In Mixture-of-Experts models, each token is routed to a subset of expert networks. During decode (generating one token at a time), the computation per expert is small — a few matrix multiplies on a single row vector. But the standard inference path dispatches multiple Metal kernels per expert per layer:

  1. Gating: softmax(logits)argpartitiongathernormalize — 4 dispatches
  2. Expert execution: gate projection, up projection, SwiGLU activation, down projection — per expert
  3. Combine: element-wise multiply by gating weights → reduce-sum across experts — 2 dispatches

On Apple Silicon, each Metal kernel dispatch has fixed overhead (command buffer encoding, GPU scheduling). When the actual compute per dispatch is small — as it is for M=1 decode — this overhead dominates. The GPU spends more time waiting between kernels than doing math.

What ZMLX fuses

ZMLX replaces the multi-dispatch sequences with single Metal kernels that do the same math in one pass. All fused kernels are generated from Python via mx.fast.metal_kernel — no changes to MLX core required.

Fused top-k gating softmax (topk_gating_softmax):

Replaces the 4-dispatch gating sequence with a single kernel. For small expert counts (D <= 32, common in MoE), the kernel uses SIMD group operations — each row is processed by one SIMD group (32 threads), with simd_max and simd_sum for the softmax reduction and a register-based insertion sort for top-k selection. For larger D, a threadgroup reduction with shared memory is used. The kernel computes softmax probabilities and selects the top-k experts with their normalized weights in one pass.

Fused expert combine (moe_combine):

Replaces the separate element-wise multiply and reduce-sum with a single kernel that reads each expert output once, multiplies by its gating weight, and accumulates the weighted sum in float32. Output shape goes directly from (B, K, D) to (B, D) without materializing the intermediate weights * expert_outputs tensor.

Why prefill is unaffected

All fused kernels are guarded with a sequence length check (M <= 32). During prefill, M equals the prompt length (typically hundreds or thousands of tokens). At this scale, the compute-to-dispatch ratio is high and the standard MLX path is already efficient. The guards ensure ZMLX never regresses prefill performance.

Correctness guarantee

Token fidelity is a first-class requirement. patch() auto-detects the model family and excludes patterns with known fidelity issues. The fused gating kernel reproduces the exact same top-k selection and softmax normalization as the reference MLX ops. The combine kernel accumulates in float32 (or dtype-matched for Qwen3's moe_combine_exact). python -m zmlx.validate compares every generated token ID between patched and unpatched models under greedy decoding.

Patching options

from zmlx.patch import patch, smart_patch

patch(model)                       # auto-detect, apply safe defaults
patch(model, patterns=["moe_mlp"]) # force specific pattern (overrides safety)
patch(model, mode="training")      # add norm fusions for backward pass

# Auto-benchmark: apply only patterns that actually help
sample = mx.array([tokenizer.encode("Hello")])
model = smart_patch(model, sample)

Model Support

Stable

Token-identical output, measurable decode improvement. Safe to use without further validation.

Model Decode speedup Fidelity Patterns
LFM2-8B-A1B-4bit +9-12% token-identical moe_mlp + swiglu_mlp
LFM2-8B-A1B-8bit +5-8% token-identical moe_mlp + swiglu_mlp
Qwen3-30B-A3B-4bit +7% token-identical moe_mlp
GPT-OSS-20B-MXFP4-Q4 +1% token-identical moe_mlp

Tested (no gain)

Model Status Notes
Nemotron-3-Nano-30B-A3B-NVFP4 0.999x, PASS Hybrid Mamba-MoE, bandwidth-limited at 19.4 GB
LFM2.5-1.2B-Thinking-MLX-8bit 0.997x, PASS Dense model, no matched MoE patterns
Qwen3-4B-4bit (dense) diverges at token 18 Dense model, patches not expected to help
Llama-3.2-1B-4bit 0.98x, PASS Dense model, bandwidth-bound

For unlisted models: python -m zmlx.validate <model>.


Toolkit

ZMLX is also a Metal kernel authoring toolkit for MLX:

  • 70+ kernel catalog — SwiGLU, GeGLU, fused dropout, MoE gating, RMSNorm, RoPE, quantization
  • One-line kernel authoringelementwise("x * tanh(log(1 + exp(x)))") compiles to Metal
  • Automatic gradients — custom VJP backward passes as Metal kernels via mx.custom_function
  • Benchmarkingzmlx.bench.compare() for side-by-side timing, zmlx.bench.report for repro capsules
from zmlx.api import elementwise
import mlx.core as mx

mish = elementwise("x * tanh(log(1 + exp(x)))", name="mish")
y = mish(mx.random.normal((1024,)))
Op-level microbenchmarks

B=16, S=1024, D=1024, float16, M4 Max. python benchmarks/microbench.py:

Operation MLX ZMLX Speedup
SwiGLU 0.87 ms 0.43 ms 2.0x
Dropout 3.08 ms 0.41 ms 7.5x
Top-K 1.81 ms 0.49 ms 3.7x
Gather-Add 0.55 ms 0.42 ms 1.3x
Softmax 0.45 ms 0.44 ms ~1.0x
RMSNorm 0.51 ms 0.54 ms 0.95x

ZMLX helps most for fused operations that MLX doesn't provide as single ops. MLX built-ins (mx.fast.rms_norm, mx.softmax) are already highly optimized.

Kernel catalog

70+ kernels organized by domain. Full reference: docs/KERNELS.md.

Module Highlights
moe topk_gating_softmax, moe_dispatch, moe_combine — fused expert routing
transformer swiglu, geglu, rmsnorm_residual, dropout — genuine fusions
loss softmax_cross_entropy — memory-efficient fused loss
bits pack_bits, unpack_bits — no MLX equivalent
quant FP8, NF4, int8, int4 dequantization
norms rmsnorm, layernorm — float32 internal compute
rope apply_rope, apply_rope_interleaved, apply_gqa_rope
optimizers adamw_step — fused parameter update

Install (full)

Requirements: macOS (Apple Silicon), Python >= 3.10, mlx >= 0.30.0

pip install "zmlx[train]"

Kernel authoring only:

pip install zmlx

From source:

git clone https://github.com/Hmbown/ZMLX.git
cd ZMLX
pip install -e ".[dev]"

Quick Start

from zmlx.api import elementwise
import mlx.core as mx

# Non-differentiable
fast_exp = elementwise("metal::exp(x)", name="fast_exp")
y = fast_exp(mx.random.normal((1024,)))

# Differentiable with custom VJP
from zmlx import msl

silu = elementwise(
    "kk_silu(x)",
    name="my_silu",
    grad_expr="g * (s + x * s * ((T)1 - s))",
    grad_prelude="T s = kk_sigmoid(x);",
    use_output=False,
    header=msl.DEFAULT_HEADER,
)
gx = mx.grad(lambda z: silu(z).sum())(mx.random.normal((1024,)))
More examples: reductions, softmax, testing

Custom reduction

from zmlx.api import reduce
import mlx.core as mx

my_sum = reduce(init="0.0f", update="acc + v", name="row_sum")
y = my_sum(mx.random.normal((8, 1024)))  # shape (8,)

Two-pass map-reduce (softmax pattern)

from zmlx.api import map_reduce
import mlx.core as mx

my_softmax = map_reduce(
    pass1={"init": "-INFINITY", "update": "max(acc1, x)", "reduce": "max(a, b)"},
    pass2={"init": "0.0f", "update": "acc2 + exp(x - s1)", "reduce": "a + b"},
    write="exp(x - s1) / s2",
    name="my_softmax",
)
y = my_softmax(mx.random.normal((8, 1024)))

Test and benchmark

import zmlx
import mlx.core as mx

zmlx.testing.assert_matches(
    my_softmax, lambda x: mx.softmax(x, axis=-1),
    shapes=[(8, 1024), (32, 4096)],
)

zmlx.bench.compare(
    {"ZMLX": my_softmax, "MLX": lambda x: mx.softmax(x, axis=-1)},
    shapes=[(1024, 4096), (4096, 4096)],
)

Precision

All Python-level Metal kernels compute internally in float32 regardless of input dtype. When exact dtype behavior matters (e.g., bfloat16 accumulation order), ZMLX provides specialized kernels to match MLX’s semantics.


Documentation


Acknowledgments

ZMLX is built on MLX by Apple machine learning research. If you use ZMLX in your work, please also cite MLX:

@software{mlx2023,
  author = {Awni Hannun and Jagrit Digani and Angelos Katharopoulos and Ronan Collobert},
  title = {{MLX}: Efficient and flexible machine learning on Apple silicon},
  url = {https://github.com/ml-explore},
  version = {0.0},
  year = {2023},
}

Contributing

See CONTRIBUTING.md for setup, testing, and conventions.


License

MIT. See LICENSE.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

zmlx-0.7.12.tar.gz (210.5 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

zmlx-0.7.12-py3-none-any.whl (135.2 kB view details)

Uploaded Python 3

File details

Details for the file zmlx-0.7.12.tar.gz.

File metadata

  • Download URL: zmlx-0.7.12.tar.gz
  • Upload date:
  • Size: 210.5 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: twine/6.1.0 CPython/3.13.7

File hashes

Hashes for zmlx-0.7.12.tar.gz
Algorithm Hash digest
SHA256 ed0e3b8655acbfc7bb38f3dccc59652041afc2234648eed7e2e2580277f1ad15
MD5 2fb1b2daf5b2e53f6a5ecb28bb3d5a50
BLAKE2b-256 288ad7c052943e0df88b1310648d308584c6f7dd6c51651209e2d8a4c296a284

See more details on using hashes here.

Provenance

The following attestation bundles were made for zmlx-0.7.12.tar.gz:

Publisher: release.yml on Hmbown/ZMLX

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

File details

Details for the file zmlx-0.7.12-py3-none-any.whl.

File metadata

  • Download URL: zmlx-0.7.12-py3-none-any.whl
  • Upload date:
  • Size: 135.2 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: twine/6.1.0 CPython/3.13.7

File hashes

Hashes for zmlx-0.7.12-py3-none-any.whl
Algorithm Hash digest
SHA256 d8359b0d4b20e4991105f9be9fe43996277684d18195b57695218f23a1c09705
MD5 2867d72d4b5cecd34a7fd546019291bf
BLAKE2b-256 609b7596654fcfb376a31f7511c9cf99f7d81710071fc1cb7ac166baebeb7ac1

See more details on using hashes here.

Provenance

The following attestation bundles were made for zmlx-0.7.12-py3-none-any.whl:

Publisher: release.yml on Hmbown/ZMLX

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page