Skip to main content

Fast batched K-Means clustering with Triton GPU kernels

Project description

Flash-KMeans

Fast batched K-Means clustering implemented with Triton GPU kernels. This repository provides the official K-Means implementation of Sparse VideoGen2.

Teasor

Installation

Install flash-kmeans with pip:

pip install flash-kmeans

From source:

git clone https://github.com/svg-project/flash-kmeans.git
cd flash-kmeans
pip install -e .

Usage

import torch
from flash_kmeans import batch_kmeans_Euclid

x = torch.randn(32, 75600, 128, device="cuda", dtype=torch.float16)
cluster_ids, centers, _ = batch_kmeans_Euclid(x, n_clusters=1000, tol=1e-4, verbose=True)

We also provide a API interface similar to faiss/sklearn, see API docs for details.

Benchmark

We compare the performance of our Triton implementation with the following baselines:

  • fast_pytorch_kmeans a Pytorch implmentation of K-Means clustering.
  • fastkmeans(triton) / fastkmeans(torch) another triton implementation of K-Means clustering. (and its Pytorch fallback)
  • flash-kmeans(triton) / flash-kmeans(torch): our implementation in Triton and Pytorch fallback.
  • batched torch kmeans: a naive batch implementation without considering OOM.

Tested on NVIDIA H200 GPU with FP16 precision, 128 demensional data, varying number of clusters (k), data points (n) and batch size (b). Our Triton implementation brings significant performance improvements.

Benchmark result 1 Benchmark result 2

Note: fastkmeans(triton) get error when k=100 or k=1000 in figure 1.

Large tensor Benchmark

For large input that cannot fit in GPU memory, we compare the performance with fastkmeans(triton) with FP32 precision, 128 demensional data, number if data points scaling from 256K to 268M (N = 2^18, 2^20, 2^22, 2^24, 2^26, 2^28) with cluster counts following K = √N (512, 1024, 2048, 4096, 8192, 16384).

Input tensor is generated randomly in CPU pinned memory. both flash-kmeans and fastkmeans transfer data from CPU to GPU in chunk and compute.

benchmark large N

Citation

If you use this codebase, or otherwise found our work valuable, please cite:

@article{yang2025sparse,
  title={Sparse VideoGen2: Accelerate Video Generation with Sparse Attention via Semantic-Aware Permutation},
  author={Yang, Shuo and Xi, Haocheng and Zhao, Yilong and Li, Muyang and Zhang, Jintao and Cai, Han and Lin, Yujun and Li, Xiuyu and Xu, Chenfeng and Peng, Kelly and others},
  journal={arXiv preprint arXiv:2505.18875},
  year={2025}
}

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

flash_kmeans-0.2.0.tar.gz (29.1 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

flash_kmeans-0.2.0-py3-none-any.whl (30.5 kB view details)

Uploaded Python 3

File details

Details for the file flash_kmeans-0.2.0.tar.gz.

File metadata

  • Download URL: flash_kmeans-0.2.0.tar.gz
  • Upload date:
  • Size: 29.1 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.13.11

File hashes

Hashes for flash_kmeans-0.2.0.tar.gz
Algorithm Hash digest
SHA256 bce83c7a57cfb7e6af95cd169940b5aa8b67db7dfa89094f6bcc1e997a1ad59d
MD5 d0faecf3454a4470159adef6042b7ecb
BLAKE2b-256 90a91fd408faa36c6ce5b6e8e98fb7f56c4e2131754dc1f304751671fe1dc175

See more details on using hashes here.

File details

Details for the file flash_kmeans-0.2.0-py3-none-any.whl.

File metadata

  • Download URL: flash_kmeans-0.2.0-py3-none-any.whl
  • Upload date:
  • Size: 30.5 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.13.11

File hashes

Hashes for flash_kmeans-0.2.0-py3-none-any.whl
Algorithm Hash digest
SHA256 3b12be4120c04dbec252ce01d4dc663c6943af12d8878dbb2a57cae6c57f4557
MD5 2e2fdd34c9825bbd49d5805d9daf33f5
BLAKE2b-256 9feea6095088ef5c1f0db62c46ec26bb04e42411423b5ac2adbc74c148178afb

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page