ZMLX: Numba-for-MLX. Ergonomic helpers for authoring and autotuning MLX custom Metal kernels from Python.
Project description
ZMLX
Numba-style helpers for authoring, autotuning, and differentiating custom Metal kernels on Apple silicon.
pip install zmlx
from zmlx import autograd, msl
import mlx.core as mx
silu = autograd.unary_from_expr(
name="my_silu", fwd_expr="x * kk_sigmoid(x)",
vjp_expr="g * (s + x * s * ((T)1 - s))",
compute_dtype=mx.float32, use_output=False,
vjp_prelude="T s = kk_sigmoid(x);",
header=msl.DEFAULT_HEADER,
)
y = silu(mx.random.normal((1024,))) # runs on GPU, supports mx.grad
Why ZMLX?
- Define-once caching — kernels compile once and are reused across calls (keyed on source hash + config).
- One-line ops — create elementwise kernels from a C expression:
elementwise.unary(name=..., expr=...). - Differentiable kernels — attach custom VJP backward passes (themselves Metal kernels) via
mx.custom_function. - Autotuning — search threadgroup sizes automatically and cache the winners.
- 70+ ready-to-use catalog kernels — activations, softmax, norms, RoPE, transformer fused ops, reductions, quantization, and more.
Install
Requirements: macOS (Apple Silicon), Python >= 3.10, mlx >= 0.30.0
pip install zmlx
From source (development):
git clone https://github.com/Hmbown/ZMLX.git
cd ZMLX
pip install -e ".[dev]"
Quick Examples
1. Elementwise kernel from a C expression
from zmlx import elementwise, msl
import mlx.core as mx
exp_fast = elementwise.unary(
name="kk_exp",
expr="metal::exp(x)",
compute_dtype=mx.float32,
header=msl.DEFAULT_HEADER,
)
x = mx.random.normal((1024,)).astype(mx.float16)
y = exp_fast(x)
2. Differentiable kernel with custom VJP
from zmlx import autograd, msl
import mlx.core as mx
exp_trainable = autograd.unary_from_expr(
name="kk_exp_vjp",
fwd_expr="metal::exp(x)",
vjp_expr="g * y",
compute_dtype=mx.float32,
use_output=True,
header=msl.DEFAULT_HEADER,
)
def loss(z):
return exp_trainable(z).sum()
x = mx.random.normal((1024,))
gx = mx.grad(loss)(x)
mx.eval(gx)
3. Catalog kernel (ready-to-use)
from zmlx.kernels import softmax, norms, transformer
import mlx.core as mx
x = mx.random.normal((8, 1024)).astype(mx.float16)
w = mx.ones((1024,), dtype=mx.float16)
y = softmax.softmax_lastdim(x) # rowwise softmax
z = norms.rmsnorm(x, w) # differentiable RMSNorm
s = transformer.swiglu(mx.random.normal((8, 2048)).astype(mx.float16)) # fused SwiGLU
Kernel Catalog
17 modules, 70+ kernels (including gradient helpers) organized by domain. Full reference: docs/KERNELS.md.
| Module | Count | Highlights |
|---|---|---|
activations |
19 | exp, sigmoid, relu, silu, gelu_tanh, softplus + grad variants |
transformer |
10 | swiglu, geglu, rmsnorm_residual, layernorm_residual, dropout |
softmax |
3 | softmax_lastdim, log_softmax_lastdim, softmax_grad |
norms |
6 | rmsnorm, layernorm, rmsnorm_grad, layer_norm_dropout |
attention |
4 | masked_softmax, scale_mask_softmax, logsumexp_lastdim |
rope |
3 | apply_rope, apply_rope_interleaved, apply_gqa_rope |
reductions |
7 | sum, mean, max, var, std, argmax, topk (all lastdim) |
fused |
6 | add, mul, bias_gelu_tanh, bias_silu, silu_mul_grad, add_bias |
linear |
4 | fused_linear_bias_silu, fused_linear_bias_gelu, fused_linear_rmsnorm |
loss |
1 | softmax_cross_entropy |
quant |
3 | dequantize_int8, dequantize_silu_int8, dequantize_int4 |
bits |
2 | pack_bits, unpack_bits |
moe |
1 | top2_gating_softmax |
image |
2 | resize_bilinear, depthwise_conv_3x3 |
indexing |
2 | fused_gather_add, fused_scatter_add |
scan |
2 | cumsum_lastdim, cumsum_grad |
Architecture
Three-layer design. Full details: docs/ARCHITECTURE.md.
- Metal kernel infrastructure —
MetalKernelwrapper, in-process cache, stats tracking - Code generation & helpers — MSL templates, elementwise/autograd/rowwise APIs, autotuning
- Kernel catalog — 17 domain modules built on layers 1 and 2
Benchmarks
Run python benchmarks/microbench.py to reproduce. Headline numbers on M4 (vs MLX reference ops):
- Dropout: ~7x faster (fused Metal-side LCG RNG)
- SwiGLU: ~2-3x faster (fused silu + gate multiply)
Results vary by shape, dtype, and chip. See benchmarks/ for the full harness.
Roadmap
- Flash Attention tiles (shared memory, 16x16 / 32x32)
- Expanded quantization (int4 matmul, mixed-precision patterns)
- Zig frontend via MLX-C (multi-language kernel generation)
- JVP support for all catalog kernels
- Community-contributed kernels
Contributing
See CONTRIBUTING.md for setup, testing, and conventions.
License
MIT. See LICENSE.
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file zmlx-0.2.0.tar.gz.
File metadata
- Download URL: zmlx-0.2.0.tar.gz
- Upload date:
- Size: 52.3 kB
- Tags: Source
- Uploaded using Trusted Publishing? Yes
- Uploaded via: twine/6.1.0 CPython/3.13.7
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
43c88fba52a5e1a3d5d9eadaa8f0fafdd4ec65a46e7487b8789481928d28926a
|
|
| MD5 |
574eda0d2ab314b54058b98cb3d7c6df
|
|
| BLAKE2b-256 |
0ed5daf4649f5c3cbf7ddc046dbe2cfc6e7b97edff3ccf4a57db91a16aee7b83
|
Provenance
The following attestation bundles were made for zmlx-0.2.0.tar.gz:
Publisher:
release.yml on Hmbown/ZMLX
-
Statement:
-
Statement type:
https://in-toto.io/Statement/v1 -
Predicate type:
https://docs.pypi.org/attestations/publish/v1 -
Subject name:
zmlx-0.2.0.tar.gz -
Subject digest:
43c88fba52a5e1a3d5d9eadaa8f0fafdd4ec65a46e7487b8789481928d28926a - Sigstore transparency entry: 870847487
- Sigstore integration time:
-
Permalink:
Hmbown/ZMLX@9178766412702f1aba8d8c932ee96931fe17c299 -
Branch / Tag:
refs/tags/v0.2.0 - Owner: https://github.com/Hmbown
-
Access:
public
-
Token Issuer:
https://token.actions.githubusercontent.com -
Runner Environment:
github-hosted -
Publication workflow:
release.yml@9178766412702f1aba8d8c932ee96931fe17c299 -
Trigger Event:
release
-
Statement type:
File details
Details for the file zmlx-0.2.0-py3-none-any.whl.
File metadata
- Download URL: zmlx-0.2.0-py3-none-any.whl
- Upload date:
- Size: 51.3 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? Yes
- Uploaded via: twine/6.1.0 CPython/3.13.7
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
42510c8d3eebacd0f61f32d1af482d6ff6d9989921bf707175a650f332c5b92b
|
|
| MD5 |
0cf3754d976352c16a41f212f6a89d03
|
|
| BLAKE2b-256 |
d80ae7b97543665104a2e820747ce1ae48bef5b0370e4b84840c74033b807c3e
|
Provenance
The following attestation bundles were made for zmlx-0.2.0-py3-none-any.whl:
Publisher:
release.yml on Hmbown/ZMLX
-
Statement:
-
Statement type:
https://in-toto.io/Statement/v1 -
Predicate type:
https://docs.pypi.org/attestations/publish/v1 -
Subject name:
zmlx-0.2.0-py3-none-any.whl -
Subject digest:
42510c8d3eebacd0f61f32d1af482d6ff6d9989921bf707175a650f332c5b92b - Sigstore transparency entry: 870847490
- Sigstore integration time:
-
Permalink:
Hmbown/ZMLX@9178766412702f1aba8d8c932ee96931fe17c299 -
Branch / Tag:
refs/tags/v0.2.0 - Owner: https://github.com/Hmbown
-
Access:
public
-
Token Issuer:
https://token.actions.githubusercontent.com -
Runner Environment:
github-hosted -
Publication workflow:
release.yml@9178766412702f1aba8d8c932ee96931fe17c299 -
Trigger Event:
release
-
Statement type: