Skip to main content

ZMLX: Triton for Apple Silicon. Write custom Metal GPU kernels for MLX with one-line ergonomics, automatic gradients, and built-in autotuning.

Project description

ZMLX — Triton for Apple Silicon

PyPI Python 3.10+ License: MIT Platform: macOS Apple Silicon

The Triton-like toolkit for MLX — write custom Metal GPU kernels from Python with one-line ergonomics, automatic gradients, and built-in autotuning. No raw Metal, no manual threadgroups, no boilerplate.

+33% decode on Qwen3-32B (dense) and +51% prompt / +36% decode on Qwen3-30B-A3B (MoE) with fused kernel patches. Benchmarks

pip install zmlx

Speed up your model in 3 lines:

import mlx_lm
from zmlx.patch import patch

model, tokenizer = mlx_lm.load("mlx-community/Qwen3-30B-A3B-4bit")
patch(model)  # +51% prompt, +36% decode — done

Or write custom GPU kernels in one line:

from zmlx.api import elementwise
import mlx.core as mx

# Math formula → compiled Metal kernel → runs on GPU
mish = elementwise("x * tanh(log(1 + exp(x)))", name="mish")
y = mish(mx.random.normal((1024,)))

What's New in v0.4.0

  • 1.33x decode on 32B dense, 1.51x prompt / 1.36x decode on 30B MoE — fused residual+RMSNorm and fused MoE gating patches (benchmarks)
  • MoE patch — fused top2_gating_softmax + moe_combine eliminates multiple memory round-trips in expert routing
  • High-level APIelementwise(), reduce(), map_reduce() for kernel authoring in one line
  • JIT compiler@jit decorator compiles Python scalar expressions to Metal
  • Testing & benchmarkingzmlx.testing.assert_matches(), zmlx.bench.compare() for correctness verification and side-by-side timing
  • Profilingzmlx.profile.time_kernel(), dump_msl(), kernel_stats() for introspection
  • Training pipelinezmlx train CLI for LoRA fine-tuning with ZMLX patches
  • Smart patchingsmart_patch() auto-benchmarks each pattern and keeps only what helps
  • Fused AdamW — single-kernel optimizer step reducing memory bandwidth
  • Paged attentionzmlx.nn.PagedAttention for high-throughput serving
  • 70+ kernel catalog — activations, norms, RoPE, attention, MoE, quantization, loss, bit ops

Why ZMLX?

When you need a custom GPU op on Apple Silicon, your options today are:

  1. Write raw Metal source strings, manage caching, figure out threadgroups, wire up autodiff manually
  2. Use ZMLX

ZMLX wraps mx.fast.metal_kernel and mx.custom_function to provide Triton-like ergonomics:

  • One-line kernel authoring — define elementwise, reduction, and map-reduce ops from C expressions
  • Automatic gradients — custom VJP backward passes (themselves Metal kernels) via mx.custom_function
  • Define-once caching — kernels compile once, reused by source hash + config
  • Autotuning — threadgroup size search with persistent caching
  • Testing & benchmarking — verify against reference ops, compare timings side-by-side
  • Model patching — swap MLX layers for fused ZMLX kernels with patch(model)

Install

Requirements: macOS (Apple Silicon), Python >= 3.10, mlx >= 0.30.0

pip install zmlx

From source (development):

git clone https://github.com/Hmbown/ZMLX.git
cd ZMLX
pip install -e ".[dev]"

Quick Start

1. Custom elementwise kernel

from zmlx.api import elementwise
import mlx.core as mx

# Non-differentiable — just forward pass
fast_exp = elementwise("metal::exp(x)", name="fast_exp")
y = fast_exp(mx.random.normal((1024,)))

# Differentiable — with custom VJP
from zmlx import msl

silu = elementwise(
    "kk_silu(x)",
    name="my_silu",
    grad_expr="g * (s + x * s * ((T)1 - s))",
    grad_prelude="T s = kk_sigmoid(x);",
    use_output=False,
    header=msl.DEFAULT_HEADER,
)
gx = mx.grad(lambda z: silu(z).sum())(mx.random.normal((1024,)))

2. Custom reduction

from zmlx.api import reduce
import mlx.core as mx

my_sum = reduce(init="0.0f", update="acc + v", name="row_sum")
y = my_sum(mx.random.normal((8, 1024)))  # shape (8,)

3. Two-pass map-reduce (softmax pattern)

from zmlx.api import map_reduce
import mlx.core as mx

my_softmax = map_reduce(
    pass1={"init": "-INFINITY", "update": "max(acc1, x)", "reduce": "max(a, b)"},
    pass2={"init": "0.0f", "update": "acc2 + exp(x - s1)", "reduce": "a + b"},
    write="exp(x - s1) / s2",
    name="my_softmax",
)
y = my_softmax(mx.random.normal((8, 1024)))

4. Test and benchmark your kernel

import zmlx
import mlx.core as mx

# Verify correctness
zmlx.testing.assert_matches(
    my_softmax, lambda x: mx.softmax(x, axis=-1),
    shapes=[(8, 1024), (32, 4096)],
)

