Skip to main content

IO-aware batched K-Means for Apple Silicon, ported from Flash-KMeans (Triton/CUDA) to pure MLX.

Project description

flash-kmeans-mlx

IO-aware batched K-Means for Apple Silicon, ported from Flash-KMeans (Triton/CUDA) to pure MLX.

500K points, 128 dimensions, K=1000 clustered in 0.76s on M3 Ultra -- 160x faster than sklearn. Uses custom Metal kernels for argmax, fused addmm assignment, and multi-iteration compiled execution.

Full Fashion-MNIST (70K samples, 784 dimensions, K=10) clustered in 0.12s on M3 Ultra, 6x faster than sklearn (0.74s). Left: K-Means cluster assignments. Right: ground truth labels. Visualization via mlx-vis UMAP.

Fashion-MNIST K-Means vs Ground Truth

Installation

uv pip install flash-kmeans-mlx

From source:

git clone https://github.com/hanxiao/flash-kmeans-mlx.git
cd flash-kmeans-mlx
uv pip install .

Usage

Functional API

import mlx.core as mx
from flash_kmeans_mlx import batch_kmeans_Euclid

x = mx.random.normal((32, 75600, 128))
cluster_ids, centers, n_iters = batch_kmeans_Euclid(
    x, n_clusters=1000, tol=1e-4, verbose=True
)

Input shape is (B, N, D) where B is batch size, N is number of points, D is dimensionality. All batches are clustered independently in a single vectorized pass.

Three distance metrics are available: batch_kmeans_Euclid, batch_kmeans_Cosine, and batch_kmeans_Dot.

Class API

from flash_kmeans_mlx import FlashKMeans

model = FlashKMeans(d=128, k=1000, niter=25, tol=1e-6)
model.fit(x)
labels = model.predict(x_new)

# or in one step
labels = model.fit_predict(x)

The FlashKMeans class accepts both (N, D) and (B, N, D) inputs. Set metric="cosine" or metric="dot" to switch distance functions.

Benchmark

All timings on M3 Ultra, float32, single batch. MLX uses mx.compile; sklearn uses Lloyd's algorithm on CPU (n_init=1).

N D K Iters MLX sklearn Speedup
5K 64 50 10 2ms 34ms 17x
50K 128 256 20 7ms 1.28s 183x
100K 128 1000 20 32ms 9.8s 306x
500K 128 1000 10 77ms 39.8s 517x

Run the benchmark yourself:

uv pip install 'flash-kmeans-mlx[benchmark]'
python -m flash_kmeans_mlx.benchmark --n 100000 --d 128 --k 1000 --max-iters 20

vs H200 GPU

Comparison against the original Flash-KMeans and other PyTorch implementations on NVIDIA H200 (FP16). All methods run D=128, 100 iterations. MLX on M3 Ultra matches or beats naive PyTorch methods on H200, with the gap to the Triton kernel explained by the 37x raw compute difference (27 TFLOPS vs 990 TFLOPS).

H200 vs M3 Ultra benchmark

Correctness

Verified against sklearn with identical initial centroids over 20 iterations. Cluster assignment agreement is 92-99.8% depending on configuration, with inertia difference below 0.01%. The remaining discrepancy comes from float32 vs sklearn's float64 accumulation - boundary points near equidistant cluster borders get assigned differently due to rounding.

Distance metrics

Euclidean (squared L2), Cosine (dot product on L2-normalized vectors), and Dot-product (raw inner product).

Credits

This is an independent MLX port of Flash-KMeans and is not affiliated with the original authors.

Papers:

License

Apache 2.0

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

flash_kmeans_mlx-0.1.1.tar.gz (17.9 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

flash_kmeans_mlx-0.1.1-py3-none-any.whl (17.9 kB view details)

Uploaded Python 3

File details

Details for the file flash_kmeans_mlx-0.1.1.tar.gz.

File metadata

  • Download URL: flash_kmeans_mlx-0.1.1.tar.gz
  • Upload date:
  • Size: 17.9 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: uv/0.8.8

File hashes

Hashes for flash_kmeans_mlx-0.1.1.tar.gz
Algorithm Hash digest
SHA256 2cff8b4a71f9aae4dcf139fc8512f4c962659a8a99bba145333c7569218b850c
MD5 513e228b44c713f3f0f0d4b5f48f36e0
BLAKE2b-256 6d2fa00b65b09517f487681e67e78ba20806bf5112ce391529a32deff0900a52

See more details on using hashes here.

File details

Details for the file flash_kmeans_mlx-0.1.1-py3-none-any.whl.

File metadata

File hashes

Hashes for flash_kmeans_mlx-0.1.1-py3-none-any.whl
Algorithm Hash digest
SHA256 c7cd63e3599d1bf2918c6366d2b2ff217e16c1b95111d2ce84fb7af7fac415ba
MD5 95239c9c945f6e61e04c4d74d6404fce
BLAKE2b-256 178fd8357bb350ce5e41d46bc850724eb30917e1d958e8b6604851f5e298bee9

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page