Skip to main content

A minimal implementation of Gaussian Mixture Models in Jax

Project description

GMMX: Gaussian Mixture Models in Jax

Release Build status codecov Commit activity License DOI

GMMX Logo

A minimal implementation of Gaussian Mixture Models in Jax

Installation

gmmx can be installed via pip:

pip install gmmx

Usage

from gmmx import GaussianMixtureModelJax, EMFitter

# Create a Gaussian Mixture Model with 16 components and 32 features
gmm = GaussianMixtureModelJax.create(n_components=16, n_features=32)

# Draw samples from the model
n_samples = 10_000
x = gmm.sample(n_samples)

# Fit the model to the data
em_fitter = EMFitter(tol=1e-3, max_iter=100)
gmm_fitted = em_fitter.fit(x=x, gmm=gmm)

If you use the code in a scientific publication, please cite the Zenodo DOI from the badge above.

Why Gaussian Mixture models?

What are Gaussian Mixture Models (GMM) useful for in the age of deep learning? GMMs might have come out of fashion for classification tasks, but they still have a few properties that make them useful in certain scenarios:

  • They are universal approximators, meaning that given enough components they can approximate any distribution.
  • Their likelihood can be evaluated in closed form, which makes them useful for generative modeling.
  • They are rather fast to train and evaluate.

I would strongly recommend to read In Depth: Gaussian Mixture Models from the Python Data Science Handbook for a more in-depth introduction to GMMs and their application as density estimators.

One of these applications in my research is the context of image reconstruction, where GMMs can be used to model the distribution and pixel correlations of local (patch based) image features. This can be useful for tasks like image denoising or inpainting. One of these methods I have used them for is Jolideco. Speed up the training of O(10^6) patches was the main motivation for gmmx.

Benchmarks

Here are some results from the benchmarks in the examples/benchmarks folder comparing against Scikit-Learn. The benchmarks were run on a 2021 MacBook Pro with an M1 Pro chip.

Prediction

Time vs. Number of Components Time vs. Number of Samples Time vs. Number of Features
Time vs. Number of Components Time vs. Number of Samples Time vs. Number of Features

For prediction the speedup is around 2x for varying number of components and features. For the number of samples the cross-over point is around O(10^4) samples.

Training Time

Time vs. Number of Components Time vs. Number of Samples Time vs. Number of Features
Time vs. Number of Components Time vs. Number of Samples Time vs. Number of Features

For training the speedup is around 10x on the same architecture. However there is no guarantee that it will converge to the same solution as Scikit-Learn. But there are some tests in the tests folder that compare the results of the two implementations.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

gmmx-0.2.tar.gz (6.0 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

gmmx-0.2-py3-none-any.whl (3.9 kB view details)

Uploaded Python 3

File details

Details for the file gmmx-0.2.tar.gz.

File metadata

  • Download URL: gmmx-0.2.tar.gz
  • Upload date:
  • Size: 6.0 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.0.1 CPython/3.12.8

File hashes

Hashes for gmmx-0.2.tar.gz
Algorithm Hash digest
SHA256 6de5c86fbe2ceadcee3eaebcb28ba07061a305338a3f4f4cb2dc2f8cc35c2fa6
MD5 4f9f37d51bc881a583bd6de71310b8f3
BLAKE2b-256 a8fa4423406f56957874c3ab56873fb5bec0a8fde5de240b885313fac9f9c0f2

See more details on using hashes here.

File details

Details for the file gmmx-0.2-py3-none-any.whl.

File metadata

  • Download URL: gmmx-0.2-py3-none-any.whl
  • Upload date:
  • Size: 3.9 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.0.1 CPython/3.12.8

File hashes

Hashes for gmmx-0.2-py3-none-any.whl
Algorithm Hash digest
SHA256 86551b321ea90cb2816ea1241997e9e30d5cbd0ef27d6ca9f4f8e8ed69cc8a0f
MD5 1d0d06f52355eee430d35aceabfc177c
BLAKE2b-256 2c451b25ac95dbaab55bfb3c85e2bb4d18095e9d718e862fcd44610e90bdb987

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page