IO-aware batched K-Means for Apple Silicon, ported from Flash-KMeans (Triton/CUDA) to pure MLX.
Project description
flash-kmeans-mlx
IO-aware batched K-Means for Apple Silicon, ported from Flash-KMeans (Triton/CUDA) to pure MLX.
500K points, 128 dimensions, K=1000 clustered in 0.76s on M3 Ultra -- 160x faster than sklearn. Uses custom Metal kernels for argmax, fused addmm assignment, and multi-iteration compiled execution.
Full Fashion-MNIST (70K samples, 784 dimensions, K=10) clustered in 0.12s on M3 Ultra, 6x faster than sklearn (0.74s). Left: K-Means cluster assignments. Right: ground truth labels. Visualization via mlx-vis UMAP.
Installation
uv pip install flash-kmeans-mlx
From source:
git clone https://github.com/hanxiao/flash-kmeans-mlx.git
cd flash-kmeans-mlx
uv pip install .
Usage
Functional API
import mlx.core as mx
from flash_kmeans_mlx import batch_kmeans_Euclid
x = mx.random.normal((32, 75600, 128))
cluster_ids, centers, n_iters = batch_kmeans_Euclid(
x, n_clusters=1000, tol=1e-4, verbose=True
)
Input shape is (B, N, D) where B is batch size, N is number of points, D is dimensionality. All batches are clustered independently in a single vectorized pass.
Three distance metrics are available: batch_kmeans_Euclid, batch_kmeans_Cosine, and batch_kmeans_Dot.
Class API
from flash_kmeans_mlx import FlashKMeans
model = FlashKMeans(d=128, k=1000, niter=25, tol=1e-6)
model.fit(x)
labels = model.predict(x_new)
# or in one step
labels = model.fit_predict(x)
The FlashKMeans class accepts both (N, D) and (B, N, D) inputs. Set metric="cosine" or metric="dot" to switch distance functions.
Benchmark
All timings on M3 Ultra, float32, single batch. MLX uses mx.compile; sklearn uses Lloyd's algorithm on CPU (n_init=1).
| N | D | K | Iters | MLX | sklearn | Speedup |
|---|---|---|---|---|---|---|
| 5K | 64 | 50 | 10 | 2ms | 34ms | 17x |
| 50K | 128 | 256 | 20 | 7ms | 1.28s | 183x |
| 100K | 128 | 1000 | 20 | 32ms | 9.8s | 306x |
| 500K | 128 | 1000 | 10 | 77ms | 39.8s | 517x |
Run the benchmark yourself:
uv pip install 'flash-kmeans-mlx[benchmark]'
python -m flash_kmeans_mlx.benchmark --n 100000 --d 128 --k 1000 --max-iters 20
vs H200 GPU
Comparison against the original Flash-KMeans and other PyTorch implementations on NVIDIA H200 (FP16). All methods run D=128, 100 iterations. MLX on M3 Ultra matches or beats naive PyTorch methods on H200, with the gap to the Triton kernel explained by the 37x raw compute difference (27 TFLOPS vs 990 TFLOPS).
Correctness
Verified against sklearn with identical initial centroids over 20 iterations. Cluster assignment agreement is 92-99.8% depending on configuration, with inertia difference below 0.01%. The remaining discrepancy comes from float32 vs sklearn's float64 accumulation - boundary points near equidistant cluster borders get assigned differently due to rounding.
Distance metrics
Euclidean (squared L2), Cosine (dot product on L2-normalized vectors), and Dot-product (raw inner product).
Credits
This is an independent MLX port of Flash-KMeans and is not affiliated with the original authors.
Papers:
- Flash-KMeans: IO-Aware Batched K-Means (Shuo Yang et al.)
- Sparse VideoGen2 (Haocheng Xi, Shuo Yang et al.)
License
Apache 2.0
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file flash_kmeans_mlx-0.1.1.tar.gz.
File metadata
- Download URL: flash_kmeans_mlx-0.1.1.tar.gz
- Upload date:
- Size: 17.9 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: uv/0.8.8
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
2cff8b4a71f9aae4dcf139fc8512f4c962659a8a99bba145333c7569218b850c
|
|
| MD5 |
513e228b44c713f3f0f0d4b5f48f36e0
|
|
| BLAKE2b-256 |
6d2fa00b65b09517f487681e67e78ba20806bf5112ce391529a32deff0900a52
|
File details
Details for the file flash_kmeans_mlx-0.1.1-py3-none-any.whl.
File metadata
- Download URL: flash_kmeans_mlx-0.1.1-py3-none-any.whl
- Upload date:
- Size: 17.9 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: uv/0.8.8
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
c7cd63e3599d1bf2918c6366d2b2ff217e16c1b95111d2ce84fb7af7fac415ba
|
|
| MD5 |
95239c9c945f6e61e04c4d74d6404fce
|
|
| BLAKE2b-256 |
178fd8357bb350ce5e41d46bc850724eb30917e1d958e8b6604851f5e298bee9
|