Skip to main content

Estimate mutual information distribution with Gaussian mixture models

Project description

GMM-MI

GMM-MI_logo

Welcome to GMM-MI (pronounced Jimmie)! This package allows you to calculate mutual information (MI) with its associated uncertainty, combining Gaussian mixture models (GMMs) and bootstrap. GMM-MI is accurate, computationally efficient and fully in python; you can read more about GMM-MI in our paper, published in Machine Learning: Science and Technology. Please cite it if you use it in your work! Check out also the poster accepted at the Machine Learning and the Physical Sciences workshop at NeurIPS 2022, and the accompanying video.

Installation

To install GMM-MI, follow these steps:

  1. (optional) conda create -n gmm_mi python=3.9 jupyter (create a custom conda environment with python 3.9)

  2. (optional) conda activate gmm_mi (activate it)

  3. Install GMM-MI:

     pip install gmm-mi
     python3 -c 'from gmm_mi.mi import EstimateMI'
    

    or alternatively, clone the repository and install it:

     git clone https://github.com/dpiras/GMM-MI.git
     cd GMM-MI
     pip install . 
     pytest 
    

The latter option will also give you access to Jupyter notebooks to get started with GMM-MI. Note that all experiments were run with python==3.9, and that GMM-MI requires at least python>=3.8.

Usage

To use GMM-MI, you simply need to import the class EstimateMI, choose the hyperparameters and fit your data. You can find an example application of GMM-MI in the next section, and a more complete walkthrough, with common scenarios and possible pitfalls, in this notebook. A description of the hyperparameters that you can play with can be found here, and we discuss a few of them below.

Example

Once you installed GMM-MI, calculating the distribution of mutual information on your data is as easy as:

import numpy as np
from gmm_mi.mi import EstimateMI

# create simple bivariate Gaussian data
mean, cov = np.array([0, 0]), np.array([[1, 0.6], [0.6, 1]])
rng = np.random.default_rng(0)
X = rng.multivariate_normal(mean, cov, 200) # has shape (200, 2)
# calculate MI
mi_estimator = EstimateMI()
MI_mean, MI_std = mi_estimator.fit_estimate(X)

This yields (0.21 ± 0.04) nat, well in agreement with the theoretical value of 0.22 nat. There are many things that you can do: for example, you can also pass two 1D arrays instead of a single 2D array, and even calculate the KL divergence between the marginals (as shown in the walkthrough notebook). If you want to visualize the fitted model over your input data, you can run:

import matplotlib.pyplot as plt
fig, ax = plt.subplots(1, 1, figsize=(10, 10))
# X is the array with the input data
ax.scatter(X[:, 0], X[:, 1], label='Input data')
# the extra arguments can be changed
ax = mi_estimator.plot_fitted_model(ax=ax, color='salmon', alpha=0.8, linewidth=4)
ax.tick_params(axis='both', which='both', labelsize=20)
ax.set_xlabel('X1', fontsize=30)
ax.set_ylabel('X2', fontsize=30)
ax.legend(fontsize=25, frameon=False)    

You can also draw contour plots for the input data and samples obtained from the fitted model. For example (smoothness of the contour plot heavily depends on the number of samples available):

fig = mi_estimator.plot_fitted_contours(parameters=['X1', 'X2'], 
                                    shade_alpha=0.4, linewidths=2, 
                                    legend_kwargs = {'loc': 'lower right'},
                                    kde=True, # smooths contours; set this to False to accelerate plotting
                                    )
fig.set_size_inches(7, 7)

To choose the GMM-MI hyperparameters, we provide three classes: GMMFitParamHolder, SelectComponentsParamHolder, and MIDistParamHolder. An example is as follows:

from gmm_mi.param_holders import GMMFitParamHolder, SelectComponentsParamHolder, MIDistParamHolder

# parameters for every GMM fit that is being run
gmm_fit_params = GMMFitParamHolder(threshold_fit=1e-5, reg_covar=1e-15)
# parameters to choose the number of components
select_components_params = SelectComponentsParamHolder(n_inits=3, n_folds=2)
# parameters for MI distribution estimation
mi_dist_params = MIDistParamHolder(n_bootstrap=50, MC_samples=1e5)

mi_estimator = EstimateMI(gmm_fit_params=gmm_fit_params,
                          select_components_params=select_components_params)
mi_estimator.fit(X)
MI_mean, MI_std = mi_estimator.estimate(mi_dist_params=mi_dist_params)

This is equivalent to the first example, and yields (0.21 ± 0.04) nat. More example notebooks, including conditional mutual information and all results from the paper, are available in notebooks.

Hyperparameter description

Here we report the most important hyperparameters that are used in GMM-MI.

(controlled by GMMFitParamHolder, passed as gmm_fit_params)
threshold_fit : float, default=1e-5
    The log-likelihood threshold on each GMM fit used to choose when to stop training. Smaller
    values will improve the fit quality and reduce the chances of stopping at a local optimum,
    while making the code considerably slower. This is equivalent to `tol` in sklearn GMMs.
    Note this parameter can be degenerate with `threshold_components`, and the two should be set
    together to reach a good density estimate of the data.
