Fast batched K-Means clustering with Triton GPU kernels
Project description
Flash-KMeans
Fast batched K-Means clustering implemented with Triton GPU kernels. This repository provides the official K-Means implementation of Sparse VideoGen2.
Installation
Install flash-kmeans with pip:
pip install flash-kmeans
From source:
git clone https://github.com/svg-project/flash-kmeans.git
cd flash-kmeans
pip install -e .
Usage
import torch
from flash_kmeans import batch_kmeans_Euclid
x = torch.randn(32, 75600, 128, device="cuda", dtype=torch.float16)
cluster_ids, centers, _ = batch_kmeans_Euclid(x, n_clusters=1000, tol=1e-4, verbose=True)
We also provide a API interface similar to faiss/sklearn, see API docs for details.
Benchmark
We compare the performance of our Triton implementation with the following baselines:
- fast_pytorch_kmeans a Pytorch implmentation of K-Means clustering.
- fastkmeans(triton) / fastkmeans(torch) another triton implementation of K-Means clustering. (and its Pytorch fallback)
- flash-kmeans(triton) / flash-kmeans(torch): our implementation in Triton and Pytorch fallback.
- batched torch kmeans: a naive batch implementation without considering OOM.
Tested on NVIDIA H200 GPU with FP16 precision, 128 demensional data, varying number of clusters (k), data points (n) and batch size (b). Our Triton implementation brings significant performance improvements.
Note: fastkmeans(triton) get error when k=100 or k=1000 in figure 1.
Large tensor Benchmark
For large input that cannot fit in GPU memory, we compare the performance with fastkmeans(triton) with FP32 precision, 128 demensional data, number if data points scaling from 256K to 268M (N = 2^18, 2^20, 2^22, 2^24, 2^26, 2^28) with cluster counts following K = √N (512, 1024, 2048, 4096, 8192, 16384).
Input tensor is generated randomly in CPU pinned memory. both flash-kmeans and fastkmeans transfer data from CPU to GPU in chunk and compute.
Citation
If you use this codebase, or otherwise found our work valuable, please cite:
@article{yang2025sparse,
title={Sparse VideoGen2: Accelerate Video Generation with Sparse Attention via Semantic-Aware Permutation},
author={Yang, Shuo and Xi, Haocheng and Zhao, Yilong and Li, Muyang and Zhang, Jintao and Cai, Han and Lin, Yujun and Li, Xiuyu and Xu, Chenfeng and Peng, Kelly and others},
journal={arXiv preprint arXiv:2505.18875},
year={2025}
}
Project details
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file flash_kmeans-0.2.0.tar.gz.
File metadata
- Download URL: flash_kmeans-0.2.0.tar.gz
- Upload date:
- Size: 29.1 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.2.0 CPython/3.13.11
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
bce83c7a57cfb7e6af95cd169940b5aa8b67db7dfa89094f6bcc1e997a1ad59d
|
|
| MD5 |
d0faecf3454a4470159adef6042b7ecb
|
|
| BLAKE2b-256 |
90a91fd408faa36c6ce5b6e8e98fb7f56c4e2131754dc1f304751671fe1dc175
|
File details
Details for the file flash_kmeans-0.2.0-py3-none-any.whl.
File metadata
- Download URL: flash_kmeans-0.2.0-py3-none-any.whl
- Upload date:
- Size: 30.5 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.2.0 CPython/3.13.11
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
3b12be4120c04dbec252ce01d4dc663c6943af12d8878dbb2a57cae6c57f4557
|
|
| MD5 |
2e2fdd34c9825bbd49d5805d9daf33f5
|
|
| BLAKE2b-256 |
9feea6095088ef5c1f0db62c46ec26bb04e42411423b5ac2adbc74c148178afb
|