Skip to main content

ZMLX: Metal-kernel toolkit and optimization lab for MLX on Apple Silicon. Fused MoE decode (+5-26%), custom GPU kernels in one line, 70+ kernel catalog.

Project description

ZMLX - Triton-style kernels for Apple Silicon

PyPI Python 3.10+ License: MIT Platform: macOS Apple Silicon

A Metal kernel toolkit and upstream incubation lab for MLX — author custom GPU kernels in Python, apply fused model patches, and prototype C++ Metal primitives for upstream MLX.

Toolkit highlights (available via pip install zmlx on stock MLX):

  • SwiGLU 2.0x, Dropout 7.5x, Top-K 3.7x in op-level microbenchmarks
  • 70+ kernel catalog, autograd, benchmarking utilities, and model patching

Validated benchmarks (M4 Max 36 GB, MLX 0.30.4.dev, Jan 31, 2026):

Result Measurement Notes
+5-8% decode on Qwen3-30B-A3B-4bit (MoE, E=128, K=8) 119-122 tok/s vs 113-114 Requires local MLX build with gather_qmm_swiglu
Neutral (1.01x) on Qwen3-4B-4bit (dense) 125.3 vs 124.4 tok/s Safe on dense models
+4% per MoE layer on LFM2-8B-A1B-4bit (MoE, E=32, K=4) 293 us -> 282 us at M=1 decode Kernel-level bench; E2E too noisy to report

Fused SwiGLU requires a local MLX build with gather_qmm_swiglu. On stock MLX, gating+combine-only can be neutral to slightly negative (Qwen3-30B measured 0.95–0.98x), so use smart_patch or a fused build for MoE models.

pip install zmlx

MoE patching example:

import mlx_lm
from zmlx.patch import patch

model, tokenizer = mlx_lm.load("mlx-community/Qwen3-30B-A3B-4bit")
patch(model)  # best on MoE with fused MLX; use smart_patch on stock MLX

Custom kernel example:

from zmlx.api import elementwise
import mlx.core as mx

# Math formula → compiled Metal kernel → runs on GPU
mish = elementwise("x * tanh(log(1 + exp(x)))", name="mish")
y = mish(mx.random.normal((1024,)))

What's New in v0.7.0

Fused expert SwiGLU (gather_qmm_swiglu)

ZMLX prototypes fused C++ Metal primitives in a local MLX fork. The first: gather_qmm_swiglu — fuses gate projection + up projection + SwiGLU activation into a single kernel launch, reading the input tensor once instead of twice.

  • +5-8% decode on Qwen3-30B-A3B-4bit (119-122 vs 113-114 tok/s, E2E)
  • +4% per MoE layer on LFM2-8B-A1B-4bit (293 us → 282 us at M=1 decode, kernel-level)
  • Neutral (1.01x) on Qwen3-4B-4bit (dense, E2E)
  • Auto-enabled by patch(model) when available; falls back to two-pass otherwise
  • Threshold-guarded: fused kernel used for M<=32 (decode/small prefill), falls back at large M where it is slower
  • Added kernel-level + E2E benchmark scripts plus correctness tests

Requires building MLX from the local fork (mlx_local/). See Optimization Lab for details. Upstream PR to MLX planned.

Optimization lab

ZMLX is evolving into two things:

  1. A Metal kernel toolkit (stable) — elementwise(), reduce(), map_reduce(), autograd, model patching. This is what pip install zmlx gives you.
  2. An MLX optimization lab (experimental) — prototyping fused C++ Metal primitives that should eventually be upstreamed to MLX. These require a local MLX build and live in mlx_local/.

The lab exists because some fusions (quantized matmul + activation) need access to MLX's internal SIMD helpers (qdot, load_vector, QuantizedBlockLoader) which aren't exposed through the public metal_kernel API. The plan is: prototype here, validate with benchmarks, upstream to MLX, then ZMLX auto-detects the primitives when they land.

