Skip to main content

ZMLX: Numba-for-MLX. Ergonomic helpers for authoring and autotuning MLX custom Metal kernels from Python.

Project description

ZMLX

Python 3.10+ License: MIT Platform: macOS Apple Silicon

Numba-style helpers for authoring, autotuning, and differentiating custom Metal kernels on Apple silicon.

pip install zmlx
from zmlx import autograd, msl
import mlx.core as mx

silu = autograd.unary_from_expr(
    name="my_silu", fwd_expr="x * kk_sigmoid(x)",
    vjp_expr="g * (s + x * s * ((T)1 - s))",
    compute_dtype=mx.float32, use_output=False,
    vjp_prelude="T s = kk_sigmoid(x);",
    header=msl.DEFAULT_HEADER,
)
y = silu(mx.random.normal((1024,)))  # runs on GPU, supports mx.grad

Why ZMLX?

  • Define-once caching — kernels compile once and are reused across calls (keyed on source hash + config).
  • One-line ops — create elementwise kernels from a C expression: elementwise.unary(name=..., expr=...).
  • Differentiable kernels — attach custom VJP backward passes (themselves Metal kernels) via mx.custom_function.
  • Autotuning — search threadgroup sizes automatically and cache the winners.
  • 70+ ready-to-use catalog kernels — activations, softmax, norms, RoPE, transformer fused ops, reductions, quantization, and more.

Install

Requirements: macOS (Apple Silicon), Python >= 3.10, mlx >= 0.30.0

pip install zmlx

From source (development):

git clone https://github.com/Hmbown/ZMLX.git
cd ZMLX
pip install -e ".[dev]"

Quick Examples

1. Elementwise kernel from a C expression

from zmlx import elementwise, msl
import mlx.core as mx

exp_fast = elementwise.unary(
    name="kk_exp",
    expr="metal::exp(x)",
    compute_dtype=mx.float32,
    header=msl.DEFAULT_HEADER,
)

x = mx.random.normal((1024,)).astype(mx.float16)
y = exp_fast(x)

2. Differentiable kernel with custom VJP

from zmlx import autograd, msl
import mlx.core as mx

exp_trainable = autograd.unary_from_expr(
    name="kk_exp_vjp",
    fwd_expr="metal::exp(x)",
    vjp_expr="g * y",
    compute_dtype=mx.float32,
    use_output=True,
    header=msl.DEFAULT_HEADER,
)

def loss(z):
    return exp_trainable(z).sum()

x = mx.random.normal((1024,))
gx = mx.grad(loss)(x)
mx.eval(gx)

3. Catalog kernel (ready-to-use)

from zmlx.kernels import softmax, norms, transformer
import mlx.core as mx

x = mx.random.normal((8, 1024)).astype(mx.float16)
w = mx.ones((1024,), dtype=mx.float16)

y = softmax.softmax_lastdim(x)           # rowwise softmax
z = norms.rmsnorm(x, w)                   # differentiable RMSNorm
s = transformer.swiglu(mx.random.normal((8, 2048)).astype(mx.float16))  # fused SwiGLU

Kernel Catalog

18 modules, 70+ kernels (including gradient helpers) organized by domain. Full reference: docs/KERNELS.md.

Module Count Highlights
activations 19 exp, sigmoid, relu, silu, gelu_tanh, softplus + grad variants
transformer 10 swiglu, geglu, rmsnorm_residual, layernorm_residual, dropout
vlsp 4 fused_recurrent_step, depth_gate_sigmoid, grpo_advantage_norm
softmax 3 softmax_lastdim, log_softmax_lastdim, softmax_grad
norms 6 rmsnorm, layernorm, rmsnorm_grad, layer_norm_dropout
attention 4 masked_softmax, scale_mask_softmax, logsumexp_lastdim
rope 3 apply_rope, apply_rope_interleaved, apply_gqa_rope
reductions 7 sum, mean, max, var, std, argmax, topk (all lastdim)
fused 6 add, mul, bias_gelu_tanh, bias_silu, silu_mul_grad, add_bias
linear 4 fused_linear_bias_silu, fused_linear_bias_gelu, fused_linear_rmsnorm
loss 1 softmax_cross_entropy
quant 3 dequantize_int8, dequantize_silu_int8, dequantize_int4
bits 2 pack_bits, unpack_bits
moe 1 top2_gating_softmax
image 2 resize_bilinear, depthwise_conv_3x3
indexing 2 fused_gather_add, fused_scatter_add
scan 2 cumsum_lastdim, cumsum_grad

Architecture

Three-layer design. Full details: docs/ARCHITECTURE.md.

  1. Metal kernel infrastructureMetalKernel wrapper, in-process cache, stats tracking
  2. Code generation & helpers — MSL templates, elementwise/autograd/rowwise APIs, autotuning
  3. Kernel catalog — 17 domain modules built on layers 1 and 2

Benchmarks

Run python benchmarks/microbench.py to reproduce. Headline numbers on M4 (vs MLX reference ops):

  • Dropout: ~7x faster (fused Metal-side LCG RNG)
  • SwiGLU: ~2-3x faster (fused silu + gate multiply)

Results vary by shape, dtype, and chip. See benchmarks/ for the full harness.


Roadmap

  • Flash Attention tiles (shared memory, 16x16 / 32x32)
  • Expanded quantization (int4 matmul, mixed-precision patterns)
  • Zig frontend via C++ shim (MLX-C once available)
  • JVP support for all catalog kernels
  • Community-contributed kernels

Contributing

See CONTRIBUTING.md for setup, testing, and conventions.


License

MIT. See LICENSE.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

zmlx-0.2.1.tar.gz (76.6 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

zmlx-0.2.1-py3-none-any.whl (54.9 kB view details)

Uploaded Python 3

File details

Details for the file zmlx-0.2.1.tar.gz.

File metadata

  • Download URL: zmlx-0.2.1.tar.gz
  • Upload date:
  • Size: 76.6 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: twine/6.1.0 CPython/3.13.7

File hashes

Hashes for zmlx-0.2.1.tar.gz
Algorithm Hash digest
SHA256 ddd87d6a28d05f082a22ee4257e1f380691071b164876b9fd984d9b4b6efd875
MD5 d31fc5b796ba1fa674eef274544e4a64
BLAKE2b-256 44bdcbbe2d11550646c8868a3ba335b55155b87f7e862fb0c8bca8de4af7daa8

See more details on using hashes here.

Provenance

The following attestation bundles were made for zmlx-0.2.1.tar.gz:

Publisher: release.yml on Hmbown/ZMLX

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

File details

Details for the file zmlx-0.2.1-py3-none-any.whl.

File metadata

  • Download URL: zmlx-0.2.1-py3-none-any.whl
  • Upload date:
  • Size: 54.9 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: twine/6.1.0 CPython/3.13.7

File hashes

Hashes for zmlx-0.2.1-py3-none-any.whl
Algorithm Hash digest
SHA256 58970a028763954f36c2ca80e1ce75134151e2d2b26d490427e1a142a277426e
MD5 906fd1e73117b30798b5b509c14bfad3
BLAKE2b-256 7676bb9461f529eac8ee48d8fc9dd3c0b2211025e9030bc616a2271da73cac09

See more details on using hashes here.

Provenance

The following attestation bundles were made for zmlx-0.2.1-py3-none-any.whl:

Publisher: release.yml on Hmbown/ZMLX

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page