Skip to main content

A Gaussian Mixture Model (GMM) based on Expectation-Maximisation (EM) implemented in PyTorch

Project description

TorchGMM: A Gaussian Mixture Model Implementation with PyTorch

TorchGMM logo

TorchGMM is a flexible implementation of Gaussian Mixture Models in PyTorch, supporting:

  • EM Algorithm
  • MAP Estimation with Priors
  • Multiple Covariance Types
  • Various Initialization Methods
  • Comprehensive Clustering & Evaluation Metrics

Features

  1. GaussianMixture

    • Full, diag, spherical, tied covariances
    • MLE or MAP estimation with weight, mean, or covariance priors
  2. GMMInitializer

    • kmeans, kpp (k-means++), random, points, maxdist
  3. ClusteringMetrics

    • Unsupervised metrics (Silhouette, Davies-Bouldin, etc.)
    • Supervised metrics (ARI, NMI, Purity, Confusion Matrix, etc.)

Installation

git clone https://github.com/YourUser/TorchGMM.git
cd TorchGMM
pip install -r requirements.txt

Make sure you have PyTorch installed. For GPU usage, install the CUDA-enabled version of PyTorch as per the official instructions.

Documentation

We use Sphinx to build documentation. The generated HTML pages live under docs/_build/html/. You can also read them online if you host them (e.g., on GitHub Pages).

cd docs
make clean
make html
# Open _build/html/index.html in a browser
# Linux
xdg-open _build/html/index.html 

The docs include:

API Reference for all modules (see GaussianMixture, GMMInitializer, and ClusteringMetrics). Tutorials that walk through different usage scenarios (basic GMM, metrics, using priors). Tutorials We provide three Jupyter notebooks in the tutorials/ folder:

GMM Tutorial: Basic usage of the GaussianMixture class. Metrics Tutorial: Demonstrates ClusteringMetrics and how to compare models. Priors Tutorial: Shows how to use weight/mean/covariance priors (MAP). To view or run them locally, just open them in Jupyter or VS Code.

Basic Usage Example

Here’s a short snippet:

import torch
from utils.gmm import GaussianMixture

# Generate random 2D data
X = torch.randn(500, 2)

# Create and fit the GMM
gmm = GaussianMixture(
    n_features=2,
    n_components=3,
    covariance_type='full',
    max_iter=200
)
gmm.fit(X)

print("Converged?:", gmm.converged_)
print("Cluster Weights:", gmm.weights_)
print("Cluster Means:", gmm.means_)

You can also run on GPU by specifying device='cuda' in the GaussianMixture constructor (assuming a CUDA-capable device).

Contributing

Fork the repository and create your feature branch. Make changes and add tests or notebooks as appropriate. Submit a pull request (PR) for review. We welcome improvements to both the code and the documentation.

License

Released under the MIT License. © 2025, Adrián A. Sousa-Poza

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

tgmm-0.1.5.tar.gz (24.7 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

tgmm-0.1.5-py3-none-any.whl (23.9 kB view details)

Uploaded Python 3

File details

Details for the file tgmm-0.1.5.tar.gz.

File metadata

  • Download URL: tgmm-0.1.5.tar.gz
  • Upload date:
  • Size: 24.7 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.1.0 CPython/3.12.8

File hashes

Hashes for tgmm-0.1.5.tar.gz
Algorithm Hash digest
SHA256 4bd1ea661d04357fec6c7455beebfe9c596deeef9445a8e79c3bd79f519c13da
MD5 e8b169d90b63fe7753402efb52be2c9d
BLAKE2b-256 350b88ce5d5cbc8fa7791a487523b3523aadf0e932eac17e0433275ecdece76b

See more details on using hashes here.

File details

Details for the file tgmm-0.1.5-py3-none-any.whl.

File metadata

  • Download URL: tgmm-0.1.5-py3-none-any.whl
  • Upload date:
  • Size: 23.9 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.1.0 CPython/3.12.8

File hashes

Hashes for tgmm-0.1.5-py3-none-any.whl
Algorithm Hash digest
SHA256 943d3d08162effafa30c13cf3e8cd7f55b2f38b396b20d209be95c61c35dbbe3
MD5 6f71c62bec15301a2dabca87a0a1e507
BLAKE2b-256 f5f0c7762df15e930fceff076448a567ee32aa2f251c14d81d6792d0c1dbb25e

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page