Skip to main content

Neighborhood Attention for Apple Silicon — PyTorch MPS backend

Project description

natten-mps

GPU-accelerated Neighborhood Attention for Apple Silicon — built on PyTorch MPS.

Disclaimer (unofficial): This is an independent, unofficial implementation/port for Apple Silicon.
Not affiliated with SHI-Labs or the upstream NATTEN project.

This is a focused, Apple-Silicon-first implementation intended to be useful, correct, and easy to install — not a replacement for upstream NATTEN on CUDA.

Neighborhood Attention was introduced by the NATTEN authors. If you use Neighborhood Attention in research, please cite the original papers (see Acknowledgments).

v0.x — API may change between minor versions. Pin your dependency for production use.


Why this exists

Upstream NATTEN is CUDA-focused and targets NVIDIA GPUs. On Apple Silicon, PyTorch users often want a GPU-accelerated neighborhood attention option without requiring CUDA.

natten-mps provides:

  • Metal-backed kernels for PyTorch MPS using torch.mps.compile_shader
  • 1D / 2D / 3D neighborhood attention with full autograd support
  • A deployment story that is intentionally simple: no native extension build step — install from PyPI and go. Metal shaders are compiled at runtime via torch.mps.compile_shader and cached by PyTorch for the process (best effort).

For MLX-based workflows, see the sibling project: natten-mlx.

Jump to: Installation | Quick start | Features | Backends | Performance | Limitations | Acknowledgments


Use natten-mps if…

  • You’re using PyTorch
  • You run on Apple Silicon and want MPS (Metal) acceleration
  • You want a drop-in-ish API (plus optional compatibility shims for historical NATTEN versions)

Installation

pip install natten-mps

Requirements:

  • Python 3.10+
  • PyTorch 2.8+ with MPS support
  • macOS 12.3+ for MPS (CPU fallback works anywhere PyTorch runs)

Quick start

Functional API

import torch
from natten_mps import na1d, na2d, na3d

# 1D: [B, L, heads, head_dim]
q = torch.randn(2, 128, 4, 32, device="mps")
k = torch.randn(2, 128, 4, 32, device="mps")
v = torch.randn(2, 128, 4, 32, device="mps")
out = na1d(q, k, v, kernel_size=7)

# 2D: [B, H, W, heads, head_dim]
q2d = torch.randn(2, 32, 32, 4, 32, device="mps")
k2d = torch.randn(2, 32, 32, 4, 32, device="mps")
v2d = torch.randn(2, 32, 32, 4, 32, device="mps")
out2d = na2d(q2d, k2d, v2d, kernel_size=7)

# 3D: [B, D, H, W, heads, head_dim]
q3d = torch.randn(1, 8, 8, 8, 4, 32, device="mps")
k3d = torch.randn(1, 8, 8, 8, 4, 32, device="mps")
v3d = torch.randn(1, 8, 8, 8, 4, 32, device="mps")
out3d = na3d(q3d, k3d, v3d, kernel_size=3)

Module API

import torch
from natten_mps import NeighborhoodAttention2D

layer = NeighborhoodAttention2D(embed_dim=128, num_heads=4, kernel_size=(7, 7)).to("mps")
x = torch.randn(2, 32, 32, 128, device="mps")  # [B, H, W, C]
y = layer(x)

Split QK / AV (access attention weights)

import torch
from natten_mps import na1d_qk, na1d_av

B, L, H, D = 2, 128, 4, 32
q = torch.randn(B, L, H, D, device="mps")
k = torch.randn(B, L, H, D, device="mps")
v = torch.randn(B, L, H, D, device="mps")

logits = na1d_qk(q, k, kernel_size=7, scale=D ** -0.5)  # [B, L, H, K]
attn = torch.softmax(logits, dim=-1)
out = na1d_av(attn, v, kernel_size=7)                   # [B, L, H, D]

Features

Core:

  • 1D / 2D / 3D neighborhood attention (fused and split QK/AV ops)
  • Causal masking, including per-axis control (e.g. is_causal=(True, False) for 2D)
  • Strided output for downsampling (e.g. stride=2)
  • Combined causal + stride in one kernel
  • Non-uniform kernels for 2D/3D (per-axis kernel sizes and dilations)

Batching / advanced:

  • Variable-length (varlen) attention — padded batches with per-sample spatial sizes, Metal-accelerated for all ranks
  • GQA / MQA (num_kv_heads) for grouped-query attention patterns
  • additional_keys / additional_values — prepend extra global tokens that every query attends to
  • merge_attentions — numerically stable sigmoid-based merge of multiple attention outputs
  • FMHA fast path — when the kernel covers the full spatial extent, can dispatch to efficient full attention

Extras:

  • extras/ namespace for model-specific fused kernels (e.g., DiNAT-style fused QK+RPB paths)

Compatibility:


Backends

Backend dispatch is controlled at runtime and does not require a native extension.

Backend Status Description
pure Complete Pure PyTorch fallback (CPU/MPS)
metal Complete Metal compute shaders via torch.mps.compile_shader
auto Default Select best available backend for the configuration
import natten_mps

natten_mps.set_backend("metal")  # "auto" (default), "metal", or "pure"
print(natten_mps.get_backend())

Or via environment variable:

NATTEN_BACKEND=metal python my_script.py   # "auto" (default), "metal", or "pure"

Performance

Metal kernels vs pure-PyTorch backend on Apple Silicon (M-series), forward pass:

Benchmark Metal Pure Speedup
1D, L=256, K=7 0.9 ms 9.8 ms 11×
1D, L=1024, K=7 1.1 ms 37 ms 34×
2D, 32×32, K=7 1.3 ms 20 ms 15–17×
2D, 64×64, K=7 2.9 ms 84 ms 29×
2D, 32×32, K=7, causal 1.1 ms 21 ms 19×
3D, 16³, K=3 1.7 ms 12 ms

Run the full suite:

python benchmarks/bench.py
# add --backward to time backward pass

Cross-framework: natten-mps vs natten-mlx

Apple Silicon (M-series), fp32, B=1 H=4 D=32, Metal-accelerated:

Config natten-mps fwd natten-mlx fwd natten-mps bwd natten-mlx bwd
1D L=256 K=7 0.25 ms 0.21 ms 0.39 ms 0.14 ms
1D L=1024 K=7 0.40 ms 0.27 ms 0.63 ms 0.26 ms
2D 32×32 K=7 0.88 ms 0.65 ms 1.62 ms 1.02 ms
2D 64×64 K=7 1.32 ms 1.13 ms 1.55 ms 0.97 ms
2D 32×32 K=7 causal 0.37 ms 0.29 ms 0.49 ms 0.31 ms
3D 16³ K=3 0.55 ms 0.43 ms 0.89 ms 0.50 ms

MLX’s compiled primitives tend to have lower dispatch overhead than PyTorch MPS, so natten-mlx is often faster for the same shapes. Both are dramatically faster than pure-framework baselines.

Variable-length (varlen) attention

Metal-accelerated varlen forward, fp32:

Config natten-mps natten-mlx MLX speedup
varlen 1D B=4 L=128 K=7 1.74 ms 0.53 ms 3.3×
varlen 1D B=4 L=256 K=7 1.74 ms 0.51 ms 3.4×
varlen 2D B=2 16×16 K=3 2.39 ms 0.82 ms 2.9×
varlen 2D B=2 32×32 K=7 3.79 ms 1.23 ms 3.1×
varlen 3D B=2 8³ K=3 3.82 ms 1.55 ms 2.5×

Backward pass uses per-sample autograd re-differentiation through the standard Metal-accelerated na*d kernels.

Methodology

All timings on Apple M4 Max, macOS 26.3, Python 3.11, PyTorch 2.10, float32. Each kernel is warmed up for 5 iterations, then timed for 20 repetitions with torch.mps.synchronize() gating; the reported value is the median. Reproduce with python benchmarks/bench.py.


Compatibility shims

If you have downstream code written against historical upstream APIs, natten-mps includes optional shims:

import natten_mps.compat.v014 as natten014
import natten_mps.compat.v017 as natten017
import natten_mps.compat.v020 as natten020

These are best-effort drop-in replacements for common upstream natten entry points.


Extras: model-specific fused kernels

Example: fused DiNAT-style ops with relative position bias:

from natten_mps.extras.allin1 import (
    na1d_qk_rpb, na1d_av_fused,
    na2d_qk_rpb, na2d_av_fused,
)

