A Gaussian Mixture Model (GMM) based on Expectation-Maximisation (EM) implemented in PyTorch
Project description
TorchGMM: A Gaussian Mixture Model Implementation with PyTorch
TorchGMM is a flexible implementation of Gaussian Mixture Models in PyTorch, supporting:
- EM Algorithm
- MAP Estimation with Priors
- Multiple Covariance Types
- Various Initialization Methods
- Comprehensive Clustering & Evaluation Metrics
Features
-
GaussianMixture
- Full, diag, spherical, tied covariances
- MLE or MAP estimation with weight, mean, or covariance priors
-
GMMInitializer
kmeans,kpp(k-means++),random,points,maxdist
-
ClusteringMetrics
- Unsupervised metrics (Silhouette, Davies-Bouldin, etc.)
- Supervised metrics (ARI, NMI, Purity, Confusion Matrix, etc.)
Installation
git clone https://github.com/YourUser/TorchGMM.git
cd TorchGMM
pip install -r requirements.txt
Make sure you have PyTorch installed. For GPU usage, install the CUDA-enabled version of PyTorch as per the official instructions.
Documentation
We use Sphinx to build documentation. The generated HTML pages live under docs/_build/html/. You can also read them online if you host them (e.g., on GitHub Pages).
cd docs
make clean
make html
# Open _build/html/index.html in a browser
# Linux
xdg-open _build/html/index.html
The docs include:
API Reference for all modules (see GaussianMixture, GMMInitializer, and ClusteringMetrics). Tutorials that walk through different usage scenarios (basic GMM, metrics, using priors). Tutorials We provide three Jupyter notebooks in the tutorials/ folder:
GMM Tutorial: Basic usage of the GaussianMixture class. Metrics Tutorial: Demonstrates ClusteringMetrics and how to compare models. Priors Tutorial: Shows how to use weight/mean/covariance priors (MAP). To view or run them locally, just open them in Jupyter or VS Code.
Basic Usage Example
Here’s a short snippet:
import torch
from utils.gmm import GaussianMixture
# Generate random 2D data
X = torch.randn(500, 2)
# Create and fit the GMM
gmm = GaussianMixture(
n_features=2,
n_components=3,
covariance_type='full',
max_iter=200
)
gmm.fit(X)
print("Converged?:", gmm.converged_)
print("Cluster Weights:", gmm.weights_)
print("Cluster Means:", gmm.means_)
You can also run on GPU by specifying device='cuda' in the GaussianMixture constructor (assuming a CUDA-capable device).
Contributing
Fork the repository and create your feature branch. Make changes and add tests or notebooks as appropriate. Submit a pull request (PR) for review. We welcome improvements to both the code and the documentation.
License
Released under the MIT License. © 2025, Adrián A. Sousa-Poza
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file tgmm-0.1.7.tar.gz.
File metadata
- Download URL: tgmm-0.1.7.tar.gz
- Upload date:
- Size: 26.4 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.1.0 CPython/3.13.2
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
59dfc7bd7a264cbae0dc15e87fb409079d3b4206c0a06fc2e2b7becb9391afa8
|
|
| MD5 |
9a1694d981d0af8655ebc99dbcc88dba
|
|
| BLAKE2b-256 |
e6d6ba652d19b147d81293136934b7b5a0606256b395682cb404adb410964167
|
File details
Details for the file tgmm-0.1.7-py3-none-any.whl.
File metadata
- Download URL: tgmm-0.1.7-py3-none-any.whl
- Upload date:
- Size: 26.0 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.1.0 CPython/3.13.2
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
a767db73754fdaa5f0e86c3ffc25a4bec3741031b70a614fe19b00ca997ebbaf
|
|
| MD5 |
337cb93cd1ebf494f20f2ce14eaace39
|
|
| BLAKE2b-256 |
7ee8767c2cce310420c5d7aeb511e78d8a7fe8ffd91be21b7044d6de38883006
|