Skip to main content

A Gaussian Mixture Model (GMM) based on Expectation-Maximisation (EM) implemented in PyTorch

Project description

TorchGMM: A Gaussian Mixture Model Implementation with PyTorch

TorchGMM logo

TorchGMM is a flexible implementation of Gaussian Mixture Models in PyTorch, supporting:

  • EM Algorithm
  • MAP Estimation with Priors
  • Multiple Covariance Types
  • Various Initialization Methods
  • Comprehensive Clustering & Evaluation Metrics

Features

  1. GaussianMixture

    • Full, diag, spherical, tied covariances
    • MLE or MAP estimation with weight, mean, or covariance priors
  2. GMMInitializer

    • kmeans, kpp (k-means++), random, points, maxdist
  3. ClusteringMetrics

    • Unsupervised metrics (Silhouette, Davies-Bouldin, etc.)
    • Supervised metrics (ARI, NMI, Purity, Confusion Matrix, etc.)

Installation

git clone https://github.com/YourUser/TorchGMM.git
cd TorchGMM
pip install -r requirements.txt

Make sure you have PyTorch installed. For GPU usage, install the CUDA-enabled version of PyTorch as per the official instructions.

Documentation

We use Sphinx to build documentation. The generated HTML pages live under docs/_build/html/. You can also read them online if you host them (e.g., on GitHub Pages).

cd docs
make clean
make html
# Open _build/html/index.html in a browser
# Linux
xdg-open _build/html/index.html 

The docs include:

API Reference for all modules (see GaussianMixture, GMMInitializer, and ClusteringMetrics). Tutorials that walk through different usage scenarios (basic GMM, metrics, using priors). Tutorials We provide three Jupyter notebooks in the tutorials/ folder:

GMM Tutorial: Basic usage of the GaussianMixture class. Metrics Tutorial: Demonstrates ClusteringMetrics and how to compare models. Priors Tutorial: Shows how to use weight/mean/covariance priors (MAP). To view or run them locally, just open them in Jupyter or VS Code.

Basic Usage Example

Here’s a short snippet:

import torch
from utils.gmm import GaussianMixture

# Generate random 2D data
X = torch.randn(500, 2)

# Create and fit the GMM
gmm = GaussianMixture(
    n_features=2,
    n_components=3,
    covariance_type='full',
    max_iter=200
)
gmm.fit(X)

print("Converged?:", gmm.converged_)
print("Cluster Weights:", gmm.weights_)
print("Cluster Means:", gmm.means_)

You can also run on GPU by specifying device='cuda' in the GaussianMixture constructor (assuming a CUDA-capable device).

Contributing

Fork the repository and create your feature branch. Make changes and add tests or notebooks as appropriate. Submit a pull request (PR) for review. We welcome improvements to both the code and the documentation.

License

Released under the MIT License. © 2025, Adrián A. Sousa-Poza

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

tgmm-0.1.7.tar.gz (26.4 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

tgmm-0.1.7-py3-none-any.whl (26.0 kB view details)

Uploaded Python 3

File details

Details for the file tgmm-0.1.7.tar.gz.

File metadata

  • Download URL: tgmm-0.1.7.tar.gz
  • Upload date:
  • Size: 26.4 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.1.0 CPython/3.13.2

File hashes

Hashes for tgmm-0.1.7.tar.gz
Algorithm Hash digest
SHA256 59dfc7bd7a264cbae0dc15e87fb409079d3b4206c0a06fc2e2b7becb9391afa8
MD5 9a1694d981d0af8655ebc99dbcc88dba
BLAKE2b-256 e6d6ba652d19b147d81293136934b7b5a0606256b395682cb404adb410964167

See more details on using hashes here.

File details

Details for the file tgmm-0.1.7-py3-none-any.whl.

File metadata

  • Download URL: tgmm-0.1.7-py3-none-any.whl
  • Upload date:
  • Size: 26.0 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.1.0 CPython/3.13.2

File hashes

Hashes for tgmm-0.1.7-py3-none-any.whl
Algorithm Hash digest
SHA256 a767db73754fdaa5f0e86c3ffc25a4bec3741031b70a614fe19b00ca997ebbaf
MD5 337cb93cd1ebf494f20f2ce14eaace39
BLAKE2b-256 7ee8767c2cce310420c5d7aeb511e78d8a7fe8ffd91be21b7044d6de38883006

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page