Limitations

  • Odd kernel sizes only for accelerated Neighborhood Attention (this matches upstream NATTEN’s neighborhood half-width formulation).
  • Metal kernel acceleration has size caps tuned for performance:
    • 1D: K ≤ 63
    • 2D: K ≤ 13
    • 3D: K ≤ 7
  • Unsupported kernel sizes or configurations automatically fall back to pure.
  • Supported dtypes: Metal kernels run in float32 and float16. Bfloat16 inputs are accepted but upcast to float32 internally. Other dtypes fall back to pure.
  • MPS acceleration is macOS-only (CPU fallback works anywhere PyTorch runs).

Differences from upstream NATTEN (high level)

  • Targets Apple Silicon (PyTorch MPS + CPU fallback); no CUDA backend
  • Uses Metal compute shaders instead of CUDA kernels
  • Includes Apple-Silicon-focused extras (and optional compatibility shims)

Acknowledgments

This project implements Neighborhood Attention as introduced by the upstream NATTEN project (SHI-Labs). The original NATTEN library and research are by Ali Hassani, Steven Walton, Humphrey Shi, and collaborators.

If you use Neighborhood Attention in research, please cite the original papers:

  • Hassani et al., Neighborhood Attention Transformer (CVPR 2023)
  • Hassani & Shi, Dilated Neighborhood Attention Transformer (2022)
  • Hassani et al., Faster Neighborhood Attention (NeurIPS 2024)
BibTeX
@inproceedings{hassani2023neighborhood,
  title   = {Neighborhood Attention Transformer},
  author  = {Hassani, Ali and Walton, Steven and Li, Jiachen and Li, Shen and Shi, Humphrey},
  booktitle = {CVPR},
  year    = {2023}
}

@article{hassani2022dilated,
  title   = {Dilated Neighborhood Attention Transformer},
  author  = {Hassani, Ali and Shi, Humphrey},
  journal = {arXiv preprint arXiv:2209.15001},
  year    = {2022}
}

@inproceedings{hassani2024faster,
  title   = {Faster Neighborhood Attention: Reducing the O(n^2) Cost of Self Attention at the Threadblock Level},
  author  = {Hassani, Ali and Ke, Wen-Mei and Gong, Jiaming and Walton, Steven and Shi, Humphrey},
  booktitle = {NeurIPS},
  year    = {2024}
}

License

MIT — see LICENSE for details.
Upstream NATTEN is also MIT-licensed.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

natten_mps-0.3.0.tar.gz (93.1 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

natten_mps-0.3.0-py3-none-any.whl (76.3 kB view details)

Uploaded Python 3

File details

Details for the file natten_mps-0.3.0.tar.gz.

File metadata

  • Download URL: natten_mps-0.3.0.tar.gz
  • Upload date:
  • Size: 93.1 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: twine/6.1.0 CPython/3.13.7

File hashes

Hashes for natten_mps-0.3.0.tar.gz
Algorithm Hash digest
SHA256 1885e2df8e91e09e1e46144833bda191618d43fd99dd80a609e312a30a6bcec7
MD5 e9c308e3ef718418091b9cb6c4f65b20
BLAKE2b-256 64b43b5b78b6721045a3cf82f4f2011a86fb45a7dff669d354e8f5c9a74cabac

See more details on using hashes here.

Provenance

The following attestation bundles were made for natten_mps-0.3.0.tar.gz:

Publisher: release-pypi.yml on ssmall256/natten-mps

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

File details

Details for the file natten_mps-0.3.0-py3-none-any.whl.

File metadata

  • Download URL: natten_mps-0.3.0-py3-none-any.whl
  • Upload date:
  • Size: 76.3 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: twine/6.1.0 CPython/3.13.7

File hashes

Hashes for natten_mps-0.3.0-py3-none-any.whl
Algorithm Hash digest
SHA256 74a3bb97ecc394e2ef87ccc46cbca9a754a3c08de8c14f873c8b79930795d15c
MD5 67ea8128a7fc2cb9eadb9b9722793f3e
BLAKE2b-256 fb496cfe980a72a55cd683462b41c98e7b5e79faac7c5578986cda17c6d73a98

See more details on using hashes here.

Provenance

The following attestation bundles were made for natten_mps-0.3.0-py3-none-any.whl:

Publisher: release-pypi.yml on ssmall256/natten-mps

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page