Skip to main content

Estimate mutual information distribution with Gaussian mixture models

Project description

GMM-MI

GMM-MI_logo

Welcome to GMM-MI (pronounced Jimmie)! This package allows you to calculate mutual information (MI) with its associated uncertainty, combining Gaussian mixture models (GMMs) and bootstrap. GMM-MI is accurate, computationally efficient and fully in python; you can read more about GMM-MI in our paper. Please cite it if you use it in your work!

Installation

To install GMM-MI, follow these steps:

  1. (optional) conda create -n gmm_mi python=3.9 jupyter (we recommend creating a custom conda environment)

  2. (optional) conda activate gmm_mi (activate it)

  3. Install GMM-MI:

     pip install gmm-mi
     python3 -c 'from gmm_mi.mi import EstimateMI'
    

    or alternatively, clone the repository and install it:

     git clone https://github.com/dpiras/GMM-MI.git
     cd GMM-MI
     pip install . 
     pytest 
    

The latter option will also give you access to Jupyter notebooks to get started with GMM-MI.

Usage

To use GMM-MI, you simply need to import the class EstimateMI, choose the hyperparameters and fit your data. You can find an example application of GMM-MI in the next section, and a more complete walkthrough, with common scenarios and possible pitfalls, in this notebook. A description of the hyperparameters that you can play with can be found here, and we discuss a few of them below.

Example

Once you installed GMM-MI, calculating the distribution of mutual information on your data is as easy as:

import numpy as np
from gmm_mi.mi import EstimateMI

# create simple bivariate Gaussian data
mean, cov = np.array([0, 0]), np.array([[1, 0.6], [0.6, 1]])
rng = np.random.default_rng(0)
X = rng.multivariate_normal(mean, cov, 200) # has shape (200, 2)
# calculate MI
mi_estimator = EstimateMI()
MI_mean, MI_std = mi_estimator.fit(X)

This yields (0.21 ± 0.04) nat, well in agreement with the theoretical value of 0.22 nat. If you want to visualize the fitted model over your input data, you can run:

import matplotlib.pyplot as plt
fig, ax = plt.subplots(1, 1, figsize=(11, 11))
# X is the array with the input data
ax.scatter(X[:, 0], X[:, 1], label='Input data')
fig, ax = mi_estimator.plot_fitted_model(ax=ax)

To choose the GMM-MI hyperparameters, we provide three classes: GMMFitParamHolder, SelectComponentsParamHolder, and MIDistParamHolder. An example is as follows:

from gmm_mi.param_holders import GMMFitParamHolder, SelectComponentsParamHolder, MIDistParamHolder

# parameters for every GMM fit that is being run
gmm_fit_params = GMMFitParamHolder(threshold_fit=1e-5, reg_covar=1e-15)
# parameters to choose the number of components
select_components_params = SelectComponentsParamHolder(n_inits=3, n_folds=2)
# parameters for MI distribution estimation
mi_dist_params = MIDistParamHolder(n_bootstrap=50, MC_samples=1e5)

mi_estimator = EstimateMI(gmm_fit_params=gmm_fit_params,
                          select_components_params=select_components_params,
                          mi_dist_params=mi_dist_params)
MI_mean, MI_std = mi_estimator.fit(X)

This is equivalent to the first example, and yields (0.21 ± 0.04) nat. More example notebooks, including all results from the paper, are available in notebooks.

Hyperparameter description

Here we report the most important hyperparameters that are used in GMM-MI.

(controlled by GMMFitParamHolder, passed as gmm_fit_params)
threshold_fit : float, default=1e-5
    The log-likelihood threshold on each GMM fit used to choose when to stop training. Smaller
    values will improve the fit quality and reduce the chances of stopping at a local optimum,
    while making the code considerably slower. This is equivalent to `tol` in sklearn GMMs.
    Note this parameter can be degenerate with `threshold_components`, and the two should be set
    together to reach a good density estimate of the data.
reg_covar : float, default=1e-15
    The constant term added to the diagonal of the covariance matrices to avoid singularities.
    Smaller values will increase the chances of singular matrices, but will have a smaller
    impact on the final MI estimates.