# Benchmark
zmlx.bench.compare(
    {"ZMLX": my_softmax, "MLX": lambda x: mx.softmax(x, axis=-1)},
    shapes=[(1024, 4096), (4096, 4096)],
)

5. Lower-level building blocks

from zmlx import autograd, elementwise, msl
import mlx.core as mx

# Unary kernel (no gradient)
exp_kern = elementwise.unary(
    name="kk_exp", expr="metal::exp(x)",
    compute_dtype=mx.float32, header=msl.DEFAULT_HEADER,
)

# Binary kernel with custom VJP
mul_op = autograd.binary_from_expr(
    name="safe_mul", fwd_expr="a * b",
    vjp_lhs_expr="g * b", vjp_rhs_expr="g * a",
    compute_dtype=mx.float32,
)

Kernel Catalog

ZMLX includes 70+ kernels organized by domain. Some are genuinely useful for custom workloads (loss, GLU fusions, bit ops, MoE gating). Others are reference implementations showing codegen patterns — correct but not faster than MLX built-ins for standard transformer shapes.

Full reference: docs/KERNELS.md.

Module Highlights
loss softmax_cross_entropy — memory-efficient fused loss
transformer swiglu, geglu, rmsnorm_residual (with full weight gradients), dropout — genuine fusions
bits pack_bits, unpack_bits — no MLX equivalent
moe top2_gating_softmax, moe_dispatch, moe_combine — fused expert routing (+36% decode on 30B MoE)
quant FP8 (E4M3/E5M2), NF4, int8, int4 dequantization — real bit-manipulation kernels
optimizers adamw_step — fused AdamW parameter update in a single kernel
scan cumsum_lastdim — differentiable prefix sum
norms rmsnorm, layernorm — parallel reduction. All norms compute in float32 internally
softmax softmax_lastdim — map-reduce codegen showcase
rope apply_rope, apply_rope_interleaved, apply_gqa_rope
linear Reference fused-linear patterns (naive matmul, not for production)

Architecture

Three-layer design. Full details: docs/ARCHITECTURE.md.

  1. Metal kernel infrastructureMetalKernel wrapper, in-process cache, stats tracking
  2. Code generation & helpers — MSL templates, elementwise/autograd/rowwise APIs, autotuning
  3. Kernel catalog — domain modules built on layers 1 and 2

Benchmarks

Op-level (B=16, S=1024, D=1024, float32, M4 Max)

Run python benchmarks/microbench.py to reproduce on your hardware.

Operation MLX ZMLX Speedup
SwiGLU 0.85 ms 0.40 ms 2.1x
Dropout 3.12 ms 0.38 ms 8.2x
Top-K 1.82 ms 0.49 ms 3.7x
Gather-Add 0.54 ms 0.41 ms 1.3x
Softmax 0.36 ms 0.41 ms 0.90x
RMSNorm 0.37 ms 0.41 ms 0.90x
Sum 0.19 ms 0.36 ms 0.53x
CumSum 0.30 ms 0.59 ms 0.51x

ZMLX wins big on fused operations that MLX doesn't provide as single ops (SwiGLU, fused-RNG dropout, fused gather-add). MLX's built-in operations (mx.fast.rms_norm, mx.softmax, reductions) are already highly optimized and should not be replaced.

Model-level inference

LLM inference is memory-bandwidth-bound: fused kernels shine on large models where each saved memory round-trip matters. The effect scales with model size — small models see no benefit, large models see significant speedups.

Qwen3-32B-4bit (64 layers, ~18 GB) — M4 Max, 36 GB

Config Prompt (tok/s) Decode (tok/s) vs Baseline
Baseline (MLX) 107 13.5
ZMLX fused activations 108 14.2 1.01x / 1.05x
ZMLX all patches 127 18.0 1.19x / 1.33x

+33% decode throughput on a 32B model — 64 layers of fused residual+RMSNorm, each saving a full memory round-trip.

Qwen3-30B-A3B-4bit (MoE, 48 layers, 3B active/30B total)python benchmarks/inference_benchmark.py --models qwen3-30b-a3b --selective

Config Prompt (tok/s) Decode (tok/s) vs Baseline
Baseline (MLX) 1,083 116
ZMLX fused activations 1,635 158 1.51x / 1.36x

+36% decode throughput on MoE models — fused gating (top2_gating_softmax) and combine (moe_combine) kernels eliminate multiple memory round-trips in the expert routing path.

Qwen3-8B-4bit (32 layers, ~5 GB)python benchmarks/inference_benchmark.py --models qwen3-8b --selective

Config Prompt (tok/s) Decode (tok/s) vs Baseline
Baseline (MLX) 676 75
ZMLX fused activations 675 76 1.00x / 1.01x

Llama-3.2-1B-Instruct-4bit (16 layers, ~0.8 GB)python benchmarks/llama_benchmark.py

Config Prompt (tok/s) Decode (tok/s) vs Baseline
Baseline (MLX) 3,913 377
ZMLX fused activations 3,804 378 0.97x / 1.00x
ZMLX all patches 3,705 366 0.95x / 0.97x

