Skip to main content

KMeans-GPU: A PyTorch Module for KMeans.

Project description

kmeans-gpu

kmeans-gpu with pytorch (batch version). It is faster than sklearn.cluster.KMeans. What's more, it is a differential operation which will back-propagate gradient to previous layers.

You can easily use KMeans as a nn.Module, and embed into your network structure.

Install

  1. From Git:
git clone git@github.com:densechen/kmeans-gpu.git
cd kmeans-gpu
pip install -r requirements.txt
python setup.py install

# check installation
python -c "import kmeans_gpu; print(kmeans_gpu.__version__)"
  1. From PyPI:
pip install kmeans-gpu

# check installation
python -c "import kmeans_gpu; print(kmeans_gpu.__version__)"

Demo

from kmeans_gpu import KMeans
import torch

# Config
batch_size = 128
feature_dim = 1024
pts_dim = 3
num_pts = 256
num_cluster = 15

# Create data
features = torch.randn(batch_size, feature_dim, num_pts)
# Pay attention to the different dimension order between features and points.
points = torch.randn(batch_size, num_pts, pts_dim)

# Create KMeans Module
kmeans = KMeans(
    n_clusters=num_cluster,
    max_iter=100,
    tolerance=1e-4,
    distance='euclidean',
    sub_sampling=None,
    max_neighbors=15,
)

# Forward
centroids, features = kmeans(points, features)

print(centroids.shape, features.shape)
# output: 
# >>> torch.Size([128, 15, 3]) torch.Size([128, 1024, 15])

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

kmeans_gpu-0.0.5.tar.gz (5.1 kB view details)

Uploaded Source

Built Distribution

kmeans_gpu-0.0.5-py3-none-any.whl (6.2 kB view details)

Uploaded Python 3

File details

Details for the file kmeans_gpu-0.0.5.tar.gz.

File metadata

  • Download URL: kmeans_gpu-0.0.5.tar.gz
  • Upload date:
  • Size: 5.1 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/3.4.1 importlib_metadata/4.5.0 pkginfo/1.7.0 requests/2.24.0 requests-toolbelt/0.9.1 tqdm/4.61.1 CPython/3.7.10

File hashes

Hashes for kmeans_gpu-0.0.5.tar.gz
Algorithm Hash digest
SHA256 fff62d46e3f1167bf22b144cb168c1471c5433942319e4e7eaa47ee3f1f4db37
MD5 9fbede0c9a0b736a453c53911b4a106d
BLAKE2b-256 87c3816436cb7fd4e487cd190b53e85b42c3e2e4fdd8f0b29f7612e7feea97e1

See more details on using hashes here.

File details

Details for the file kmeans_gpu-0.0.5-py3-none-any.whl.

File metadata

  • Download URL: kmeans_gpu-0.0.5-py3-none-any.whl
  • Upload date:
  • Size: 6.2 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/3.4.1 importlib_metadata/4.5.0 pkginfo/1.7.0 requests/2.24.0 requests-toolbelt/0.9.1 tqdm/4.61.1 CPython/3.7.10

File hashes

Hashes for kmeans_gpu-0.0.5-py3-none-any.whl
Algorithm Hash digest
SHA256 1e7a79be558d507f4c8e3c892527df3832817020c649dd26a2a0cadd6bc2cb46
MD5 623a5fb9dd27b553de5d0c1153cf9597
BLAKE2b-256 6c7534e4182b4f1512b7ec48d25b56368a29896db9c556b160088064836189df

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page