Current lab work:

  • gather_qmm_swiglu — fused expert gate+up+SwiGLU (done, working, benchmarked)
  • add_rms_norm — fused residual add + RMSNorm (planned, benefits all models)
  • gather_qmm_combine — fused down projection + weighted expert sum (planned, MoE-specific)

Previous highlights (v0.6.x)

  • SIMD-group top-k gating, bias-aware fused gating, dynamic num_experts_per_tok
  • topk_gating_softmax(x, k) kernel (fused Metal for k<=8, full-softmax + expert bias)
  • LFM2, GPT-OSS, GLM-4, Mixtral model support
  • mode parameter, validated benchmarks, router attribute support

Previous highlights (v0.4-0.5)

  • MoE patch (fused gating + combine), high-level API, JIT compiler
  • Smart patching, training pipeline, 70+ kernel catalog

Why ZMLX?

When you need a custom GPU op on Apple Silicon, your options today are:

  1. Write raw Metal source strings, manage caching, figure out threadgroups, wire up autodiff manually
  2. Use ZMLX

ZMLX wraps mx.fast.metal_kernel and mx.custom_function to provide Triton-like ergonomics:

  • One-line kernel authoring - define elementwise, reduction, and map-reduce ops from C expressions
  • Automatic gradients - custom VJP backward passes (themselves Metal kernels) via mx.custom_function
  • Define-once caching - kernels compile once, reused by source hash + config
  • Autotuning - threadgroup size search with persistent caching
  • Testing & benchmarking - verify against reference ops, compare timings side-by-side
  • Model patching - swap MLX layers for fused ZMLX kernels with patch(model)

Install

Requirements: macOS (Apple Silicon), Python >= 3.10, mlx >= 0.30.0

pip install zmlx

From source (development):

git clone https://github.com/Hmbown/ZMLX.git
cd ZMLX
pip install -e ".[dev]"

Quick Start

1. Custom elementwise kernel

from zmlx.api import elementwise
import mlx.core as mx

# Non-differentiable - just forward pass
fast_exp = elementwise("metal::exp(x)", name="fast_exp")
y = fast_exp(mx.random.normal((1024,)))

# Differentiable - with custom VJP
from zmlx import msl

silu = elementwise(
    "kk_silu(x)",
    name="my_silu",
    grad_expr="g * (s + x * s * ((T)1 - s))",
    grad_prelude="T s = kk_sigmoid(x);",
    use_output=False,
    header=msl.DEFAULT_HEADER,
)
gx = mx.grad(lambda z: silu(z).sum())(mx.random.normal((1024,)))

2. Custom reduction

from zmlx.api import reduce
import mlx.core as mx

my_sum = reduce(init="0.0f", update="acc + v", name="row_sum")
y = my_sum(mx.random.normal((8, 1024)))  # shape (8,)

3. Two-pass map-reduce (softmax pattern)

from zmlx.api import map_reduce
import mlx.core as mx

my_softmax = map_reduce(
    pass1={"init": "-INFINITY", "update": "max(acc1, x)", "reduce": "max(a, b)"},
    pass2={"init": "0.0f", "update": "acc2 + exp(x - s1)", "reduce": "a + b"},
    write="exp(x - s1) / s2",
    name="my_softmax",
)
y = my_softmax(mx.random.normal((8, 1024)))

4. Test and benchmark your kernel

import zmlx
import mlx.core as mx

# Verify correctness
zmlx.testing.assert_matches(
    my_softmax, lambda x: mx.softmax(x, axis=-1),
    shapes=[(8, 1024), (32, 4096)],
)

# Benchmark
zmlx.bench.compare(
    {"ZMLX": my_softmax, "MLX": lambda x: mx.softmax(x, axis=-1)},
    shapes=[(1024, 4096), (4096, 4096)],
)

5. Lower-level building blocks

from zmlx import autograd, elementwise, msl
import mlx.core as mx

# Unary kernel (no gradient)
exp_kern = elementwise.unary(
    name="kk_exp", expr="metal::exp(x)",
    compute_dtype=mx.float32, header=msl.DEFAULT_HEADER,
)

