Mutual Information Estimation toolkit based on pytorch
Project description
Mist - A PyTorch Mutual information Estimation toolkit
Mutual Information Estimation toolkit based on pytorch. TO BE RELEASED SOON
Installation
The package can be installed via pip as follows:
$ pip install torch_mist
Usage
The torch_mist
package provides the basic functionalities for sample-based continuous mutual information estimation using modern
neural network architectures.
Here we provide a simple example of how to use the package to estimate mutual information between pairs of observations using the MINE estimator [2].
First, we need to import and instantiate the estimator from the package:
from torch_mist.estimators import mine
# Defining the estimator
estimator = mine(
x_dim=1, # dimension of x
y_dim=1, # dimension of y
hidden_dims=[32, 64, 32], # hidden dimensions of the neural networks
)
then we can train the estimator:
from torch_mist.utils import optimize_mi_estimator
train_log = optimize_mi_estimator(
estimator=estimator, # the estimator to train
dataloader=dataloader, # the dataloader returning pairs of x and y
epochs=10, # the number of epochs
device="cpu", # the device to use
return_log=True, # whether to return the training log
)
Lastly, we can use the trained estimator to estimate the mutual information between pairs of observations:
from torch_mist.utils import estimate_mi
value, std = estimate_mi(
estimator=estimator, # the estimator to use
dataloader=dataloader, # the dataloader returning pairs of x and y
device="cpu", # the device to use
)
print(f"Estimated MI: {value} +- {std}")
Please refer to the documentation for a detailed description of the package and its usage.
Estimators
The basic estimators implemented in this package are summarized in the following table:
Estimator | Type | Models |
---|---|---|
NWJ [1] | Discriminative | $f_\phi(x,y)$ |
MINE [2] | Discriminative | $f_\phi(x,y)$ |
InfoNCE [3] | Discriminative | $f_\phi(x,y)$ |
TUBA [4] | Discriminative | $f_\phi(x,y)$, $b_\xi(x)$ |
AlphaTUBA [4] | Discriminative | $f_\phi(x,y)$, $b_\xi(x)$ |
JS [5] | Discriminative | $f_\phi(x,y)$ |
SMILE [6] | Discriminative | $f_\phi(x,y)$ |
FLO [7] | Discriminative | $f_\phi(x,y)$, $b_\xi(x,y)$ |
BA [8] | Generative | $q_\theta(y|x)$ |
DoE [9] | Generative | $q_\theta(y|x)$, $q_\psi(y)$ |
GM [6] | Generative | $q_\theta(x,y)$, $q_\psi(x)$, $r_\psi(y)$ |
L1OUT [4] [10] | Generative | $q_\theta(y|x)$ |
CLUB [10] | Generative | $q_\theta(y|x)$ |
Discrete [] | Generative (Discrete) | $Q(x)$, $Q(y)$ |
PQ [11] | Generative (Discrete) | $Q(y)$, $q_\theta(Q(y)|x)$ |
in which:
- $f_\phi(x,y)$ is a
critic
neural network with parameters $\phi, which maps pairs of observations to a scalar value. Critics can be eitherjoint
orseparable
depending on whether they parametrize function of both $x$ and $y$ directly, or through the product of separate projection heads ( $f_\phi(x,y)=h_\phi(x)^T h_\phi(y)$ ) respectively. - $b_\xi(x)$ is a
baseline
neural network with parameters $\xi$, which maps observations (or paris of observations) to a scalar value. When the baseline is a function of both $x$ and $y$ it is referred to as ajoint_baseline
. - $q_\theta(y|x)$ is a conditional variational distribution
q_Y_given_X
used to approximate $p(y|x)$ with parameters $\theta$. Conditional distributions may have learnable parameters $\theta$ that are usually parametrized by a (conditional) normalizing flow. - $q_\psi(y)$ is a marginal variational distribution
q_Y
used to approximate $p(y)$ with parameters $\psi$. Marginal distributions may have learnable parameters $\psi$ that are usually parametrized by a normalizing flow. - $q_\theta(x,y)$ is a joint variational distribution
q_XY
used to approximate $p(x,y)$ with parameters $\theta$. Joint distributions may have learnable parameters $\theta$ that are usually parametrized by a normalizing flow. - $Q(x)$ and $Q(y)$ are
quantization
functions that map observations to a finite set of discrete values.
Hybrid estimators
The torch_mist
package allows to combine Generative and Discriminative estimators in a single hybrid estimators as proposed in [11][12].
References
[1] Nguyen, XuanLong, Martin J. Wainwright, and Michael I. Jordan. "Estimating divergence functionals and the likelihood ratio by convex risk minimization." IEEE Transactions on Information Theory 56.11 (2010): 5847-5861.
[2] Belghazi, Mohamed Ishmael, et al. "Mutual information neural estimation." International conference on machine learning. PMLR, 2018.
[3] Oord, Aaron van den, Yazhe Li, and Oriol Vinyals. "Representation learning with contrastive predictive coding." arXiv preprint arXiv:1807.03748 (2018).
[4] Poole, Ben, et al. "On variational bounds of mutual information." International Conference on Machine Learning. PMLR, 2019.
[5] Hjelm, R. Devon, et al. "Learning deep representations by mutual information estimation and maximization." arXiv preprint arXiv:1808.06670 (2018).
[6] Song, Jiaming, and Stefano Ermon. "Understanding the limitations of variational mutual information estimators." arXiv preprint arXiv:1910.06222 (2019).
[7] Guo, Qing, et al. "Tight mutual information estimation with contrastive fenchel-legendre optimization." Advances in Neural Information Processing Systems 35 (2022): 28319-28334.
[8] Barber, David, and Felix Agakov. "The im algorithm: a variational approach to information maximization." Advances in neural information processing systems 16.320 (2004): 201.
[9] McAllester, David, and Karl Stratos. "Formal limitations on the measurement of mutual information." International Conference on Artificial Intelligence and Statistics. PMLR, 2020.
[10] Cheng, Pengyu, et al. "Club: A contrastive log-ratio upper bound of mutual information." International conference on machine learning. PMLR, 2020.
[11] Federici, Marco, David Ruhe, and Patrick Forré. "On the Effectiveness of Hybrid Mutual Information Estimation." arXiv preprint arXiv:2306.00608 (2023).
[12] Brekelmans, Rob, et al. "Improving mutual information estimation with annealed and energy-based bounds." arXiv preprint arXiv:2303.06992 (2023).
Contributing
Interested in contributing? Check out the contributing guidelines. Please note that this project is released with a Code of Conduct. By contributing to this project, you agree to abide by its terms.
License
torch_mist
was created by Marco Federici. It is licensed under the terms of the MIT license.
Credits
torch_mist
was created with cookiecutter
and the py-pkgs-cookiecutter
template.
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
File details
Details for the file torch_mist-0.1.1.tar.gz
.
File metadata
- Download URL: torch_mist-0.1.1.tar.gz
- Upload date:
- Size: 29.9 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: poetry/1.5.1 CPython/3.8.1 Linux/5.4.0-148-generic
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | 0e50cbb7a6b70e9f6c74a4368066b473c0f7188957510c028e53f86806c49d79 |
|
MD5 | 7fa5a131bb9ed5daacdf1467fccca364 |
|
BLAKE2b-256 | 77d0c353bb48a4d48ffeff06e75c299ba646c7470096eca944879195d0e66b37 |
File details
Details for the file torch_mist-0.1.1-py3-none-any.whl
.
File metadata
- Download URL: torch_mist-0.1.1-py3-none-any.whl
- Upload date:
- Size: 45.9 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: poetry/1.5.1 CPython/3.8.1 Linux/5.4.0-148-generic
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | e5fc5d4e1d98cd0b2a60fc699161f246a360472815fc003405fb23990d5e62d8 |
|
MD5 | 7893366b6719dd6bd9b041b0f9db83d9 |
|
BLAKE2b-256 | 82782e487973cb8459b7f8e83011cd4169dda2ad2e1dcc54d67c76d637e73c31 |