Skip to main content

A minimal implementation of Gaussian Mixture Models in Jax

Project description

GMMX: Gaussian Mixture Models in Jax

Release Build status codecov Commit activity License DOI

GMMX Logo

A minimal implementation of Gaussian Mixture Models in Jax

Installation

gmmx can be installed via pip:

pip install gmmx

Usage

from gmmx import GaussianMixtureModelJax, EMFitter

# Create a Gaussian Mixture Model with 16 components and 32 features
gmm = GaussianMixtureModelJax.create(n_components=16, n_features=32)

# Draw samples from the model
n_samples = 10_000
x = gmm.sample(n_samples)

# Fit the model to the data
em_fitter = EMFitter(tol=1e-3, max_iter=100)
gmm_fitted = em_fitter.fit(x=x, gmm=gmm)

If you use the code in a scientific publication, please cite the Zenodo DOI from the badge above.

Why Gaussian Mixture models?

What are Gaussian Mixture Models (GMM) useful for in the age of deep learning? GMMs might have come out of fashion for classification tasks, but they still have a few properties that make them useful in certain scenarios:

  • They are universal approximators, meaning that given enough components they can approximate any distribution.
  • Their likelihood can be evaluated in closed form, which makes them useful for generative modeling.
  • They are rather fast to train and evaluate.

I would strongly recommend to read In Depth: Gaussian Mixture Models from the Python Data Science Handbook for a more in-depth introduction to GMMs and their application as density estimators.

One of these applications in my research is the context of image reconstruction, where GMMs can be used to model the distribution and pixel correlations of local (patch based) image features. This can be useful for tasks like image denoising or inpainting. One of these methods I have used them for is Jolideco. Speed up the training of O(10^6) patches was the main motivation for gmmx.

Benchmarks

Here are some results from the benchmarks in the examples/benchmarks folder comparing against Scikit-Learn. The benchmarks were run on an "Intel(R) Xeon(R) Gold 6338" CPU and a single "NVIDIA L40S" GPU.

Prediction Time

Time vs. Number of Components Time vs. Number of Samples Time vs. Number of Features
Time vs. Number of Components Time vs. Number of Samples Time vs. Number of Features

For prediction the speedup is around 5-6x for varying number of components and features and ~50x speedup on the GPU. For the number of samples the cross-over point is around O(10^3) samples.

Training Time

Time vs. Number of Components Time vs. Number of Samples Time vs. Number of Features
Time vs. Number of Components Time vs. Number of Samples Time vs. Number of Features

For training the speedup is around ~5-6x on the same architecture and ~50x speedup on the GPU. In the bechmark I have forced both fitters to evaluate exactly the same number of iterations. However in general there is no guarantee that GMMX converges to the same solution as Scikit-Learn. But there are some tests in the tests folder that compare the results of the two implementations which shows good agreement.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

gmmx-0.4.tar.gz (6.3 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

gmmx-0.4-py3-none-any.whl (4.0 kB view details)

Uploaded Python 3

File details

Details for the file gmmx-0.4.tar.gz.

File metadata

  • Download URL: gmmx-0.4.tar.gz
  • Upload date:
  • Size: 6.3 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.0.1 CPython/3.12.8

File hashes

Hashes for gmmx-0.4.tar.gz
Algorithm Hash digest
SHA256 774b91406a5e416c43ea2886b4510b1fdc1591264149f69b341e325c0c388dd5
MD5 ccf8f78d800e5a97a95c4acaaaff81b1
BLAKE2b-256 478da34870225aff87c6581fc83c895e630e05df3a5db46763b8f30df3e7b495

See more details on using hashes here.

File details

Details for the file gmmx-0.4-py3-none-any.whl.

File metadata

  • Download URL: gmmx-0.4-py3-none-any.whl
  • Upload date:
  • Size: 4.0 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.0.1 CPython/3.12.8

File hashes

Hashes for gmmx-0.4-py3-none-any.whl
Algorithm Hash digest
SHA256 d32c0d11f7f682e5035c402906ba7bf7fea856f86bb61f94a81265c3b8d6d8b7
MD5 9bac84bdb197ef2bca0ff4c43f43822b
BLAKE2b-256 4e85b42e1f46a22401ad560eb9169e3dab1ea947cd4f19bddc1b4b7d70955e90

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page