reg_covar : float, default=1e-15
    The constant term added to the diagonal of the covariance matrices to avoid singularities.
    Smaller values will increase the chances of singular matrices, but will have a smaller
    impact on the final MI estimates.
init_type : {'random', 'minmax', 'kmeans', 'randomized_kmeans', 'random_sklearn', 
             'kmeans_sklearn'}, default='random_sklearn'
    The method used to initialize the weights, the means, the covariances and the precisions
    in each fit during cross-validation. See utils.initializations for more details.
scale : float, default=None
    The scale used for 'random', 'minmax' and 'randomized_kmeans' initializations. 
    This hyperparameter is not used in all other cases, but it is useful if you roughly know 
    in advance the scale of your data, and can accelerate convergence.
    See utils.initializations for more details.

(controlled by ChooseComponentParamHolder, passed as select_component_params)
n_inits : int, default=3
    Number of initializations used to find the best initialization parameters. Higher
    values will decrease the chances of stopping at a local optimum, while making the
    code slower.
n_folds : int, default=2
    Number of folds in the cross-validation (CV) performed to find the best initialization
    parameters. As in every CV procedure, there is no best value. A good value, though,
    should ensure each fold has enough samples to be representative of your training set.
threshold_components : float, default=1e-5
    The metric threshold to decide when to stop adding GMM components. In other words, GMM-MI
    stops adding components either when the metric gets worse, or when the improvement in the
    metric value is less than this threshold. Smaller values ensure that enough components are
    considered and that the data distribution is correctly captured, while taking longer to converge.
    Note this parameter can be degenerate with `threshold_fit`, and the two should be set 
    together to reach a good density estimate of the data.
patience : int, default=1 
    Number of extra components to "wait" until convergence is declared. Must be at least 1.
    Same concept as patience when training a neural network. Higher value will fit models
    with higher numbers of GMM components, while taking longer to converge.

(controlled by MIDistParamHolder, passed as mi_dist_params to the `estimate` method) 
n_bootstrap : int, default=50 
    Number of bootstrap realisations to consider to obtain the MI uncertainty.
    Higher values will return a better estimate of the MI uncertainty, and
    will make the MI distribution more Gaussian-like, but will take longer.
    If less than 1, do not perform bootstrap and actually just do a single 
    fit on the entire dataset; there will be no MI uncertainty in this case.
MC_samples : int, default=1e5
    Number of MC samples to use to estimate the MI integral. Only used if MI_method == 'MC'.
    Higher values will return less noisy estimates of MI, but will take longer.

Contributing and contacts

Feel free to fork this repository to work on it; otherwise, please raise an issue or contact Davide Piras.

Citation

If you use GMM-MI, please cite the corresponding paper:

 @article{Piras23, 
      author = {Davide Piras and Hiranya V Peiris and Andrew Pontzen and 
                Luisa Lucie-Smith and Ningyuan Guo and Brian Nord},
      title = {A robust estimator of mutual information for deep learning interpretability},
      journal = {Machine Learning: Science and Technology},
      doi = {10.1088/2632-2153/acc444},
      url = {https://dx.doi.org/10.1088/2632-2153/acc444},
      year = {2023},
      month = {apr},
      publisher = {IOP Publishing},
      volume = {4},
      number = {2},
      pages = {025006}
}

License

GMM-MI is released under the GPL-3 license - see LICENSE-, subject to the non-commercial use condition - see LICENSE_EXT.

 GMM-MI
 Copyright (C) 2022 Davide Piras & contributors

 This program is released under the GPL-3 license (see LICENSE.txt), 
 subject to a non-commercial use condition (see LICENSE_EXT.txt).

 This program is distributed in the hope that it will be useful,
 but WITHOUT ANY WARRANTY; without even the implied warranty of
 MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

gmm_mi-0.8.0.tar.gz (51.5 kB view details)

Uploaded Source

Built Distribution

gmm_mi-0.8.0-py3-none-any.whl (52.3 kB view details)

Uploaded Python 3

File details

Details for the file gmm_mi-0.8.0.tar.gz.

File metadata

  • Download URL: gmm_mi-0.8.0.tar.gz
  • Upload date:
  • Size: 51.5 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/4.0.1 CPython/3.9.13

File hashes

Hashes for gmm_mi-0.8.0.tar.gz
Algorithm Hash digest
SHA256 3e6f9b237d1e48313829706851af015178f47a1e64951a704b568024632b26dc
MD5 bdf0c452dc04fbae1dbba8f8cf8e065c
BLAKE2b-256 c64511af68b0fc12887ed0c1bcc04b3b56bdecb0506058c473bb9550713f8d00

See more details on using hashes here.

File details

Details for the file gmm_mi-0.8.0-py3-none-any.whl.

File metadata

  • Download URL: gmm_mi-0.8.0-py3-none-any.whl
  • Upload date:
  • Size: 52.3 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/4.0.1 CPython/3.9.13

File hashes

Hashes for gmm_mi-0.8.0-py3-none-any.whl
Algorithm Hash digest
SHA256 ecf009aee6969229cbb601a7725581cd32016eb17d1a2e0dac618beb59561018
MD5 1313230e906e0c85a0b547ce5e75cdef
BLAKE2b-256 df4c3cee87f41dc04f68b1a3544716e859795f44e67279fbe820c157019845ab

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page