Skip to main content

Mutual Information Estimation toolkit based on pytorch

Project description

Mist - A PyTorch Mutual information Estimation toolkit

PyPI version codecov Documentation Status License: MIT

Mutual Information Estimation toolkit based on pytorch. TO BE RELEASED SOON

Installation

The package can be installed via pip as follows:

$ pip install torch_mist

Usage

The torch_mist package provides the basic functionalities for sample-based continuous mutual information estimation using modern neural network architectures.

Here we provide a simple example of how to use the package to estimate mutual information between pairs of observations using the MINE estimator [2].

First, we need to import and instantiate the estimator from the package:

from torch_mist.estimators import mine

# Defining the estimator
estimator = mine(
    x_dim=1,                    # dimension of x
    y_dim=1,                    # dimension of y   
    hidden_dims=[32, 64, 32],   # hidden dimensions of the neural networks
)

then we can train the estimator:

from torch_mist.utils import optimize_mi_estimator


train_log = optimize_mi_estimator(
    estimator=estimator,        # the estimator to train
    dataloader=dataloader,      # the dataloader returning pairs of x and y
    epochs=10,                  # the number of epochs
    device="cpu",               # the device to use
    return_log=True,            # whether to return the training log
)

Lastly, we can use the trained estimator to estimate the mutual information between pairs of observations:

from torch_mist.utils import estimate_mi
value, std = estimate_mi(
    estimator=estimator,        # the estimator to use
    dataloader=dataloader,      # the dataloader returning pairs of x and y
    device="cpu",               # the device to use
)

print(f"Estimated MI: {value} +- {std}")

Please refer to the documentation for a detailed description of the package and its usage.

Estimators

The basic estimators implemented in this package are summarized in the following table:

Estimator Type Models
NWJ [1] Discriminative $f_\phi(x,y)$
MINE [2] Discriminative $f_\phi(x,y)$
InfoNCE [3] Discriminative $f_\phi(x,y)$
TUBA [4] Discriminative $f_\phi(x,y)$, $b_\xi(x)$
AlphaTUBA [4] Discriminative $f_\phi(x,y)$, $b_\xi(x)$
JS [5] Discriminative $f_\phi(x,y)$
SMILE [6] Discriminative $f_\phi(x,y)$
FLO [7] Discriminative $f_\phi(x,y)$, $b_\xi(x,y)$
BA [8] Generative $q_\theta(y|x)$
DoE [9] Generative $q_\theta(y|x)$, $q_\psi(y)$
GM [6] Generative $q_\theta(x,y)$, $q_\psi(x)$, $r_\psi(y)$
L1OUT [4] [10] Generative $q_\theta(y|x)$
CLUB [10] Generative $q_\theta(y|x)$
Discrete [] Generative (Discrete) $Q(x)$, $Q(y)$
PQ [11] Generative (Discrete) $Q(y)$, $q_\theta(Q(y)|x)$

in which:

  • $f_\phi(x,y)$ is a critic neural network with parameters $\phi, which maps pairs of observations to a scalar value. Critics can be either joint or separable depending on whether they parametrize function of both $x$ and $y$ directly, or through the product of separate projection heads ( $f_\phi(x,y)=h_\phi(x)^T h_\phi(y)$ ) respectively.
  • $b_\xi(x)$ is a baseline neural network with parameters $\xi$, which maps observations (or paris of observations) to a scalar value. When the baseline is a function of both $x$ and $y$ it is referred to as a joint_baseline.
  • $q_\theta(y|x)$ is a conditional variational distribution q_Y_given_X used to approximate $p(y|x)$ with parameters $\theta$. Conditional distributions may have learnable parameters $\theta$ that are usually parametrized by a (conditional) normalizing flow.
  • $q_\psi(y)$ is a marginal variational distribution q_Y used to approximate $p(y)$ with parameters $\psi$. Marginal distributions may have learnable parameters $\psi$ that are usually parametrized by a normalizing flow.
  • $q_\theta(x,y)$ is a joint variational distribution q_XY used to approximate $p(x,y)$ with parameters $\theta$. Joint distributions may have learnable parameters $\theta$ that are usually parametrized by a normalizing flow.
  • $Q(x)$ and $Q(y)$ are quantization functions that map observations to a finite set of discrete values.

