ZMLX: Triton for Apple Silicon. Write custom Metal GPU kernels for MLX with one-line ergonomics, automatic gradients, and built-in autotuning.
Project description
ZMLX — Triton for Apple Silicon
The Triton-like toolkit for MLX — write custom Metal GPU kernels from Python with one-line ergonomics, automatic gradients, and built-in autotuning. No raw Metal, no manual threadgroups, no boilerplate.
+33% decode on Qwen3-32B (dense) and +51% prompt / +36% decode on Qwen3-30B-A3B (MoE) with fused kernel patches. Benchmarks
pip install zmlx
Speed up your model in 3 lines:
import mlx_lm
from zmlx.patch import patch
model, tokenizer = mlx_lm.load("mlx-community/Qwen3-30B-A3B-4bit")
patch(model) # +51% prompt, +36% decode — done
Or write custom GPU kernels in one line:
from zmlx.api import elementwise
import mlx.core as mx
# Math formula → compiled Metal kernel → runs on GPU
mish = elementwise("x * tanh(log(1 + exp(x)))", name="mish")
y = mish(mx.random.normal((1024,)))
What's New in v0.4.0
- 1.33x decode on 32B dense, 1.51x prompt / 1.36x decode on 30B MoE — fused residual+RMSNorm and fused MoE gating patches (benchmarks)
- MoE patch — fused
top2_gating_softmax+moe_combineeliminates multiple memory round-trips in expert routing - High-level API —
elementwise(),reduce(),map_reduce()for kernel authoring in one line - JIT compiler —
@jitdecorator compiles Python scalar expressions to Metal - Testing & benchmarking —
zmlx.testing.assert_matches(),zmlx.bench.compare()for correctness verification and side-by-side timing - Profiling —
zmlx.profile.time_kernel(),dump_msl(),kernel_stats()for introspection - Training pipeline —
zmlx trainCLI for LoRA fine-tuning with ZMLX patches - Smart patching —
smart_patch()auto-benchmarks each pattern and keeps only what helps - Fused AdamW — single-kernel optimizer step reducing memory bandwidth
- Paged attention —
zmlx.nn.PagedAttentionfor high-throughput serving - 70+ kernel catalog — activations, norms, RoPE, attention, MoE, quantization, loss, bit ops
Why ZMLX?
When you need a custom GPU op on Apple Silicon, your options today are:
- Write raw Metal source strings, manage caching, figure out threadgroups, wire up autodiff manually
- Use ZMLX
ZMLX wraps mx.fast.metal_kernel and mx.custom_function to provide Triton-like ergonomics:
- One-line kernel authoring — define elementwise, reduction, and map-reduce ops from C expressions
- Automatic gradients — custom VJP backward passes (themselves Metal kernels) via
mx.custom_function - Define-once caching — kernels compile once, reused by source hash + config
- Autotuning — threadgroup size search with persistent caching
- Testing & benchmarking — verify against reference ops, compare timings side-by-side
- Model patching — swap MLX layers for fused ZMLX kernels with
patch(model)
Install
Requirements: macOS (Apple Silicon), Python >= 3.10, mlx >= 0.30.0
pip install zmlx
From source (development):
git clone https://github.com/Hmbown/ZMLX.git
cd ZMLX
pip install -e ".[dev]"
Quick Start
1. Custom elementwise kernel
from zmlx.api import elementwise
import mlx.core as mx
# Non-differentiable — just forward pass
fast_exp = elementwise("metal::exp(x)", name="fast_exp")
y = fast_exp(mx.random.normal((1024,)))
# Differentiable — with custom VJP
from zmlx import msl
silu = elementwise(
"kk_silu(x)",
name="my_silu",
grad_expr="g * (s + x * s * ((T)1 - s))",
grad_prelude="T s = kk_sigmoid(x);",
use_output=False,
header=msl.DEFAULT_HEADER,
)
gx = mx.grad(lambda z: silu(z).sum())(mx.random.normal((1024,)))
2. Custom reduction
from zmlx.api import reduce
import mlx.core as mx
my_sum = reduce(init="0.0f", update="acc + v", name="row_sum")
y = my_sum(mx.random.normal((8, 1024))) # shape (8,)
3. Two-pass map-reduce (softmax pattern)
from zmlx.api import map_reduce
import mlx.core as mx
my_softmax = map_reduce(
pass1={"init": "-INFINITY", "update": "max(acc1, x)", "reduce": "max(a, b)"},
pass2={"init": "0.0f", "update": "acc2 + exp(x - s1)", "reduce": "a + b"},
write="exp(x - s1) / s2",
name="my_softmax",
)
y = my_softmax(mx.random.normal((8, 1024)))
4. Test and benchmark your kernel
import zmlx
import mlx.core as mx
# Verify correctness
zmlx.testing.assert_matches(
my_softmax, lambda x: mx.softmax(x, axis=-1),
shapes=[(8, 1024), (32, 4096)],
)
# Benchmark
zmlx.bench.compare(
{"ZMLX": my_softmax, "MLX": lambda x: mx.softmax(x, axis=-1)},
shapes=[(1024, 4096), (4096, 4096)],
)
5. Lower-level building blocks
from zmlx import autograd, elementwise, msl
import mlx.core as mx
# Unary kernel (no gradient)
exp_kern = elementwise.unary(
name="kk_exp", expr="metal::exp(x)",
compute_dtype=mx.float32, header=msl.DEFAULT_HEADER,
)
# Binary kernel with custom VJP
mul_op = autograd.binary_from_expr(
name="safe_mul", fwd_expr="a * b",
vjp_lhs_expr="g * b", vjp_rhs_expr="g * a",
compute_dtype=mx.float32,
)
Kernel Catalog
ZMLX includes 70+ kernels organized by domain. Some are genuinely useful for custom workloads (loss, GLU fusions, bit ops, MoE gating). Others are reference implementations showing codegen patterns — correct but not faster than MLX built-ins for standard transformer shapes.
Full reference: docs/KERNELS.md.
| Module | Highlights |
|---|---|
loss |
softmax_cross_entropy — memory-efficient fused loss |
transformer |
swiglu, geglu, rmsnorm_residual (with full weight gradients), dropout — genuine fusions |
bits |
pack_bits, unpack_bits — no MLX equivalent |
moe |
top2_gating_softmax, moe_dispatch, moe_combine — fused expert routing (+36% decode on 30B MoE) |
quant |
FP8 (E4M3/E5M2), NF4, int8, int4 dequantization — real bit-manipulation kernels |
optimizers |
adamw_step — fused AdamW parameter update in a single kernel |
scan |
cumsum_lastdim — differentiable prefix sum |
norms |
rmsnorm, layernorm — parallel reduction. All norms compute in float32 internally |
softmax |
softmax_lastdim — map-reduce codegen showcase |
rope |
apply_rope, apply_rope_interleaved, apply_gqa_rope |
linear |
Reference fused-linear patterns (naive matmul, not for production) |
Architecture
Three-layer design. Full details: docs/ARCHITECTURE.md.
- Metal kernel infrastructure —
MetalKernelwrapper, in-process cache, stats tracking - Code generation & helpers — MSL templates, elementwise/autograd/rowwise APIs, autotuning
- Kernel catalog — domain modules built on layers 1 and 2
Benchmarks
Op-level (B=16, S=1024, D=1024, float32, M4 Max)
Run python benchmarks/microbench.py to reproduce on your hardware.
| Operation | MLX | ZMLX | Speedup |
|---|---|---|---|
| SwiGLU | 0.85 ms | 0.40 ms | 2.1x |
| Dropout | 3.12 ms | 0.38 ms | 8.2x |
| Top-K | 1.82 ms | 0.49 ms | 3.7x |
| Gather-Add | 0.54 ms | 0.41 ms | 1.3x |
| Softmax | 0.36 ms | 0.41 ms | 0.90x |
| RMSNorm | 0.37 ms | 0.41 ms | 0.90x |
| Sum | 0.19 ms | 0.36 ms | 0.53x |
| CumSum | 0.30 ms | 0.59 ms | 0.51x |
ZMLX wins big on fused operations that MLX doesn't provide as single ops (SwiGLU, fused-RNG dropout, fused gather-add). MLX's built-in operations (mx.fast.rms_norm, mx.softmax, reductions) are already highly optimized and should not be replaced.
Model-level inference
LLM inference is memory-bandwidth-bound: fused kernels shine on large models where each saved memory round-trip matters. The effect scales with model size — small models see no benefit, large models see significant speedups.
Qwen3-32B-4bit (64 layers, ~18 GB) — M4 Max, 36 GB
| Config | Prompt (tok/s) | Decode (tok/s) | vs Baseline |
|---|---|---|---|
| Baseline (MLX) | 107 | 13.5 | — |
| ZMLX fused activations | 108 | 14.2 | 1.01x / 1.05x |
| ZMLX all patches | 127 | 18.0 | 1.19x / 1.33x |
+33% decode throughput on a 32B model — 64 layers of fused residual+RMSNorm, each saving a full memory round-trip.
Qwen3-30B-A3B-4bit (MoE, 48 layers, 3B active/30B total) — python benchmarks/inference_benchmark.py --models qwen3-30b-a3b --selective
| Config | Prompt (tok/s) | Decode (tok/s) | vs Baseline |
|---|---|---|---|
| Baseline (MLX) | 1,083 | 116 | — |
| ZMLX fused activations | 1,635 | 158 | 1.51x / 1.36x |
+36% decode throughput on MoE models — fused gating (
top2_gating_softmax) and combine (moe_combine) kernels eliminate multiple memory round-trips in the expert routing path.
Qwen3-8B-4bit (32 layers, ~5 GB) — python benchmarks/inference_benchmark.py --models qwen3-8b --selective
| Config | Prompt (tok/s) | Decode (tok/s) | vs Baseline |
|---|---|---|---|
| Baseline (MLX) | 676 | 75 | — |
| ZMLX fused activations | 675 | 76 | 1.00x / 1.01x |
Llama-3.2-1B-Instruct-4bit (16 layers, ~0.8 GB) — python benchmarks/llama_benchmark.py
| Config | Prompt (tok/s) | Decode (tok/s) | vs Baseline |
|---|---|---|---|
| Baseline (MLX) | 3,913 | 377 | — |
| ZMLX fused activations | 3,804 | 378 | 0.97x / 1.00x |
| ZMLX all patches | 3,705 | 366 | 0.95x / 0.97x |
When do patches help?
- Large Dense Models (8B+): Use all patches. Bandwidth-bound, so fused residual+norm saves real throughput.
- MoE Models: Use fused activations (
--selective). Themoe_mlppatch provides a massive +36% boost. - Small Models (< 3B): Neutral. Overhead often outweighs fusion gains. Use
smart_patchto be sure.
from zmlx.patch import smart_patch
import mlx.core as mx
# Auto-benchmark each pattern, keep only what helps
sample = mx.array([tokenizer.encode("Hello")])
model = smart_patch(model, sample)
Or use presets if you know your workload:
from zmlx.patch import patch, FUSED_ACTIVATIONS, TRAINING_RECOMMENDED
patch(model) # large models (8B+): all patches
patch(model, patterns=FUSED_ACTIVATIONS) # MoE/small: activations + MoE gating
patch(model, patterns=TRAINING_RECOMMENDED) # training: activations + norms
Smart patching
smart_patch applies each candidate pattern, benchmarks the model's forward pass, and automatically reverts patterns that make things slower. It supports custom forward functions for realistic benchmarks:
from zmlx.patch import smart_patch
# Basic: benchmark raw forward pass
model = smart_patch(model, sample_input)
# Advanced: benchmark with actual generation
def gen_fn(model, sample):
return mlx_lm.generate(model, tokenizer, prompt="Hello", max_tokens=20)
model = smart_patch(model, sample, forward_fn=gen_fn, threshold=0.99)
# Result includes per-pattern speedups
result = model._zmlx_patch_result
print(result.benchmarks) # {'swiglu_mlp': 1.012, 'residual_norm': 0.971}
print(result.summary()) # what was kept and why
Autotuning
Replacement modules support threadgroup="auto" to search for the best threadgroup size on first invocation:
from zmlx.patch import patch
patch(model, threadgroup="auto") # autotunes each kernel on first call
The map_reduce() API also supports autotuning:
from zmlx.api import map_reduce
my_softmax = map_reduce(..., threadgroup="auto") # autotunes per-shape
Where ZMLX genuinely helps
- Large model inference — 1.33x decode on 32B dense (fused residual+norm), 1.51x prompt / 1.36x decode on 30B MoE (fused gating+combine)
- Custom ops that MLX doesn't have — SwiGLU, GeGLU, fused dropout, fused MoE gating, bit packing
- Training — fused
softmax_cross_entropyloss, correct weight gradients forrmsnorm_residual - Authoring new kernels — the
elementwise(),reduce(), andmap_reduce()APIs let you go from math formula to compiled Metal kernel in one line - Quantization — FP8 (E4M3/E5M2), NF4, int8, int4 dequantization with real bit-manipulation kernels
Precision
All ZMLX Metal kernels compute internally in float32 regardless of input dtype. The compute_dtype parameter accepted by many kernel functions is deprecated and will be removed in a future release. Passing a non-None value will emit a DeprecationWarning.
Documentation
docs/QUICKSTART.md— 5-minute tutorialdocs/COOKBOOK.md— Recipes for common patternsdocs/KERNELS.md— Complete kernel catalog referencedocs/ARCHITECTURE.md— Design philosophy
Contributing
See CONTRIBUTING.md for setup, testing, and conventions.
License
MIT. See LICENSE.
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file zmlx-0.4.2.tar.gz.
File metadata
- Download URL: zmlx-0.4.2.tar.gz
- Upload date:
- Size: 143.8 kB
- Tags: Source
- Uploaded using Trusted Publishing? Yes
- Uploaded via: twine/6.1.0 CPython/3.13.7
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
00ba20605c84e2169b73b8067f3dd3fdbfeb9d14486e31b8f94fef7c9f2e2b3a
|
|
| MD5 |
b2b88a0e513111562c7697fc93ad5987
|
|
| BLAKE2b-256 |
a6284f2485c422c170c1e8fdd7f1c04b87e6993acec4f07a571c20aab5b3eb47
|
Provenance
The following attestation bundles were made for zmlx-0.4.2.tar.gz:
Publisher:
release.yml on Hmbown/ZMLX
-
Statement:
-
Statement type:
https://in-toto.io/Statement/v1 -
Predicate type:
https://docs.pypi.org/attestations/publish/v1 -
Subject name:
zmlx-0.4.2.tar.gz -
Subject digest:
00ba20605c84e2169b73b8067f3dd3fdbfeb9d14486e31b8f94fef7c9f2e2b3a - Sigstore transparency entry: 872225472
- Sigstore integration time:
-
Permalink:
Hmbown/ZMLX@d0506431a5f5b470934ceb4f65627aeed4c83c33 -
Branch / Tag:
refs/tags/v0.4.2 - Owner: https://github.com/Hmbown
-
Access:
public
-
Token Issuer:
https://token.actions.githubusercontent.com -
Runner Environment:
github-hosted -
Publication workflow:
release.yml@d0506431a5f5b470934ceb4f65627aeed4c83c33 -
Trigger Event:
release
-
Statement type:
File details
Details for the file zmlx-0.4.2-py3-none-any.whl.
File metadata
- Download URL: zmlx-0.4.2-py3-none-any.whl
- Upload date:
- Size: 112.9 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? Yes
- Uploaded via: twine/6.1.0 CPython/3.13.7
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
5e212db3d62fa778bb33c4e0c9c3d71b6e15ab31730aebc7d94cdd46bb03635c
|
|
| MD5 |
98c623da3658ca1bef836b3f632797bc
|
|
| BLAKE2b-256 |
d3b25468422be1229c04600d76a571c174332365a677d0cf9fae16ac1de91310
|
Provenance
The following attestation bundles were made for zmlx-0.4.2-py3-none-any.whl:
Publisher:
release.yml on Hmbown/ZMLX
-
Statement:
-
Statement type:
https://in-toto.io/Statement/v1 -
Predicate type:
https://docs.pypi.org/attestations/publish/v1 -
Subject name:
zmlx-0.4.2-py3-none-any.whl -
Subject digest:
5e212db3d62fa778bb33c4e0c9c3d71b6e15ab31730aebc7d94cdd46bb03635c - Sigstore transparency entry: 872225473
- Sigstore integration time:
-
Permalink:
Hmbown/ZMLX@d0506431a5f5b470934ceb4f65627aeed4c83c33 -
Branch / Tag:
refs/tags/v0.4.2 - Owner: https://github.com/Hmbown
-
Access:
public
-
Token Issuer:
https://token.actions.githubusercontent.com -
Runner Environment:
github-hosted -
Publication workflow:
release.yml@d0506431a5f5b470934ceb4f65627aeed4c83c33 -
Trigger Event:
release
-
Statement type: