Skip to main content

Balanced kmeans with cuda support in Pytorch.

Project description

Balanced K-Means clustering in PyTorch

Balanced K-Means clustering in Pytorch with strong GPU acceleration.

Disclaimer: This project is heavily inspired by the project kmeans_pytorch. Each part of the original implementation is combined with the appropriate attribution.

Installation

As easy as:

pip install balanced_kmeans

Getting started

First things first: Classical kmeans algorithm as easy as

from balanced_kmeans import kmeans
# experiment constants
N = 10000
batch_size = 10
num_clusters = 100
device = 'cuda'

cluster_size = N // num_clusters
X = torch.rand(batch_size, N, dim, device=device)
choices, centers = kmeans(X, num_clusters=num_clusters)

Now, if you want balanced kmeans you can run:

from balanced_kmeans import kmeans_equal
N = 10000
batch_size = 10
num_clusters = 100
device = 'cuda'

cluster_size = N // num_clusters
X = torch.rand(batch_size, N, dim, device=device)
choices, centers = kmeans_equal(X, num_clusters=num_clusters)

By default, forge initialization scheme is used for initial cluster centers. However, you may change the initial cluster centers by providing the keyword argument initial_state to either kmeans or kmeans_equal.

Contributing

This is a pet project, so feel free to contribute if you want to add any extra feature. For any bugs, please open a detailed issue.

Credits

This implementation extends the package kmeans_pytorch which contains the implementation of the original Lloyd's K-means algorithm in Pytorch. You can check (and star!) the original package here.

For licensing of this project, please refer to this repo as well as the kmeans_pytorch repo.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Files for balanced-kmeans, version 0.1.0
Filename, size File type Python version Upload date Hashes
Filename, size balanced_kmeans-0.1.0.tar.gz (5.1 kB) File type Source Python version None Upload date Hashes View

Supported by

AWS AWS Cloud computing Datadog Datadog Monitoring DigiCert DigiCert EV certificate Facebook / Instagram Facebook / Instagram PSF Sponsor Fastly Fastly CDN Google Google Object Storage and Download Analytics Pingdom Pingdom Monitoring Salesforce Salesforce PSF Sponsor Sentry Sentry Error logging StatusPage StatusPage Status page