Skip to main content

Fast batched K-Means clustering with Triton GPU kernels

Project description

Flash-KMeans

Fast batched K-Means clustering implemented with Triton GPU kernels. This repository provides the official K-Means implementation of Sparse VideoGen2.

Teasor

Installation

Clone the repository and install in editable mode:

git clone https://github.com/svg-project/flash-kmeans.git
cd flash-kmeans
pip install -e .

Usage

import torch
from flash_kmeans import batch_kmeans_Euclid

x = torch.randn(32, 75600, 128, device="cuda", dtype=torch.float16)
cluster_ids, centers, _ = batch_kmeans_Euclid(x, n_clusters=1000, tol=1e-4, verbose=True)

Benchmark

Our Triton implementation brings significant performance improvements. Compared with a standard PyTorch baseline, it achieves up to 16× speed-up on an NVIDIA H100 GPU (FP16, batch size 32, 16k points, 128-D, 1k clusters).

Benchmark result

Citation

If you use this codebase, or otherwise found our work valuable, please cite:

@article{yang2025sparse,
  title={Sparse VideoGen2: Accelerate Video Generation with Sparse Attention via Semantic-Aware Permutation},
  author={Yang, Shuo and Xi, Haocheng and Zhao, Yilong and Li, Muyang and Zhang, Jintao and Cai, Han and Lin, Yujun and Li, Xiuyu and Xu, Chenfeng and Peng, Kelly and others},
  journal={arXiv preprint arXiv:2505.18875},
  year={2025}
}

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

flash_kmeans-0.1.0.tar.gz (20.2 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

flash_kmeans-0.1.0-py3-none-any.whl (19.9 kB view details)

Uploaded Python 3

File details

Details for the file flash_kmeans-0.1.0.tar.gz.

File metadata

  • Download URL: flash_kmeans-0.1.0.tar.gz
  • Upload date:
  • Size: 20.2 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.10.18

File hashes

Hashes for flash_kmeans-0.1.0.tar.gz
Algorithm Hash digest
SHA256 df17bd579b2ed662e5999a07fb09583e5afc20806ecabdd2912ae13b52cc5b82
MD5 8a6b64fffd049287e6d00bb13dae6ae7
BLAKE2b-256 6b1e7efc833b8efd3dea47d435963268c5fe300d0ec7862937fdf5658ce5f9ca

See more details on using hashes here.

File details

Details for the file flash_kmeans-0.1.0-py3-none-any.whl.

File metadata

  • Download URL: flash_kmeans-0.1.0-py3-none-any.whl
  • Upload date:
  • Size: 19.9 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.10.18

File hashes

Hashes for flash_kmeans-0.1.0-py3-none-any.whl
Algorithm Hash digest
SHA256 5f6df05ea0ad0a06a254ea15d339d04edc65558a6dc8630ae43d3063f62439ad
MD5 598084136d9d702efa6a0e1a137546bd
BLAKE2b-256 7fcf983e4ee3c27bcd59996a5b3bc7cbb6e324dcf629c31f5ee2989866dc4e15

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page