Balanced kmeans with cuda support in Pytorch.
Project description
Balanced K-Means clustering in PyTorch
Balanced K-Means clustering in Pytorch with strong GPU acceleration.
Disclaimer: This project is heavily inspired by the project kmeans_pytorch. Each part of the original implementation is combined with the appropriate attribution.
Installation
As easy as:
pip install balanced_kmeans
Getting started
First things first: Classical kmeans algorithm as easy as
from balanced_kmeans import kmeans
# experiment constants
N = 10000
batch_size = 10
num_clusters = 100
device = 'cuda'
cluster_size = N // num_clusters
X = torch.rand(batch_size, N, dim, device=device)
choices, centers = kmeans(X, num_clusters=num_clusters)
Now, if you want balanced kmeans you can run:
from balanced_kmeans import kmeans_equal
N = 10000
batch_size = 10
num_clusters = 100
device = 'cuda'
cluster_size = N // num_clusters
X = torch.rand(batch_size, N, dim, device=device)
choices, centers = kmeans_equal(X, num_clusters=num_clusters)
By default, forge initialization scheme is used for initial cluster centers.
However, you may change the initial cluster centers by providing the keyword
argument initial_state
to either kmeans
or kmeans_equal
.
Contributing
This is a pet project, so feel free to contribute if you want to add any extra feature. For any bugs, please open a detailed issue.
Credits
This implementation extends the package kmeans_pytorch
which contains the
implementation of the original Lloyd's K-means algorithm in Pytorch. You can check (and star!)
the original package here.
For licensing of this project, please refer to this repo as well as the kmeans_pytorch
repo.
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
File details
Details for the file balanced_kmeans-0.1.0.tar.gz
.
File metadata
- Download URL: balanced_kmeans-0.1.0.tar.gz
- Upload date:
- Size: 5.1 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/3.1.1 pkginfo/1.5.0.1 requests/2.22.0 setuptools/46.0.0.post20200309 requests-toolbelt/0.8.0 tqdm/4.45.0 CPython/3.7.1
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | 5bde59536ff7bc90af3fa12b973ba93546dfa9555baeafb5af0087ef997e45f0 |
|
MD5 | 519001c5c5cc20452c00c8ab133f73d2 |
|
BLAKE2b-256 | 5a51320acd74f5b5955ea9b64ad32342ee5c88fada89fd1f73c702a3a3a85d8f |