Skip to main content

PyTorch implementations of KMeans, Soft-KMeans and Constrained-KMeans which can be run on GPU and work on (mini-)batches of data.

Project description

pyversions wheel Latest Version ReadTheDocs torch_kmeans-logo

torch_kmeans

PyTorch implementations of KMeans, Soft-KMeans and Constrained-KMeans

torch_kmeans features implementations of the well known k-means algorithm as well as its soft and constrained variants.

All algorithms are completely implemented as PyTorch modules and can be easily incorporated in a PyTorch pipeline or model. Therefore, they support execution on GPU as well as working on (mini-)batches of data. Moreover, they also provide a scikit-learn style interface featuring

model.fit(), model.predict() and model.fit_predict()

functions.

-> view official documentation

Highlights

  • Fully implemented in PyTorch.

  • GPU support like native PyTorch.

  • PyTorch script JIT compiled for most performance sensitive parts.

  • Works with mini-batches of samples:
    • each instance can have a different number of clusters.

  • Constrained Kmeans works with cluster constraints like:
    • a max number of samples per cluster or,

    • a maximum weight per cluster, where each sample has an associated weight.

  • SoftKMeans is a fully differentiable clustering procedure and can readily be used in a PyTorch neural network model which requires backpropagation.

  • Unit tested against the scikit-learn KMeans implementation.

  • GPU execution enables very fast computation even for large batch size or very high dimensional feature spaces (see speed comparison)

Installation

Simply install from PyPI

pip install torch-kmeans

Usage

Pytorch style usage

import torch
from torch_kmeans import KMeans

model = KMeans(n_clusters=4)

x = torch.randn((4, 20, 2))   # (BS, N, D)
result = model(x)
print(result.labels)

Scikit-learn style usage

import torch
from torch_kmeans import KMeans

model = KMeans(n_clusters=4)

x = torch.randn((4, 20, 2))   # (BS, N, D)
model = model.fit(x)
labels = model.predict(x)
print(labels)

or

import torch
from torch_kmeans import KMeans

model = KMeans(n_clusters=4)

x = torch.randn((4, 20, 2))   # (BS, N, D)
labels = model.fit_predict(x)
print(labels)

Examples

You can find more examples and usage in the detailed example notebooks.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

torch_kmeans-0.1.1.tar.gz (73.0 kB view details)

Uploaded Source

Built Distribution

torch_kmeans-0.1.1-py3-none-any.whl (23.0 kB view details)

Uploaded Python 3

File details

Details for the file torch_kmeans-0.1.1.tar.gz.

File metadata

  • Download URL: torch_kmeans-0.1.1.tar.gz
  • Upload date:
  • Size: 73.0 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/4.0.0 CPython/3.9.12

File hashes

Hashes for torch_kmeans-0.1.1.tar.gz
Algorithm Hash digest
SHA256 9c8207abaf9c056fa81fe4390aaa3aab2177b19d96d2f1183b4155d056346337
MD5 2a2fe5b4d3145682a7a66f278163679b
BLAKE2b-256 fdfde71ea26bb89ef16f979aee1b881c000cdf7e8aa1b15493fad67c58d8c0b4

See more details on using hashes here.

File details

Details for the file torch_kmeans-0.1.1-py3-none-any.whl.

File metadata

  • Download URL: torch_kmeans-0.1.1-py3-none-any.whl
  • Upload date:
  • Size: 23.0 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/4.0.0 CPython/3.9.12

File hashes

Hashes for torch_kmeans-0.1.1-py3-none-any.whl
Algorithm Hash digest
SHA256 adcbaede90b127575ecf7f806c1e1ebff122705314f0d99d6de30f9319521cc0
MD5 7ec07069772e7ef1da8f642c8ce2ac00
BLAKE2b-256 1b58c69fbf5e3996400928a1e2a80199703cf0972e30fab4cc1f332481c473ec

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page