Skip to main content

ZMLX: Numba-for-MLX. Ergonomic helpers for authoring and autotuning MLX custom Metal kernels from Python.

Project description

ZMLX

Python 3.10+ License: MIT Platform: macOS Apple Silicon

Numba-style helpers for authoring, autotuning, and differentiating custom Metal kernels on Apple silicon.

pip install zmlx
from zmlx import autograd, msl
import mlx.core as mx

silu = autograd.unary_from_expr(
    name="my_silu", fwd_expr="x * kk_sigmoid(x)",
    vjp_expr="g * (s + x * s * ((T)1 - s))",
    compute_dtype=mx.float32, use_output=False,
    vjp_prelude="T s = kk_sigmoid(x);",
    header=msl.DEFAULT_HEADER,
)
y = silu(mx.random.normal((1024,)))  # runs on GPU, supports mx.grad

Why ZMLX?

  • Define-once caching — kernels compile once and are reused across calls (keyed on source hash + config).
  • One-line ops — create elementwise kernels from a C expression: elementwise.unary(name=..., expr=...).
  • Differentiable kernels — attach custom VJP backward passes (themselves Metal kernels) via mx.custom_function.
  • Autotuning — search threadgroup sizes automatically and cache the winners.
  • 70+ ready-to-use catalog kernels — activations, softmax, norms, RoPE, transformer fused ops, reductions, quantization, and more.

Install

Requirements: macOS (Apple Silicon), Python >= 3.10, mlx >= 0.30.0

pip install zmlx

From source (development):

git clone https://github.com/Hmbown/ZMLX.git
cd ZMLX
pip install -e ".[dev]"

Quick Examples

1. Elementwise kernel from a C expression

from zmlx import elementwise, msl
import mlx.core as mx

exp_fast = elementwise.unary(
    name="kk_exp",
    expr="metal::exp(x)",
    compute_dtype=mx.float32,
    header=msl.DEFAULT_HEADER,
)

x = mx.random.normal((1024,)).astype(mx.float16)
y = exp_fast(x)

2. Differentiable kernel with custom VJP

from zmlx import autograd, msl
import mlx.core as mx

exp_trainable = autograd.unary_from_expr(
    name="kk_exp_vjp",
    fwd_expr="metal::exp(x)",
    vjp_expr="g * y",
    compute_dtype=mx.float32,
    use_output=True,
    header=msl.DEFAULT_HEADER,
)

def loss(z):
    return exp_trainable(z).sum()

x = mx.random.normal((1024,))
gx = mx.grad(loss)(x)
mx.eval(gx)

3. Catalog kernel (ready-to-use)

from zmlx.kernels import softmax, norms, transformer
import mlx.core as mx

x = mx.random.normal((8, 1024)).astype(mx.float16)
w = mx.ones((1024,), dtype=mx.float16)

y = softmax.softmax_lastdim(x)           # rowwise softmax
z = norms.rmsnorm(x, w)                   # differentiable RMSNorm
s = transformer.swiglu(mx.random.normal((8, 2048)).astype(mx.float16))  # fused SwiGLU

Kernel Catalog

17 modules, 70+ kernels (including gradient helpers) organized by domain. Full reference: docs/KERNELS.md.

Module Count Highlights
activations 19 exp, sigmoid, relu, silu, gelu_tanh, softplus + grad variants
transformer 10 swiglu, geglu, rmsnorm_residual, layernorm_residual, dropout
softmax 3 softmax_lastdim, log_softmax_lastdim, softmax_grad
norms 6 rmsnorm, layernorm, rmsnorm_grad, layer_norm_dropout
attention 4 masked_softmax, scale_mask_softmax, logsumexp_lastdim
rope 3 apply_rope, apply_rope_interleaved, apply_gqa_rope
reductions 7 sum, mean, max, var, std, argmax, topk (all lastdim)
fused 6 add, mul, bias_gelu_tanh, bias_silu, silu_mul_grad, add_bias
linear 4 fused_linear_bias_silu, fused_linear_bias_gelu, fused_linear_rmsnorm
loss 1 softmax_cross_entropy
quant 3 dequantize_int8, dequantize_silu_int8, dequantize_int4
bits 2 pack_bits, unpack_bits
moe 1 top2_gating_softmax
image 2 resize_bilinear, depthwise_conv_3x3
indexing 2 fused_gather_add, fused_scatter_add
scan 2 cumsum_lastdim, cumsum_grad

Architecture

Three-layer design. Full details: docs/ARCHITECTURE.md.

  1. Metal kernel infrastructureMetalKernel wrapper, in-process cache, stats tracking
  2. Code generation & helpers — MSL templates, elementwise/autograd/rowwise APIs, autotuning
  3. Kernel catalog — 17 domain modules built on layers 1 and 2

Benchmarks

Run python benchmarks/microbench.py to reproduce. Headline numbers on M4 (vs MLX reference ops):

  • Dropout: ~7x faster (fused Metal-side LCG RNG)
  • SwiGLU: ~2-3x faster (fused silu + gate multiply)

Results vary by shape, dtype, and chip. See benchmarks/ for the full harness.


Roadmap

  • Flash Attention tiles (shared memory, 16x16 / 32x32)
  • Expanded quantization (int4 matmul, mixed-precision patterns)
  • Zig frontend via MLX-C (multi-language kernel generation)
  • JVP support for all catalog kernels
  • Community-contributed kernels

Contributing

See CONTRIBUTING.md for setup, testing, and conventions.


License

MIT. See LICENSE.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

zmlx-0.2.0.tar.gz (52.3 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

zmlx-0.2.0-py3-none-any.whl (51.3 kB view details)

Uploaded Python 3

File details

Details for the file zmlx-0.2.0.tar.gz.

File metadata

  • Download URL: zmlx-0.2.0.tar.gz
  • Upload date:
  • Size: 52.3 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: twine/6.1.0 CPython/3.13.7

File hashes

Hashes for zmlx-0.2.0.tar.gz
Algorithm Hash digest
SHA256 43c88fba52a5e1a3d5d9eadaa8f0fafdd4ec65a46e7487b8789481928d28926a
MD5 574eda0d2ab314b54058b98cb3d7c6df
BLAKE2b-256 0ed5daf4649f5c3cbf7ddc046dbe2cfc6e7b97edff3ccf4a57db91a16aee7b83

See more details on using hashes here.

Provenance

The following attestation bundles were made for zmlx-0.2.0.tar.gz:

Publisher: release.yml on Hmbown/ZMLX

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

File details

Details for the file zmlx-0.2.0-py3-none-any.whl.

File metadata

  • Download URL: zmlx-0.2.0-py3-none-any.whl
  • Upload date:
  • Size: 51.3 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: twine/6.1.0 CPython/3.13.7

File hashes

Hashes for zmlx-0.2.0-py3-none-any.whl
Algorithm Hash digest
SHA256 42510c8d3eebacd0f61f32d1af482d6ff6d9989921bf707175a650f332c5b92b
MD5 0cf3754d976352c16a41f212f6a89d03
BLAKE2b-256 d80ae7b97543665104a2e820747ce1ae48bef5b0370e4b84840c74033b807c3e

See more details on using hashes here.

Provenance

The following attestation bundles were made for zmlx-0.2.0-py3-none-any.whl:

Publisher: release.yml on Hmbown/ZMLX

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page