# Binary kernel with custom VJP
mul_op = autograd.binary_from_expr(
    name="safe_mul", fwd_expr="a * b",
    vjp_lhs_expr="g * b", vjp_rhs_expr="g * a",
    compute_dtype=mx.float32,
)

Kernel Catalog

ZMLX includes 70+ kernels organized by domain. Some are genuinely useful for custom workloads (loss, GLU fusions, bit ops, MoE gating). Others are reference implementations showing codegen patterns - correct but not faster than MLX built-ins for standard transformer shapes.

Full reference: docs/KERNELS.md.

Module Highlights
loss softmax_cross_entropy - memory-efficient fused loss
transformer swiglu, geglu, rmsnorm_residual (with full weight gradients), dropout - genuine fusions
bits pack_bits, unpack_bits - no MLX equivalent
moe topk_gating_softmax, moe_dispatch, moe_combine - fused expert routing (k ≤ 8 fused, bias-aware)
quant FP8 (E4M3/E5M2), NF4, int8, int4 dequantization - real bit-manipulation kernels
optimizers adamw_step - fused AdamW parameter update in a single kernel
scan cumsum_lastdim - differentiable prefix sum
norms rmsnorm, layernorm - parallel reduction. All norms compute in float32 internally
softmax softmax_lastdim - map-reduce codegen showcase
rope apply_rope, apply_rope_interleaved, apply_gqa_rope
linear Reference fused-linear patterns (naive matmul, not for production)

Architecture

Three-layer design. Full details: docs/ARCHITECTURE.md.

  1. Metal kernel infrastructure - MetalKernel wrapper, in-process cache, stats tracking
  2. Code generation & helpers - MSL templates, elementwise/autograd/rowwise APIs, autotuning
  3. Kernel catalog - domain modules built on layers 1 and 2

Benchmarks

Three levels of measurement:

  1. Isolated primitivesbench_gather_qmm_swiglu.py
  2. Single MoE layerbench_moe_layer.py (500 iters, p50 median, mx.synchronize())
  3. E2E model decodebench_moe_e2e.py

Op-level (B=16, S=1024, D=1024, float16, M4 Max)

Run python benchmarks/microbench.py to reproduce on your hardware.

Operation MLX ZMLX Speedup
SwiGLU 0.87 ms 0.43 ms 2.0x
Dropout 3.08 ms 0.41 ms 7.5x
Top-K 1.81 ms 0.49 ms 3.7x
Gather-Add 0.55 ms 0.42 ms 1.3x
Softmax 0.45 ms 0.44 ms ~1.0x
RMSNorm 0.51 ms 0.54 ms 0.95x
MoE gating 0.35 ms 0.37 ms 0.94x
Sum 0.22 ms 0.37 ms 0.58x
CumSum 0.32 ms 0.62 ms 0.52x

ZMLX is most effective for fused operations that MLX does not provide as single ops (SwiGLU, fused-RNG dropout, fused gather-add). MLX built-ins (mx.fast.rms_norm, mx.softmax, reductions) are already highly optimized and remain the preferred choice for standard transformer shapes.

Model-level inference (E2E)

All baselines are unmodified mlx_lm. ZMLX rows add patch(model). Same model weights, same quantization, same prompt. Benchmarks on M4 Max 36 GB, MLX 0.30.4.dev, Jan 31, 2026 unless noted.

MoE models

Qwen3-30B-A3B-4bit (MoE, 48 layers, E=128, K=8) — python benchmarks/bench_moe_e2e.py
3 runs x 200 tokens, repeated 3 times.

Config Decode (tok/s) vs Baseline
Baseline (mlx_lm) 113-114
patch(model) — gating + combine only 109-111 0.95-0.98x
patch(model) — with fused SwiGLU 119-122 +5-8%

Fused SwiGLU requires a local MLX build. Gating+combine alone measured below baseline on this model.

LFM2-8B-A1B-4bit (MoE, 24 layers, E=32, K=4) — E2E is too noisy at ~20k tok/s.
Use the kernel-level MoE layer benchmark below as the authoritative measurement.

