PyTorch implementations of KMeans, Soft-KMeans and Constrained-KMeans which can be run on GPU and work on (mini-)batches of data.
Project description
torch_kmeans
PyTorch implementations of KMeans, Soft-KMeans and Constrained-KMeans
torch_kmeans features implementations of the well known k-means algorithm as well as its soft and constrained variants.
All algorithms are completely implemented as PyTorch modules and can be easily incorporated in a PyTorch pipeline or model. Therefore, they support execution on GPU as well as working on (mini-)batches of data. Moreover, they also provide a scikit-learn style interface featuring
model.fit(), model.predict() and model.fit_predict()
functions.
-> view official documentation
Highlights
Fully implemented in PyTorch.
GPU support like native PyTorch.
PyTorch script JIT compiled for most performance sensitive parts.
- Works with mini-batches of samples:
each instance can have a different number of clusters.
- Constrained Kmeans works with cluster constraints like:
a max number of samples per cluster or,
a maximum weight per cluster, where each sample has an associated weight.
SoftKMeans is a fully differentiable clustering procedure and can readily be used in a PyTorch neural network model which requires backpropagation.
Unit tested against the scikit-learn KMeans implementation.
GPU execution enables very fast computation even for large batch size or very high dimensional feature spaces (see speed comparison)
Installation
Simply install from PyPI
pip install torch-kmeans
Usage
Pytorch style usage
import torch
from torch_kmeans import KMeans
model = KMeans(n_clusters=4)
x = torch.randn((4, 20, 2)) # (BS, N, D)
result = model(x)
print(result.labels)
Scikit-learn style usage
import torch
from torch_kmeans import KMeans
model = KMeans(n_clusters=4)
x = torch.randn((4, 20, 2)) # (BS, N, D)
model = model.fit(x)
labels = model.predict(x)
print(labels)
or
import torch
from torch_kmeans import KMeans
model = KMeans(n_clusters=4)
x = torch.randn((4, 20, 2)) # (BS, N, D)
labels = model.fit_predict(x)
print(labels)
Examples
You can find more examples and usage in the detailed example notebooks.
Project details
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
File details
Details for the file torch_kmeans-0.1.0.tar.gz
.
File metadata
- Download URL: torch_kmeans-0.1.0.tar.gz
- Upload date:
- Size: 73.0 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/4.0.0 CPython/3.9.12
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | 1642c0639b56934e621be1ae033a6921c8d521b39cafb4138a28e75eaecee79b |
|
MD5 | 241a94eed532a277a8e3abed4dacdbbe |
|
BLAKE2b-256 | 7cf80ebcb0ef0d13c3196e27e7e9d56b99c9e4228b326767166350f009e9b980 |
File details
Details for the file torch_kmeans-0.1.0-py3-none-any.whl
.
File metadata
- Download URL: torch_kmeans-0.1.0-py3-none-any.whl
- Upload date:
- Size: 23.1 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/4.0.0 CPython/3.9.12
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | 7969beb8884a60553761b0a1bba4e60e7f126a8e1f4dee5e17b827f8c3ab8f64 |
|
MD5 | b07ac61fd4684197e529620633a34c2e |
|
BLAKE2b-256 | 9169b61b7f50227e5678a7fe06a3ff90d01e80dba27171bb41ac1f3fb3cfbdd0 |