Skip to main content

Run Gaussian Mixture Models on single or multiple CPUs/GPUs

Project description

TorchGMM

Tests Documentation

TorchGMM allows to run Gaussian Mixture Models on single or multiple CPUs/GPUs. The repository is a fork from PyCave and LightKit, two amazing packages developed by Oliver Borchert that are not being maintained anymore. While PyCave implements additional models such as Markov Chains, TorchGMM focuses only on Gaussian Mixture Models.

The models are implemented in PyTorch and PyTorch Lightning, and provide an Estimator API that is fully compatible with scikit-learn.

For Gaussian mixture model, TorchGMM allows for 100x speed ups when using a GPU and enables to train on markedly larger datasets via mini-batch training. The full suite of benchmarks run to compare TorchGMM models against scikit-learn models is available on the documentation website.

Features

  • Support for GPU and multi-node training by implementing models in PyTorch and relying on PyTorch Lightning
  • Mini-batch training for all models such that they can be used on huge datasets
  • Well-structured implementation of models
    • High-level Estimator API allows for easy usage such that models feel and behave like in scikit-learn
    • Medium-level LightingModule implements the training algorithm
    • Low-level PyTorch Module manages the model parameters

Getting started

Please refer to the documentation. In particular, the API documentation

Requirements

TorchGMM requires PyTorch to be installed. The installation instructions can be found on the PyTorch website.

TorchGMM is available via pip:

pip install torchgmm

Usage

If you've ever used scikit-learn, you'll feel right at home when using TorchGMM. First, let's create some artificial data to work with:

import torch

X = torch.cat([
    torch.randn(10000, 8) - 5,
    torch.randn(10000, 8),
    torch.randn(10000, 8) + 5,
])

This dataset consists of three clusters with 8-dimensional datapoints. If you want to fit a K-Means model, to find the clusters' centroids, it's as easy as:

from torchgmm.clustering import KMeans

estimator = KMeans(3)
estimator.fit(X)

# Once the estimator is fitted, it provides various properties. One of them is
# the `model_` property which yields the PyTorch module with the fitted parameters.
print("Centroids are:")
print(estimator.model_.centroids)

Due to the high-level estimator API, the usage for all machine learning models is similar. The API documentation provides more detailed information about parameters that can be passed to estimators and which methods are available.

GPU and Multi-Node training

For GPU- and multi-node training, TorchGMM leverages PyTorch Lightning. The hardware that training runs on is determined by the Trainer class. It's init method provides various configuration options.

If you want to run K-Means with a GPU, you can pass the options accelerator='gpu' and devices=1 to the estimator's initializer:

estimator = KMeans(3, trainer_params=dict(accelerator='gpu', devices=1))

Similarly, if you want to train on 4 nodes simultaneously where each node has one GPU available, you can specify this as follows:

estimator = KMeans(3, trainer_params=dict(num_nodes=4, accelerator='gpu', devices=1))

In fact, you do not need to change anything else in your code.

Implemented Models

Currently, TorchGMM implements two different models:

Contribution

If you found a bug or you want to propose a new feature, please use the issue tracker.

License

TorchGMM is licensed under the MIT License.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

torchgmm-0.1.4.tar.gz (48.7 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

torchgmm-0.1.4-py3-none-any.whl (46.1 kB view details)

Uploaded Python 3

File details

Details for the file torchgmm-0.1.4.tar.gz.

File metadata

  • Download URL: torchgmm-0.1.4.tar.gz
  • Upload date:
  • Size: 48.7 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: twine/6.1.0 CPython/3.13.7

File hashes

Hashes for torchgmm-0.1.4.tar.gz
Algorithm Hash digest
SHA256 c874880fe653f17ac75db0eb06052793757b00fa37109b05cf1d331f98bcc031
MD5 e8f0f832945efcadc3a3a746dfb01334
BLAKE2b-256 d92e07f4db21a7d0712c8a48d5b1bc4489e791480a935471a0ef99318e38f3c4

See more details on using hashes here.

Provenance

The following attestation bundles were made for torchgmm-0.1.4.tar.gz:

Publisher: release.yaml on CSOgroup/torchgmm

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

File details

Details for the file torchgmm-0.1.4-py3-none-any.whl.

File metadata

  • Download URL: torchgmm-0.1.4-py3-none-any.whl
  • Upload date:
  • Size: 46.1 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: twine/6.1.0 CPython/3.13.7

File hashes

Hashes for torchgmm-0.1.4-py3-none-any.whl
Algorithm Hash digest
SHA256 b4e6c77cb1ec0ef3000edc40f217fa6d0e3b9d37cb7266694b13480b3a9e8901
MD5 7468816f410f8994b336f10b0f2ae99e
BLAKE2b-256 6f2103a0086804c3921f24404a3865ec67c12b4f5cfbdac477ae0f42d7ec5af3

See more details on using hashes here.

Provenance

The following attestation bundles were made for torchgmm-0.1.4-py3-none-any.whl:

Publisher: release.yaml on CSOgroup/torchgmm

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page