Hybrid estimators

The torch_mist package allows to combine Generative and Discriminative estimators in a single hybrid estimators as proposed in [11][12].

References

[1] Nguyen, XuanLong, Martin J. Wainwright, and Michael I. Jordan. "Estimating divergence functionals and the likelihood ratio by convex risk minimization." IEEE Transactions on Information Theory 56.11 (2010): 5847-5861.

[2] Belghazi, Mohamed Ishmael, et al. "Mutual information neural estimation." International conference on machine learning. PMLR, 2018.

[3] Oord, Aaron van den, Yazhe Li, and Oriol Vinyals. "Representation learning with contrastive predictive coding." arXiv preprint arXiv:1807.03748 (2018).

[4] Poole, Ben, et al. "On variational bounds of mutual information." International Conference on Machine Learning. PMLR, 2019.

[5] Hjelm, R. Devon, et al. "Learning deep representations by mutual information estimation and maximization." arXiv preprint arXiv:1808.06670 (2018).

[6] Song, Jiaming, and Stefano Ermon. "Understanding the limitations of variational mutual information estimators." arXiv preprint arXiv:1910.06222 (2019).

[7] Guo, Qing, et al. "Tight mutual information estimation with contrastive fenchel-legendre optimization." Advances in Neural Information Processing Systems 35 (2022): 28319-28334.

[8] Barber, David, and Felix Agakov. "The im algorithm: a variational approach to information maximization." Advances in neural information processing systems 16.320 (2004): 201.

[9] McAllester, David, and Karl Stratos. "Formal limitations on the measurement of mutual information." International Conference on Artificial Intelligence and Statistics. PMLR, 2020.

[10] Cheng, Pengyu, et al. "Club: A contrastive log-ratio upper bound of mutual information." International conference on machine learning. PMLR, 2020.

[11] Federici, Marco, David Ruhe, and Patrick Forré. "On the Effectiveness of Hybrid Mutual Information Estimation." arXiv preprint arXiv:2306.00608 (2023).

[12] Brekelmans, Rob, et al. "Improving mutual information estimation with annealed and energy-based bounds." arXiv preprint arXiv:2303.06992 (2023).

Contributing

Interested in contributing? Check out the contributing guidelines. Please note that this project is released with a Code of Conduct. By contributing to this project, you agree to abide by its terms.

License

torch_mist was created by Marco Federici. It is licensed under the terms of the MIT license.

Credits

torch_mist was created with cookiecutter and the py-pkgs-cookiecutter template.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

torch_mist-0.1.1.tar.gz (29.9 kB view details)

Uploaded Source

Built Distribution

torch_mist-0.1.1-py3-none-any.whl (45.9 kB view details)

Uploaded Python 3

File details

Details for the file torch_mist-0.1.1.tar.gz.

File metadata

  • Download URL: torch_mist-0.1.1.tar.gz
  • Upload date:
  • Size: 29.9 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: poetry/1.5.1 CPython/3.8.1 Linux/5.4.0-148-generic

File hashes

Hashes for torch_mist-0.1.1.tar.gz
Algorithm Hash digest
SHA256 0e50cbb7a6b70e9f6c74a4368066b473c0f7188957510c028e53f86806c49d79
MD5 7fa5a131bb9ed5daacdf1467fccca364
BLAKE2b-256 77d0c353bb48a4d48ffeff06e75c299ba646c7470096eca944879195d0e66b37

See more details on using hashes here.

File details

Details for the file torch_mist-0.1.1-py3-none-any.whl.

File metadata

  • Download URL: torch_mist-0.1.1-py3-none-any.whl
  • Upload date:
  • Size: 45.9 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: poetry/1.5.1 CPython/3.8.1 Linux/5.4.0-148-generic

File hashes

Hashes for torch_mist-0.1.1-py3-none-any.whl
Algorithm Hash digest
SHA256 e5fc5d4e1d98cd0b2a60fc699161f246a360472815fc003405fb23990d5e62d8
MD5 7893366b6719dd6bd9b041b0f9db83d9
BLAKE2b-256 82782e487973cb8459b7f8e83011cd4169dda2ad2e1dcc54d67c76d637e73c31

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page