Dense models (neutral — expected)

Qwen3-4B-4bit (dense, 36 layers) — python benchmarks/bench_moe_e2e.py
3 runs x 1000 tokens.

Config Decode (tok/s) vs Baseline
Baseline (mlx_lm) 124.4
patch(model) 125.3 1.01x (neutral)

Dense decode is bandwidth-bound; patches are safe but not expected to help.

Overnight suites (sequential, cache on external drive):

python benchmarks/bench_moe_suite.py \
  --model-list benchmarks/moe_models.txt \
  --cache-dir /Volumes/VIXinSSD/TEST \
  --runs 3 --max-tokens 200 --resume

Kernel-level MoE layer timing (authoritative for fast models)

python benchmarks/bench_moe_layer.py isolates a single MoE layer forward pass, timed with mx.synchronize() brackets. 500 iterations, p50 median. Per-measurement stdev: 4-10%.

LFM2-8B-A1B-4bit (E=32, K=4, hidden=2048, intermediate=1792):

seq_len Baseline Gating+combine Fused SwiGLU Speedup
1 (decode) 293 us 288 us (1.02x) 282 us 1.04x
4 531 us 500 us (1.06x) 498 us 1.07x
16 998 us 943 us (1.06x) 949 us 1.05x
64 2291 us 2480 us (0.92x) 2452 us 0.93x

Qwen3-30B-A3B-4bit (E=128, K=8, hidden=2048, intermediate=1024):

seq_len Baseline Gating+combine Fused SwiGLU Speedup
1 (decode) 613 us 621 us (0.99x) 613 us 1.00x
4 808 us 724 us (1.12x) 756 us 1.07x
16 1269 us 1264 us (1.00x) 1284 us 0.99x
64 3565 us 3557 us (1.00x) 3508 us 1.02x

Key finding: fused SwiGLU saves ~4% per MoE layer on LFM2 at decode, but is neutral on Qwen3-30B at the single-layer level. The +5-8% E2E gain on Qwen3-30B comes from compound system effects: 48 MoE layers x one eliminated kernel dispatch is ~240 us saved per token (~2.7% of 8.85 ms/token), plus reduced intermediate memory pressure and fewer graph nodes.

Kernel-level primitive microbenchmarks

python benchmarks/bench_gather_qmm_swiglu.py — fused vs naive (2x gather_qmm + SwiGLU):

Config Naive Fused Speedup
M=1, K=512, N=512 170 us 133 us 1.28x
M=1, K=2048, N=1024 144 us 129 us 1.11x
M=1, K=2048, N=2048 148 us 134 us 1.10x
M=16, K=2048, N=1024 243 us 229 us 1.06x
M=64, K=2048, N=1024 240 us 534 us 0.45x

Fused kernel improves small M (decode) and is slower at large M (prefill). The patch auto-selects: fused for M<=32, two-pass fallback otherwise.

Next steps (roadmap)

  • add_rms_norm — fused residual add + RMSNorm in one kernel (benefits all models, ~20 line Metal diff)
  • gather_qmm_combine — fused down projection + weighted expert sum (MoE-specific, eliminates intermediate tensor)
  • Upstream gather_qmm_swiglu to MLX — PR with benchmarks so everyone benefits via pip install mlx
  • Per-device autotune profiles (better defaults by chip family)

When do patches help?

  • MoE Models (4-bit): Best case with fused SwiGLU on a local MLX build. Qwen3-30B gets +5-8% E2E; LFM2 shows +4% per MoE layer at decode (E2E too noisy to report). On stock MLX, gating+combine can be slower — use smart_patch or exclude moe_mlp.
  • MoE Models (8-bit): Fused SwiGLU is less impactful; benchmark before enabling.
  • MoE Models (pre-computed gating): GLM-4, DeepSeek-V3 — neutral. Gate is already @mx.compile-optimized.
  • Dense Models (any size): Neutral. Decode is bandwidth-bound; Qwen3-4B is 1.01x.