(controlled by ChooseComponentParamHolder, passed as select_component_params)
n_inits : int, default=3
    Number of initializations used to find the best initialization parameters. Higher
    values will decrease the chances of stopping at a local optimum, while making the
    code slower.
n_folds : int, default=2
    Number of folds in the cross-validation (CV) performed to find the best initialization
    parameters. As in every CV procedure, there is no best value. A good value, though,
    should ensure each fold has enough samples to be representative of your training set.
threshold_components : float, default=1e-5
    The metric threshold to decide when to stop adding GMM components. In other words, GMM-MI
    stops adding components either when the metric gets worse, or when the improvement in the
    metric value is less than this threshold. Smaller values ensure that enough components are
    considered and that the data distribution is correctly captured, while taking longer to converge.
    Note this parameter can be degenerate with `threshold_fit`, and the two should be set 
together to reach a good density estimate of the data.
patience : int, default=1 
    Number of extra components to "wait" until convergence is declared. Must be at least 1.
    Same concept as patience when training a neural network. Higher value will fit models
    with higher numbers of GMM components, while taking longer to converge.

(controlled by MIDistParamHolder, passed as mi_dist_params) 
n_bootstrap : int, default=50 
    Number of bootstrap realisations to consider to obtain the MI uncertainty.
    Higher values will return a better estimate of the MI uncertainty, and
    will make the MI distribution more Gaussian-like, but will take longer.
    If less than 1, do not perform bootstrap and actually just do a single 
    fit on the entire dataset; there will be no MI uncertainty in this case.
MC_samples : int, default=1e5
    Number of MC samples to use to estimate the MI integral. Only used if MI_method == 'MC'.
    Higher values will return less noisy estimates of MI, but will take longer.

Contributing and contacts

Feel free to fork this repository to work on it; otherwise, please raise an issue or contact Davide Piras.

Citation

If you use GMM-MI, please cite the corresponding paper:

 @article{TBC, 
    author = {TBC},
     title = {TBC},
   journal = {TBC},
    eprint = {TBC},
      year = {TBC}
 }

License

GMM-MI is released under the GPL-3 license - see LICENSE-, subject to the non-commercial use condition - see LICENSE_EXT.

 GMM-MI
 Copyright (C) 2022 Davide Piras & contributors

 This program is released under the GPL-3 license (see LICENSE.txt), 
 subject to a non-commercial use condition (see LICENSE_EXT.txt).

 This program is distributed in the hope that it will be useful,
 but WITHOUT ANY WARRANTY; without even the implied warranty of
 MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

gmm_mi-0.1.2.tar.gz (45.3 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

gmm_mi-0.1.2-py3-none-any.whl (47.1 kB view details)

Uploaded Python 3

File details

Details for the file gmm_mi-0.1.2.tar.gz.

File metadata

  • Download URL: gmm_mi-0.1.2.tar.gz
  • Upload date:
  • Size: 45.3 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/4.0.1 CPython/3.9.13

File hashes

Hashes for gmm_mi-0.1.2.tar.gz
Algorithm Hash digest
SHA256 1a98dc1564d33726da472d6799568f7485502fb2cf4c42d31e246e8df8a12486
MD5 7101356ebd70aab6f0a6fd72c0eebe30
BLAKE2b-256 82ee1c3caefca19569c90ad3e71a9cb03415d44283780981e816dc9632cc6433

See more details on using hashes here.

File details

Details for the file gmm_mi-0.1.2-py3-none-any.whl.

File metadata

  • Download URL: gmm_mi-0.1.2-py3-none-any.whl
  • Upload date:
  • Size: 47.1 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/4.0.1 CPython/3.9.13

File hashes

Hashes for gmm_mi-0.1.2-py3-none-any.whl
Algorithm Hash digest
SHA256 a2a58929d174a805731ef96d1ecd99200ae33f25ec6cf61c32a930f8845ed203
MD5 f36942c9315ac4618045260a0c3f1aee
BLAKE2b-256 85578a3a33b322921343b136b562c2cff0dd63ccb84d3bb3375ac299abe4a1a8

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page