When do patches help?

  • Large Dense Models (8B+): Use all patches. Bandwidth-bound, so fused residual+norm saves real throughput.
  • MoE Models: Use fused activations (--selective). The moe_mlp patch provides a massive +36% boost.
  • Small Models (< 3B): Neutral. Overhead often outweighs fusion gains. Use smart_patch to be sure.
from zmlx.patch import smart_patch
import mlx.core as mx

# Auto-benchmark each pattern, keep only what helps
sample = mx.array([tokenizer.encode("Hello")])
model = smart_patch(model, sample)

Or use presets if you know your workload:

from zmlx.patch import patch, FUSED_ACTIVATIONS, TRAINING_RECOMMENDED
patch(model)                                   # large models (8B+): all patches
patch(model, patterns=FUSED_ACTIVATIONS)       # MoE/small: activations + MoE gating
patch(model, patterns=TRAINING_RECOMMENDED)    # training: activations + norms

Smart patching

smart_patch applies each candidate pattern, benchmarks the model's forward pass, and automatically reverts patterns that make things slower. It supports custom forward functions for realistic benchmarks:

from zmlx.patch import smart_patch

# Basic: benchmark raw forward pass
model = smart_patch(model, sample_input)

# Advanced: benchmark with actual generation
def gen_fn(model, sample):
    return mlx_lm.generate(model, tokenizer, prompt="Hello", max_tokens=20)

model = smart_patch(model, sample, forward_fn=gen_fn, threshold=0.99)

# Result includes per-pattern speedups
result = model._zmlx_patch_result
print(result.benchmarks)    # {'swiglu_mlp': 1.012, 'residual_norm': 0.971}
print(result.summary())     # what was kept and why

Autotuning

Replacement modules support threadgroup="auto" to search for the best threadgroup size on first invocation:

from zmlx.patch import patch
patch(model, threadgroup="auto")  # autotunes each kernel on first call

The map_reduce() API also supports autotuning:

from zmlx.api import map_reduce
my_softmax = map_reduce(..., threadgroup="auto")  # autotunes per-shape

Where ZMLX genuinely helps

  • Large model inference — 1.33x decode on 32B dense (fused residual+norm), 1.51x prompt / 1.36x decode on 30B MoE (fused gating+combine)
  • Custom ops that MLX doesn't have — SwiGLU, GeGLU, fused dropout, fused MoE gating, bit packing
  • Training — fused softmax_cross_entropy loss, correct weight gradients for rmsnorm_residual
  • Authoring new kernels — the elementwise(), reduce(), and map_reduce() APIs let you go from math formula to compiled Metal kernel in one line
  • Quantization — FP8 (E4M3/E5M2), NF4, int8, int4 dequantization with real bit-manipulation kernels

Precision

All ZMLX Metal kernels compute internally in float32 regardless of input dtype. The compute_dtype parameter accepted by many kernel functions is deprecated and will be removed in a future release. Passing a non-None value will emit a DeprecationWarning.


Documentation


Contributing

See CONTRIBUTING.md for setup, testing, and conventions.


License

MIT. See LICENSE.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

zmlx-0.4.2.tar.gz (143.8 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

zmlx-0.4.2-py3-none-any.whl (112.9 kB view details)

Uploaded Python 3

File details

Details for the file zmlx-0.4.2.tar.gz.

File metadata

  • Download URL: zmlx-0.4.2.tar.gz
  • Upload date:
  • Size: 143.8 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: twine/6.1.0 CPython/3.13.7

File hashes

Hashes for zmlx-0.4.2.tar.gz
Algorithm Hash digest
SHA256 00ba20605c84e2169b73b8067f3dd3fdbfeb9d14486e31b8f94fef7c9f2e2b3a
MD5 b2b88a0e513111562c7697fc93ad5987
BLAKE2b-256 a6284f2485c422c170c1e8fdd7f1c04b87e6993acec4f07a571c20aab5b3eb47

See more details on using hashes here.

Provenance

The following attestation bundles were made for zmlx-0.4.2.tar.gz:

Publisher: release.yml on Hmbown/ZMLX

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

File details

Details for the file zmlx-0.4.2-py3-none-any.whl.

File metadata

  • Download URL: zmlx-0.4.2-py3-none-any.whl
  • Upload date:
  • Size: 112.9 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: twine/6.1.0 CPython/3.13.7

File hashes

Hashes for zmlx-0.4.2-py3-none-any.whl
Algorithm Hash digest
SHA256 5e212db3d62fa778bb33c4e0c9c3d71b6e15ab31730aebc7d94cdd46bb03635c
MD5 98c623da3658ca1bef836b3f632797bc
BLAKE2b-256 d3b25468422be1229c04600d76a571c174332365a677d0cf9fae16ac1de91310

See more details on using hashes here.

Provenance

The following attestation bundles were made for zmlx-0.4.2-py3-none-any.whl:

Publisher: release.yml on Hmbown/ZMLX

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page