from zmlx.patch import smart_patch
import mlx.core as mx

# Auto-benchmark each pattern, keep only what helps
sample = mx.array([tokenizer.encode("Hello")])
model = smart_patch(model, sample)

Or use mode/presets if you know your workload:

from zmlx.patch import patch

patch(model)                    # inference default; use smart_patch on stock MLX for MoE
patch(model, mode="training")   # training: adds norm fusions for backward pass savings

# Or explicit presets for full control:
from zmlx.patch import ALL_PATTERNS, FUSED_ACTIVATIONS, TRAINING_RECOMMENDED
patch(model, patterns=FUSED_ACTIVATIONS)       # same as default
patch(model, patterns=TRAINING_RECOMMENDED)    # same as mode="training"
patch(model, patterns=ALL_PATTERNS)            # WARNING: can be slower on inference

Smart patching

smart_patch applies each candidate pattern, benchmarks the model's forward pass, and automatically reverts patterns that make things slower. It supports custom forward functions for realistic benchmarks:

from zmlx.patch import smart_patch

# Basic: benchmark raw forward pass
model = smart_patch(model, sample_input)

# Advanced: benchmark with actual generation
def gen_fn(model, sample):
    return mlx_lm.generate(model, tokenizer, prompt="Hello", max_tokens=20)

model = smart_patch(model, sample, forward_fn=gen_fn, threshold=0.99)

# Result includes per-pattern speedups
result = model._zmlx_patch_result
print(result.benchmarks)    # {'swiglu_mlp': 1.012, 'residual_norm': 0.971}
print(result.summary())     # what was kept and why

Autotuning

Replacement modules support threadgroup="auto" to search for the best threadgroup size on first invocation:

from zmlx.patch import patch
patch(model, threadgroup="auto")  # autotunes each kernel on first call

The map_reduce() API also supports autotuning:

from zmlx.api import map_reduce
my_softmax = map_reduce(..., threadgroup="auto")  # autotunes per-shape

Where ZMLX genuinely helps

  • MoE model inference — best results with fused expert SwiGLU (gather_qmm_swiglu, local MLX build). Qwen3-30B gets +5-8% E2E; LFM2 shows +4% per layer at decode. On stock MLX, use smart_patch to avoid gating+combine slowdowns. Supports Qwen3-MoE, LFM2, Mixtral, GPT-OSS.
  • Prototyping MLX-level optimizations — ZMLX's optimization lab incubates C++ Metal primitives. Prove value with benchmarks here, then upstream to MLX for everyone.
  • Custom ops that MLX doesn't have — SwiGLU, GeGLU, fused dropout, fused MoE gating, bit packing
  • Training — fused softmax_cross_entropy loss, correct weight gradients for rmsnorm_residual
  • Authoring new kernelselementwise(), reduce(), map_reduce() APIs: math formula to compiled Metal kernel in one line
  • Quantization — FP8 (E4M3/E5M2), NF4, int8, int4 dequantization with real bit-manipulation kernels

MoE performance notes

  • Expert matmuls are the bottleneck — for M=1 decode, MoE layers do 3 gather_qmm calls per layer (gate, up, down). Fusing gate+up+SwiGLU into one kernel (gather_qmm_swiglu) saves one full read of the input tensor.
  • E2E gains are a compound effect — on Qwen3-30B, 48 MoE layers x one eliminated dispatch saves ~240 us per token (~2.7% of 8.85 ms/token), plus reduced intermediate memory pressure and fewer graph nodes.
  • Fast models need kernel-level timing — LFM2 runs at ~20k tok/s, so E2E variance is 20-40%. The kernel-level MoE layer benchmark is the reliable measurement (+4% per layer at decode).

Where ZMLX won't help

  • Dense model inference — batch-1 decode is dominated by weight reads and bandwidth-bound. patch(model) is generally neutral (Qwen3-4B: 1.01x).
  • Replacing MLX built-in norms/softmaxmx.fast.rms_norm, mx.softmax, mx.fast.scaled_dot_product_attention are Apple-optimized fast paths. Custom kernels add dispatch overhead.

