ZMLX: Triton for Apple Silicon. Custom Metal GPU kernels for MLX with Triton-like ergonomics. Killer feature: 1.3-1.6x MoE inference speedup via fused expert routing.
Project description
ZMLX — Triton for Apple Silicon
The Triton-like toolkit for MLX — write custom Metal GPU kernels from Python with one-line ergonomics, automatic gradients, and built-in autotuning.
MoE models get 1.3–1.6x inference speedup with
patch(model)— fused expert routing eliminates kernel launches MLX doesn't optimize. Dense models are neutral (safe, no regressions). Benchmarks
Model Prompt Decode Status Qwen3-30B-A3B (MoE) +61% +37% Validated Qwen3-30B-A3B-Instruct (MoE) +60% +31% Validated GLM-4.7-Flash (MoE, pre-computed gating) +1% -1% Neutral — gating already @mx.compile-optimized, expanding supportDense models (8B–32B) neutral neutral Safe, no regressions
pip install zmlx
Speed up MoE inference in 3 lines:
import mlx_lm
from zmlx.patch import patch
model, tokenizer = mlx_lm.load("mlx-community/Qwen3-30B-A3B-Instruct-2507-4bit")
patch(model) # +60% prompt, +31% decode — safe on all models, MoE gets the win
Or write custom GPU kernels in one line:
from zmlx.api import elementwise
import mlx.core as mx
# Math formula → compiled Metal kernel → runs on GPU
mish = elementwise("x * tanh(log(1 + exp(x)))", name="mish")
y = mish(mx.random.normal((1024,)))
What's New in v0.6.1
- Benchmark default fixed:
inference_benchmark.pynow usesFUSED_ACTIVATIONSby default (matchingpatch(model)behavior). Previously defaulted toALL_PATTERNS, giving misleadingly low decode numbers. Use--all-patternsto opt in to norms/softmax. - Qwen3-30B-A3B (base) validated: +61% prompt / +37% decode — freshly confirmed on v0.6.1.
- GLM-4.7-Flash: neutral (gating already
@mx.compile-optimized) — expanding support in progress.
Previous highlights (v0.6.0)
modeparameter:patch(model, mode="training")for workload-aware preset selection.- Validated benchmarks: MoE models get 1.3–1.6x; dense models are neutral;
ALL_PATTERNScauses 3–5% regression — docstrings now warn explicitly.
Previous highlights (v0.4–0.5)
- MoE patch — fused
top2_gating_softmax+moe_combineeliminates 4+ kernel launches in expert routing - High-level API —
elementwise(),reduce(),map_reduce()for kernel authoring in one line - JIT compiler —
@jitdecorator compiles Python scalar expressions to Metal - Smart patching —
smart_patch()auto-benchmarks each pattern and keeps only what helps - Training pipeline —
zmlx trainCLI for LoRA fine-tuning with ZMLX patches - 70+ kernel catalog — activations, norms, RoPE, attention, MoE, quantization, loss, bit ops
Why ZMLX?
When you need a custom GPU op on Apple Silicon, your options today are:
- Write raw Metal source strings, manage caching, figure out threadgroups, wire up autodiff manually
- Use ZMLX
ZMLX wraps mx.fast.metal_kernel and mx.custom_function to provide Triton-like ergonomics:
- One-line kernel authoring — define elementwise, reduction, and map-reduce ops from C expressions
- Automatic gradients — custom VJP backward passes (themselves Metal kernels) via
mx.custom_function - Define-once caching — kernels compile once, reused by source hash + config
- Autotuning — threadgroup size search with persistent caching
- Testing & benchmarking — verify against reference ops, compare timings side-by-side
- Model patching — swap MLX layers for fused ZMLX kernels with
patch(model)
Install
Requirements: macOS (Apple Silicon), Python >= 3.10, mlx >= 0.30.0
pip install zmlx
From source (development):
git clone https://github.com/Hmbown/ZMLX.git
cd ZMLX
pip install -e ".[dev]"
Quick Start
1. Custom elementwise kernel
from zmlx.api import elementwise
import mlx.core as mx
# Non-differentiable — just forward pass
fast_exp = elementwise("metal::exp(x)", name="fast_exp")
y = fast_exp(mx.random.normal((1024,)))
# Differentiable — with custom VJP
from zmlx import msl
silu = elementwise(
"kk_silu(x)",
name="my_silu",
grad_expr="g * (s + x * s * ((T)1 - s))",
grad_prelude="T s = kk_sigmoid(x);",
use_output=False,
header=msl.DEFAULT_HEADER,
)
gx = mx.grad(lambda z: silu(z).sum())(mx.random.normal((1024,)))
2. Custom reduction
from zmlx.api import reduce
import mlx.core as mx
my_sum = reduce(init="0.0f", update="acc + v", name="row_sum")
y = my_sum(mx.random.normal((8, 1024))) # shape (8,)
3. Two-pass map-reduce (softmax pattern)
from zmlx.api import map_reduce
import mlx.core as mx
my_softmax = map_reduce(
pass1={"init": "-INFINITY", "update": "max(acc1, x)", "reduce": "max(a, b)"},
pass2={"init": "0.0f", "update": "acc2 + exp(x - s1)", "reduce": "a + b"},
write="exp(x - s1) / s2",
name="my_softmax",
)
y = my_softmax(mx.random.normal((8, 1024)))
4. Test and benchmark your kernel
import zmlx
import mlx.core as mx
# Verify correctness
zmlx.testing.assert_matches(
my_softmax, lambda x: mx.softmax(x, axis=-1),
shapes=[(8, 1024), (32, 4096)],
)
# Benchmark
zmlx.bench.compare(
{"ZMLX": my_softmax, "MLX": lambda x: mx.softmax(x, axis=-1)},
shapes=[(1024, 4096), (4096, 4096)],
)
5. Lower-level building blocks
from zmlx import autograd, elementwise, msl
import mlx.core as mx
# Unary kernel (no gradient)
exp_kern = elementwise.unary(
name="kk_exp", expr="metal::exp(x)",
compute_dtype=mx.float32, header=msl.DEFAULT_HEADER,
)
# Binary kernel with custom VJP
mul_op = autograd.binary_from_expr(
name="safe_mul", fwd_expr="a * b",
vjp_lhs_expr="g * b", vjp_rhs_expr="g * a",
compute_dtype=mx.float32,
)
Kernel Catalog
ZMLX includes 70+ kernels organized by domain. Some are genuinely useful for custom workloads (loss, GLU fusions, bit ops, MoE gating). Others are reference implementations showing codegen patterns — correct but not faster than MLX built-ins for standard transformer shapes.
Full reference: docs/KERNELS.md.
| Module | Highlights |
|---|---|
loss |
softmax_cross_entropy — memory-efficient fused loss |
transformer |
swiglu, geglu, rmsnorm_residual (with full weight gradients), dropout — genuine fusions |
bits |
pack_bits, unpack_bits — no MLX equivalent |
moe |
top2_gating_softmax, moe_dispatch, moe_combine — fused expert routing (+36% decode on 30B MoE) |
quant |
FP8 (E4M3/E5M2), NF4, int8, int4 dequantization — real bit-manipulation kernels |
optimizers |
adamw_step — fused AdamW parameter update in a single kernel |
scan |
cumsum_lastdim — differentiable prefix sum |
norms |
rmsnorm, layernorm — parallel reduction. All norms compute in float32 internally |
softmax |
softmax_lastdim — map-reduce codegen showcase |
rope |
apply_rope, apply_rope_interleaved, apply_gqa_rope |
linear |
Reference fused-linear patterns (naive matmul, not for production) |
Architecture
Three-layer design. Full details: docs/ARCHITECTURE.md.
- Metal kernel infrastructure —
MetalKernelwrapper, in-process cache, stats tracking - Code generation & helpers — MSL templates, elementwise/autograd/rowwise APIs, autotuning
- Kernel catalog — domain modules built on layers 1 and 2
Benchmarks
Op-level (B=16, S=1024, D=1024, float32, M4 Max)
Run python benchmarks/microbench.py to reproduce on your hardware.
| Operation | MLX | ZMLX | Speedup |
|---|---|---|---|
| SwiGLU | 0.85 ms | 0.40 ms | 2.1x |
| Dropout | 3.12 ms | 0.38 ms | 8.2x |
| Top-K | 1.82 ms | 0.49 ms | 3.7x |
| Gather-Add | 0.54 ms | 0.41 ms | 1.3x |
| Softmax | 0.36 ms | 0.41 ms | 0.90x |
| RMSNorm | 0.37 ms | 0.41 ms | 0.90x |
| Sum | 0.19 ms | 0.36 ms | 0.53x |
| CumSum | 0.30 ms | 0.59 ms | 0.51x |
ZMLX wins big on fused operations that MLX doesn't provide as single ops (SwiGLU, fused-RNG dropout, fused gather-add). MLX's built-in operations (mx.fast.rms_norm, mx.softmax, reductions) are already highly optimized and should not be replaced.
Model-level inference
All baselines are unmodified mlx_lm (mlx_lm.load() + mlx_lm.generate()) — the standard MLX inference stack. ZMLX rows add patch(model) (default: FUSED_ACTIVATIONS) or explicit ALL_PATTERNS on top of that same pipeline. Same model weights, same quantization, same prompt.
LLM inference is memory-bandwidth-bound: fused kernels shine on large models where each saved memory round-trip matters. The effect scales with model size — small models see no benefit, large models see significant speedups.
Qwen3-32B-4bit (dense, 64 layers, ~19 GB) — M4 Max, 36 GB
| Config | Prompt (tok/s) | Decode (tok/s) | vs Baseline |
|---|---|---|---|
Baseline (mlx_lm) |
149 | 18.0 | — |
patch(model) (default) |
148 | 17.7 | 0.99x / 0.99x |
patch(model, patterns=ALL_PATTERNS) |
147 | 17.2 | 0.99x / 0.97x |
Dense batch-1 decode is ~95% weight-reading through quantized matmuls — already at the hardware bandwidth limit. Custom kernels can't beat MLX's built-in
mx.fast.rms_norm/mx.fast.rope/mx.fast.scaled_dot_product_attention. ZMLX shines on MoE models where fused expert routing eliminates kernel launches that MLX lacks fast paths for.
Qwen3-30B-A3B-4bit (MoE, 48 layers, 3B active/30B total) — M4 Max, 36 GB
| Config | Prompt (tok/s) | Decode (tok/s) | vs Baseline |
|---|---|---|---|
Baseline (mlx_lm) |
1,084 | 113 | — |
patch(model) (default) |
1,749 | 154 | 1.61x / 1.37x |
Qwen3-30B-A3B-Instruct-2507-4bit (MoE) — M4 Max, 36 GB
| Config | Prompt (tok/s) | Decode (tok/s) | vs Baseline |
|---|---|---|---|
Baseline (mlx_lm) |
1,093 | 106 | — |
patch(model) (default) |
1,754 | 138 | 1.60x / 1.31x |
Fused gating (
top2_gating_softmax) and combine (moe_combine) kernels eliminate multiple memory round-trips in the expert routing path. No regressions.
Qwen3-8B-4bit (32 layers, ~5 GB) — python benchmarks/inference_benchmark.py --models qwen3-8b --selective
| Config | Prompt (tok/s) | Decode (tok/s) | vs Baseline |
|---|---|---|---|
Baseline (mlx_lm) |
676 | 75 | — |
patch(model) (default) |
675 | 76 | 1.00x / 1.01x |
Llama-3.2-1B-Instruct-4bit (16 layers, ~0.8 GB) — python benchmarks/llama_benchmark.py
| Config | Prompt (tok/s) | Decode (tok/s) | vs Baseline |
|---|---|---|---|
Baseline (mlx_lm) |
3,913 | 377 | — |
patch(model) (default) |
3,804 | 378 | 0.97x / 1.00x |
patch(model, patterns=ALL_PATTERNS) |
3,705 | 366 | 0.95x / 0.97x |
GLM-4.7-Flash-4bit (MoE, 47 layers, 30B-A3B, sigmoid+group gating) — M4 Max, 36 GB
| Config | Prompt (tok/s) | Decode (tok/s) | vs Baseline |
|---|---|---|---|
Baseline (mlx_lm) |
662 | 74.1 | — |
patch(model) (default) |
672 | 73.6 | 1.01x / 0.99x |
GLM-4 uses sigmoid + group-based gating with
@mx.compile, so the gate is already optimized. ZMLX preserves the original gating and only fuses the combine step — not enough to move the needle. Neutral, no regression. Expanding GLM-4.7 support is in progress.
When do patches help?
- MoE Models (softmax-gated): Qwen3-MoE, Mixtral — fused gating provides 1.3–1.6x. The gate must return raw logits (not pre-computed indices).
- MoE Models (pre-computed gating): GLM-4, DeepSeek-V3 — neutral. Gate is already
@mx.compile-optimized. - Large Dense Models (32B+): Neutral. MLX built-ins are already at bandwidth limit.
- Medium Dense Models (8B): Neutral-to-positive, no regressions.
- Small Models (< 3B): Neutral. Use
smart_patchto be sure.
from zmlx.patch import smart_patch
import mlx.core as mx
# Auto-benchmark each pattern, keep only what helps
sample = mx.array([tokenizer.encode("Hello")])
model = smart_patch(model, sample)
Or use mode/presets if you know your workload:
from zmlx.patch import patch
patch(model) # inference (safe default — MoE gets 1.3–1.6x, dense neutral)
patch(model, mode="training") # training: adds norm fusions for backward pass savings
# Or explicit presets for full control:
from zmlx.patch import ALL_PATTERNS, FUSED_ACTIVATIONS, TRAINING_RECOMMENDED
patch(model, patterns=FUSED_ACTIVATIONS) # same as default
patch(model, patterns=TRAINING_RECOMMENDED) # same as mode="training"
patch(model, patterns=ALL_PATTERNS) # WARNING: regresses 3–5% on inference
Smart patching
smart_patch applies each candidate pattern, benchmarks the model's forward pass, and automatically reverts patterns that make things slower. It supports custom forward functions for realistic benchmarks:
from zmlx.patch import smart_patch
# Basic: benchmark raw forward pass
model = smart_patch(model, sample_input)
# Advanced: benchmark with actual generation
def gen_fn(model, sample):
return mlx_lm.generate(model, tokenizer, prompt="Hello", max_tokens=20)
model = smart_patch(model, sample, forward_fn=gen_fn, threshold=0.99)
# Result includes per-pattern speedups
result = model._zmlx_patch_result
print(result.benchmarks) # {'swiglu_mlp': 1.012, 'residual_norm': 0.971}
print(result.summary()) # what was kept and why
Autotuning
Replacement modules support threadgroup="auto" to search for the best threadgroup size on first invocation:
from zmlx.patch import patch
patch(model, threadgroup="auto") # autotunes each kernel on first call
The map_reduce() API also supports autotuning:
from zmlx.api import map_reduce
my_softmax = map_reduce(..., threadgroup="auto") # autotunes per-shape
Where ZMLX genuinely helps
- MoE model inference (softmax-gated) — 1.3–1.6x speedup on MoE models whose gate returns raw logits with softmax top-k routing (Qwen3-MoE, Mixtral). Fused gating eliminates 4+ kernel launches per MoE layer. Models with pre-computed gating (GLM-4, DeepSeek-V3 with
@mx.compile) are neutral — the pattern detects this and only fuses the combine step. - Custom ops that MLX doesn't have — SwiGLU, GeGLU, fused dropout, fused MoE gating, bit packing
- Training — fused
softmax_cross_entropyloss, correct weight gradients forrmsnorm_residual - Authoring new kernels — the
elementwise(),reduce(), andmap_reduce()APIs let you go from math formula to compiled Metal kernel in one line - Quantization — FP8 (E4M3/E5M2), NF4, int8, int4 dequantization with real bit-manipulation kernels
Where ZMLX won't help
- Dense model inference — batch-1 decode is ~95% weight-reading through quantized matmuls, already at hardware bandwidth limit.
patch(model)is safe (neutral) but won't speed things up. - Replacing MLX built-in norms/softmax —
mx.fast.rms_norm,mx.softmax,mx.fast.scaled_dot_product_attentionare Apple-optimized fast paths. Custom kernels add dispatch overhead.
Precision
All ZMLX Metal kernels compute internally in float32 regardless of input dtype. The compute_dtype parameter accepted by many kernel functions is deprecated and will be removed in a future release. Passing a non-None value will emit a DeprecationWarning.
Documentation
docs/QUICKSTART.md— 5-minute tutorialdocs/COOKBOOK.md— Recipes for common patternsdocs/KERNELS.md— Complete kernel catalog referencedocs/ARCHITECTURE.md— Design philosophy
Contributing
See CONTRIBUTING.md for setup, testing, and conventions.
License
MIT. See LICENSE.
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file zmlx-0.6.1.tar.gz.
File metadata
- Download URL: zmlx-0.6.1.tar.gz
- Upload date:
- Size: 147.5 kB
- Tags: Source
- Uploaded using Trusted Publishing? Yes
- Uploaded via: twine/6.1.0 CPython/3.13.7
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
c6eae4f0fdf91df6f4e30c40d2f1a0c972e1c22507bb73a76a6428ea1592a883
|
|
| MD5 |
383fddc07c58ac89953aad5b613f5e8d
|
|
| BLAKE2b-256 |
3f55801feb6adac1397bdd447b091cdb1ef1c18220c6095864e938568871eb98
|
Provenance
The following attestation bundles were made for zmlx-0.6.1.tar.gz:
Publisher:
release.yml on Hmbown/ZMLX
-
Statement:
-
Statement type:
https://in-toto.io/Statement/v1 -
Predicate type:
https://docs.pypi.org/attestations/publish/v1 -
Subject name:
zmlx-0.6.1.tar.gz -
Subject digest:
c6eae4f0fdf91df6f4e30c40d2f1a0c972e1c22507bb73a76a6428ea1592a883 - Sigstore transparency entry: 872465157
- Sigstore integration time:
-
Permalink:
Hmbown/ZMLX@518c664a6453b949164d3d0c92821a4dbfa29a98 -
Branch / Tag:
refs/tags/v0.6.1 - Owner: https://github.com/Hmbown
-
Access:
public
-
Token Issuer:
https://token.actions.githubusercontent.com -
Runner Environment:
github-hosted -
Publication workflow:
release.yml@518c664a6453b949164d3d0c92821a4dbfa29a98 -
Trigger Event:
release
-
Statement type:
File details
Details for the file zmlx-0.6.1-py3-none-any.whl.
File metadata
- Download URL: zmlx-0.6.1-py3-none-any.whl
- Upload date:
- Size: 114.9 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? Yes
- Uploaded via: twine/6.1.0 CPython/3.13.7
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
789e8a819d53ef487e9e575e7a9312ed1fce3f4b2ee59695b09f39a27fcf18c3
|
|
| MD5 |
c3796f73336aaf24de12a8b483f109db
|
|
| BLAKE2b-256 |
b1496762d23d84bd922330bedfa353a335d3b13bcfbc03cee5a8b9f2025183dc
|
Provenance
The following attestation bundles were made for zmlx-0.6.1-py3-none-any.whl:
Publisher:
release.yml on Hmbown/ZMLX
-
Statement:
-
Statement type:
https://in-toto.io/Statement/v1 -
Predicate type:
https://docs.pypi.org/attestations/publish/v1 -
Subject name:
zmlx-0.6.1-py3-none-any.whl -
Subject digest:
789e8a819d53ef487e9e575e7a9312ed1fce3f4b2ee59695b09f39a27fcf18c3 - Sigstore transparency entry: 872465160
- Sigstore integration time:
-
Permalink:
Hmbown/ZMLX@518c664a6453b949164d3d0c92821a4dbfa29a98 -
Branch / Tag:
refs/tags/v0.6.1 - Owner: https://github.com/Hmbown
-
Access:
public
-
Token Issuer:
https://token.actions.githubusercontent.com -
Runner Environment:
github-hosted -
Publication workflow:
release.yml@518c664a6453b949164d3d0c92821a4dbfa29a98 -
Trigger Event:
release
-
Statement type: