Skip to main content

A minimal implementation of Gaussian Mixture Models in Jax

Project description

GMMX: Gaussian Mixture Models in Jax

Release Build status codecov Commit activity License DOI

GMMX Logo

A minimal implementation of Gaussian Mixture Models in Jax

Installation

gmmx can be installed via pip:

pip install gmmx

Usage

from gmmx import GaussianMixtureModelJax, EMFitter

# Create a Gaussian Mixture Model with 16 components and 32 features
gmm = GaussianMixtureModelJax.create(n_components=16, n_features=32)

# Draw samples from the model
n_samples = 10_000
x = gmm.sample(n_samples)

# Fit the model to the data
em_fitter = EMFitter(tol=1e-3, max_iter=100)
gmm_fitted = em_fitter.fit(x=x, gmm=gmm)

If you use the code in a scientific publication, please cite the Zenodo DOI from the badge above.

Why Gaussian Mixture models?

What are Gaussian Mixture Models (GMM) useful for in the age of deep learning? GMMs might have come out of fashion for classification tasks, but they still have a few properties that make them useful in certain scenarios:

  • They are universal approximators, meaning that given enough components they can approximate any distribution.
  • Their likelihood can be evaluated in closed form, which makes them useful for generative modeling.
  • They are rather fast to train and evaluate.

I would strongly recommend to read In Depth: Gaussian Mixture Models from the Python Data Science Handbook for a more in-depth introduction to GMMs and their application as density estimators.

One of these applications in my research is the context of image reconstruction, where GMMs can be used to model the distribution and pixel correlations of local (patch based) image features. This can be useful for tasks like image denoising or inpainting. One of these methods I have used them for is Jolideco. Speed up the training of O(10^6) patches was the main motivation for gmmx.

Benchmarks

Here are some results from the benchmarks in the examples/benchmarks folder comparing against Scikit-Learn. The benchmarks were run on an "Intel(R) Xeon(R) Gold 6338" CPU and a single "NVIDIA L40S" GPU.

Prediction Time

Time vs. Number of Components Time vs. Number of Samples Time vs. Number of Features
Time vs. Number of Components Time vs. Number of Samples Time vs. Number of Features

For prediction the speedup is around 5-6x for varying number of components and features and ~50x speedup on the GPU. For the number of samples the cross-over point is around O(10^3) samples.

Training Time

Time vs. Number of Components Time vs. Number of Samples Time vs. Number of Features
Time vs. Number of Components Time vs. Number of Samples Time vs. Number of Features

For training the speedup is around ~5-6x on the same architecture and ~50x speedup on the GPU. In the bechmark I have forced both fitters to evaluate exactly the same number of iterations. However in general there is no guarantee that GMMX converges to the same solution as Scikit-Learn. But there are some tests in the tests folder that compare the results of the two implementations which shows good agreement.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

gmmx-0.6.tar.gz (16.7 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

gmmx-0.6-py3-none-any.whl (12.9 kB view details)

Uploaded Python 3

File details

Details for the file gmmx-0.6.tar.gz.

File metadata

  • Download URL: gmmx-0.6.tar.gz
  • Upload date:
  • Size: 16.7 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.1.0 CPython/3.12.8

File hashes

Hashes for gmmx-0.6.tar.gz
Algorithm Hash digest
SHA256 d2141f82e98fc785b430d9888a7a237f6b7c60a09cc4d53d80c7028ba3ae017d
MD5 d4cf160ef8af3de13277f443f532a9f9
BLAKE2b-256 a2f66df7b51119577f5290ad4db30d6b0eb5b57d5efc9818c1f65f1ec1a54c35

See more details on using hashes here.

File details

Details for the file gmmx-0.6-py3-none-any.whl.

File metadata

  • Download URL: gmmx-0.6-py3-none-any.whl
  • Upload date:
  • Size: 12.9 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.1.0 CPython/3.12.8

File hashes

Hashes for gmmx-0.6-py3-none-any.whl
Algorithm Hash digest
SHA256 48b189fd3cb65f3af9ac569b3ad4765909a1e9237ee9386c6548dfeb6f72e3d3
MD5 00eb2e2f541a4e941cc6076de08263ba
BLAKE2b-256 5088e8f0777eb023d80d615a9af1e4a933109630e27002134628c0f2f01f22be

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page