Skip to main content

Balanced kmeans with cuda support in Pytorch.

Project description

Balanced K-Means clustering in PyTorch

Balanced K-Means clustering in Pytorch with strong GPU acceleration.

Disclaimer: This project is heavily inspired by the project kmeans_pytorch. Each part of the original implementation is combined with the appropriate attribution.

Installation

As easy as:

pip install balanced_kmeans

Getting started

First things first: Classical kmeans algorithm as easy as

from balanced_kmeans import kmeans
# experiment constants
N = 10000
batch_size = 10
num_clusters = 100
device = 'cuda'

cluster_size = N // num_clusters
X = torch.rand(batch_size, N, dim, device=device)
choices, centers = kmeans(X, num_clusters=num_clusters)

Now, if you want balanced kmeans you can run:

from balanced_kmeans import kmeans_equal
N = 10000
batch_size = 10
num_clusters = 100
device = 'cuda'

cluster_size = N // num_clusters
X = torch.rand(batch_size, N, dim, device=device)
choices, centers = kmeans_equal(X, num_clusters=num_clusters)

By default, forge initialization scheme is used for initial cluster centers. However, you may change the initial cluster centers by providing the keyword argument initial_state to either kmeans or kmeans_equal.

Contributing

This is a pet project, so feel free to contribute if you want to add any extra feature. For any bugs, please open a detailed issue.

Credits

This implementation extends the package kmeans_pytorch which contains the implementation of the original Lloyd's K-means algorithm in Pytorch. You can check (and star!) the original package here.

For licensing of this project, please refer to this repo as well as the kmeans_pytorch repo.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

balanced_kmeans-0.1.0.tar.gz (5.1 kB view details)

Uploaded Source

File details

Details for the file balanced_kmeans-0.1.0.tar.gz.

File metadata

  • Download URL: balanced_kmeans-0.1.0.tar.gz
  • Upload date:
  • Size: 5.1 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/3.1.1 pkginfo/1.5.0.1 requests/2.22.0 setuptools/46.0.0.post20200309 requests-toolbelt/0.8.0 tqdm/4.45.0 CPython/3.7.1

File hashes

Hashes for balanced_kmeans-0.1.0.tar.gz
Algorithm Hash digest
SHA256 5bde59536ff7bc90af3fa12b973ba93546dfa9555baeafb5af0087ef997e45f0
MD5 519001c5c5cc20452c00c8ab133f73d2
BLAKE2b-256 5a51320acd74f5b5955ea9b64ad32342ee5c88fada89fd1f73c702a3a3a85d8f

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page