Optimization Lab

ZMLX includes a local MLX fork (mlx_local/) where we prototype fused C++ Metal primitives that need access to MLX's internal quantized matmul infrastructure. These are experimental and require building MLX from source.

What's in the lab

Primitive Status What it does
gather_qmm_swiglu Working, benchmarked Fused gate+up+SwiGLU for MoE experts
add_rms_norm Planned Fused residual add + RMSNorm
gather_qmm_combine Planned Fused down projection + weighted expert sum

Building the local MLX fork

# From the ZMLX repo root
CMAKE_ARGS='-DMETAL_CPP_URL=file:///path/to/ZMLX/mlx_local/third_party/metal-cpp_26.zip \
  -DMLX_METAL_MODULE_CACHE=/tmp/clang_module_cache \
  -DFETCHCONTENT_SOURCE_DIR_JSON=/path/to/ZMLX/mlx_local/third_party/json \
  -DFETCHCONTENT_SOURCE_DIR_FMT=/path/to/ZMLX/mlx_local/third_party/fmt \
  -DFETCHCONTENT_SOURCE_DIR_NANOBIND=/path/to/ZMLX/mlx_local/third_party/nanobind \
  -DMLX_BUILD_GGUF=OFF' \
pip install -e mlx_local --no-build-isolation

ZMLX auto-detects the fused primitives at runtime. If you're on stock MLX, patch(model) uses the standard two-pass path — for MoE models, prefer smart_patch to avoid gating+combine slowdowns.

The plan

Prototype here, validate with benchmarks, upstream to MLX via PR. Once primitives land in stock MLX, ZMLX auto-detects them and everyone benefits via pip install mlx. ZMLX is the incubator, MLX is the distribution.


Precision

All ZMLX Metal kernels compute internally in float32 regardless of input dtype. The compute_dtype parameter accepted by many kernel functions is deprecated and will be removed in a future release. Passing a non-None value will emit a DeprecationWarning.


Documentation


Contributing

See CONTRIBUTING.md for setup, testing, and conventions.


License

MIT. See LICENSE.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

zmlx-0.7.0.tar.gz (191.6 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

zmlx-0.7.0-py3-none-any.whl (128.5 kB view details)

Uploaded Python 3

File details

Details for the file zmlx-0.7.0.tar.gz.

File metadata

  • Download URL: zmlx-0.7.0.tar.gz
  • Upload date:
  • Size: 191.6 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: twine/6.1.0 CPython/3.13.7

File hashes

Hashes for zmlx-0.7.0.tar.gz
Algorithm Hash digest
SHA256 3d350c15045fda126c533f1b809310d80a3342439cb5492c57577e1b47abde05
MD5 353982e1078707879b79aedc949dd57c
BLAKE2b-256 6a101a4e3914dce1b331296fb369f1b6f15673beea68fa6e89113686d78ab82e

See more details on using hashes here.

Provenance

The following attestation bundles were made for zmlx-0.7.0.tar.gz:

Publisher: release.yml on Hmbown/ZMLX

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

File details

Details for the file zmlx-0.7.0-py3-none-any.whl.

File metadata

  • Download URL: zmlx-0.7.0-py3-none-any.whl
  • Upload date:
  • Size: 128.5 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: twine/6.1.0 CPython/3.13.7

File hashes

Hashes for zmlx-0.7.0-py3-none-any.whl
Algorithm Hash digest
SHA256 13a9aaa8b0a929c52e3951bef32238cd598aa50e40abb39d32e9a8c8f85e9d18
MD5 16e943b4ff92d03af1b412f40017f2a8
BLAKE2b-256 707a074fb16da1cc916c99a325f2af0d2511af9d0ff89577e134f8cd0c7d9005

See more details on using hashes here.

Provenance

The following attestation bundles were made for zmlx-0.7.0-py3-none-any.whl:

Publisher: release.yml on Hmbown/ZMLX

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page