Mutual Information Estimation toolkit based on pytorch
Project description
Mist - A PyTorch Mutual information Estimation toolkit
Mutual Information Estimation toolkit based on pytorch. Please refer to the documentation for additional details regarding installation, usage, tutorials and pracical use-case example.
Installation
The package can be installed via pip as follows:
$ pip install torch_mist
Usage
The torch_mist
package provides the basic functionalities for sample-based continuous mutual information estimation using modern
neural network architectures.
Here we provide a simple example of how to use the package to estimate mutual information between pairs
of observations using the MINE estimator [2].
Consider the variables $x$ and $y$ as of shape [N, x_dim]
, [N, y_dim]
respectively sampled from some joint distribution $p(x,y)$.
Mutual information can be estimated directly using the estimate_mi
utility function that takes care of fitting the estimator's parameters and evaluating mutual information.
from torch_mist import estimate_mi
estimated_mi = estimate_mi(
estimator_name='mine', # Use MINE
hidden_dims=[32, 32], # Hidden dimensions of the neural network
x=x, # The values for x
y=y, # The values for y
)
print(f"Mutual information estimated value: {estimated_mi} nats")
Additional flags that can be used to customize the estimators, training and evaluation procedure are included in the documentation.
Alternatively, it is possible to manually instantiate, train and evaluate the mutual information estimators.
from torch_mist.estimators import mine
from torch_mist.utils.train import train_mi_estimator
from torch_mist.utils import evaluate_mi
# Instantiate the mutual information estimator
estimator = mine(
x_dim=x.shape[-1],
y_dim=y.shape[-1],
hidden_dims=[32, 32],
)
# Train it on the given samples
train_log = train_mi_estimator(
estimator=estimator,
x=x,
y=y,
batch_size=64,
verbose=True
)
# Evaluate the estimator on the entirety of the data
estimated_mi = evaluate_mi(
estimator=estimator,
x=x,
y=y,
batch_size=64
)
print(f"Mutual information estimated value: {estimated_mi} nats")
Note that the two code snippets above perform the same procedure. Please refer to the documentation for a detailed description of the package and its usage.
Estimators
Each estimator implemented in the library is an instance of MutualInformationEstimator
and can be instantiated
through a simplified utility functions
############################
# Simplified instantiation #
############################
from torch_mist.estimators import mine
estimator = mine(
x_dim=x.shape[-1],
y_dim=y.shape[-1],
neg_samples=16,
hidden_dims=[32, 32],
critic_type='joint'
)
or directly using the corresponfing MutualInformationEstimator
class
##########################
# Advanced instantiation #
##########################
from torch_mist.estimators import MINE
from torch_mist.critic import JointCritic
from torch import nn
# First we define the critic architecture
critic = JointCritic( # Wrapper to concatenate the inputs x and y
joint_net=nn.Sequential( # The neural network architectures that maps [x,y] to a scalar
nn.Linear(x.shape[-1] + y.shape[-1], 32),
nn.ReLU(True),
nn.Linear(32, 32),
nn.ReLU(True),
nn.Linear(32, 1)
)
)
# Then we pass it to the MINE constructor
estimator = MINE(
critic=critic,
neg_samples=16,
)
Note that the simplified and advanced instantiation reported in the example above result in the same model.
The basic estimators implemented in this package are summarized in the following table:
Estimator | Type | Models | Hyperparameters |
---|---|---|---|
NWJ [1] | Discriminative | $f_\phi(x,y)$ | M |
MINE [2] | Discriminative | $f_\phi(x,y)$ | M, $\gamma_{EMA}$ |
InfoNCE [3] | Discriminative | $f_\phi(x,y)$ | M |
TUBA [4] | Discriminative | $f_\phi(x,y)$, $b_\xi(x)$ | M |
AlphaTUBA [4] | Discriminative | $f_\phi(x,y)$, $b_\xi(x)$ | M, $\alpha$ |
JS [5] | Discriminative | $f_\phi(x,y)$ | M |
SMILE [6] | Discriminative | $f_\phi(x,y)$ | M, $\tau$ |
FLO [7] | Discriminative | $f_\phi(x,y)$, $b_\xi(x,y)$ | M |
BA [8] | Generative | $q_\theta(y|x)$ | - |
DoE [9] | Generative | $q_\theta(y|x)$, $q_\psi(y)$ | - |
GM [6] | Generative | $q_\theta(x,y)$, $q_\psi(x)$, $q_\psi(y)$ | - |
L1OUT [4] [10] | Generative | $q_\theta(y|x)$ | - |
CLUB [10] | Generative | $q_\theta(y|x)$ | - |
Binned [13] | Transformed (Generative) | $Q(x)$, $Q(y)$ | - |
PQ [11] | Transformed (Generative) | $Q(y)$, $q_\theta(Q(y)|x)$ | - |
in which the following models are used:
- $f_\phi(x,y)$ is a
critic
neural network with parameters $\phi, which maps pairs of observations to a scalar value. Critics can be eitherjoint
orseparable
depending on whether they parametrize function of both $x$ and $y$ directly, or through the product of separate projection heads ( $f_\phi(x,y)=h_\phi(x)^T h_\phi(y)$ ) respectively. - $b_\xi(x)$ is a
baseline
neural network with parameters $\xi$, which maps observations (or paris of observations) to a scalar value. When the baseline is a function of both $x$ and $y$ it is referred to as ajoint_baseline
. - $q_\theta(y|x)$ is a conditional variational distribution
q_Y_given_X
used to approximate $p(y|x)$ with parameters $\theta$. Conditional distributions may have learnable parameters $\theta$ that are usually parametrized by a (conditional) normalizing flow. - $q_\psi(y)$ is a marginal variational distribution
q_Y
used to approximate $p(y)$ with parameters $\psi$. Marginal distributions may have learnable parameters $\psi$ that are usually parametrized by a normalizing flow. - $q_\theta(x,y)$ is a joint variational distribution
q_XY
used to approximate $p(x,y)$ with parameters $\theta$. Joint distributions may have learnable parameters $\theta$ that are usually parametrized by a normalizing flow. - $Q(x)$ and $Q(y)$ are
quantization
functions that map observations to a finite set of discrete values.
And the following hyperparameters:
- $M \in [1, N]$ is the number of samples used to estimate the log-normalization constant for each element in the batch.
- $\gamma_{EMA} \in (0,1]$ is the exponential moving average decay used to update the baseline in MINE.
- $\alpha \in [0,1]$ is the weight of the baseline in AlphaTUBA (0 corresponds to InfoNCE, 1 to TUBA).
- $\tau \in [0..]$ is used to define the interval $[-\tau,\tau]$ in which critic values are clipped in SMILE.
Hybrid estimators
The torch_mist
package allows to combine Generative and Discriminative estimators in a single hybrid estimators as proposed in [11][12].
Hybrid mutual information estimators combine the flexibility of discriminative mutual information estimators with the lower
variance of generative estimators.
from torch_mist.estimators.hybrid import ResampledHybridMIEstimator
from torch_mist.estimators import nwj, doe
# Use the proposal r(y|x) to sample negatives instead of p(y)
estimator = ResampledHybridMIEstimator(
# Difference of Entropies generative estimator
generative_estimator=doe(
x_dim=x.shape[-1],
y_dim=y.shape[-1],
hidden_dims=[32, 32],
),
# NWJ discriminative estimator
discriminative_estimator=nwj(
x_dim=x.shape[-1],
y_dim=y.shape[-1],
hidden_dims=[32, 32],
neg_samples=16
)
)
Further details on the available hybrid mutual information estimators and additional details are reported in the tutorial available in the documentation.
Training and Evaluation
Most of the estimators included in this package are parametric and require a training procedure for accurate estimation.
The train_mi_estimator
utility function supports either row data x
and y
as numpy.array
or torch.Tensor
.
from torch_mist.utils.train import train_mi_estimator
######################################
# Training using tensors for x and y #
######################################
# By default 10% of the data is used for cross-validation and early stopping
train_log = train_mi_estimator(
estimator=estimator,
x=x,
y=y,
batch_size=64,
valid_percentage=0.1,
)
Alternatively, it is possible to use a torch.utils.DataLoader
that returns eiter batches of pairs (batch_x, batch_y)
or dictionaries of batches {'x': batch_x, 'y': batch_y}
, with batch_x
of shape [batch_size, ..., x_dim]
and [batch_size, ..., y_dim]
respectively.
#############################
# Training with DataLoaders #
#############################
from torch_mist.utils.data import SampleDataset
from torch.utils.data import DataLoader, random_split
# We provide an utility to make the tensors into a torch.utils.data.Dataset object
# This can be replaced with any other Dataset object that may load the data from disk
dataset = SampleDataset(
samples={'x': x, 'y': y}
)
# Split into train and validation
train_size = int(len(dataset)*0.9)
valid_size = len(dataset)-train_size
train_set, valid_set = random_split(dataset, [train_size, valid_size])
# Instantiate the dataloaders
train_loader = DataLoader(
train_set,
batch_size=64,
shuffle=True,
num_workers=8
)
valid_loader = DataLoader(
valid_set,
batch_size=64,
num_workers=8
)
# Train using the specified dataloaders
# Note that the validation set is optional but recommended to prevent overfitting.
train_log = train_mi_estimator(
estimator=estimator,
train_loader=train_loader,
valid_loader=valid_loader,
)
The two options result in the same training procedure, but we recommend using DataLoader
for larger datasets.
Both DataLoader
and torch.Tensor
(or np.array
) can be used for the evaluate_mi
function.
References
[1] Nguyen, XuanLong, Martin J. Wainwright, and Michael I. Jordan. "Estimating divergence functionals and the likelihood ratio by convex risk minimization." IEEE Transactions on Information Theory 56.11 (2010): 5847-5861.
[2] Belghazi, Mohamed Ishmael, et al. "Mutual information neural estimation." International conference on machine learning. PMLR, 2018.
[3] Oord, Aaron van den, Yazhe Li, and Oriol Vinyals. "Representation learning with contrastive predictive coding." arXiv preprint arXiv:1807.03748 (2018).
[4] Poole, Ben, et al. "On variational bounds of mutual information." International Conference on Machine Learning. PMLR, 2019.
[5] Hjelm, R. Devon, et al. "Learning deep representations by mutual information estimation and maximization." arXiv preprint arXiv:1808.06670 (2018).
[6] Song, Jiaming, and Stefano Ermon. "Understanding the limitations of variational mutual information estimators." arXiv preprint arXiv:1910.06222 (2019).
[7] Guo, Qing, et al. "Tight mutual information estimation with contrastive fenchel-legendre optimization." Advances in Neural Information Processing Systems 35 (2022): 28319-28334.
[8] Barber, David, and Felix Agakov. "The im algorithm: a variational approach to information maximization." Advances in neural information processing systems 16.320 (2004): 201.
[9] McAllester, David, and Karl Stratos. "Formal limitations on the measurement of mutual information." International Conference on Artificial Intelligence and Statistics. PMLR, 2020.
[10] Cheng, Pengyu, et al. "Club: A contrastive log-ratio upper bound of mutual information." International conference on machine learning. PMLR, 2020.
[11] Federici, Marco, David Ruhe, and Patrick Forré. "On the Effectiveness of Hybrid Mutual Information Estimation." arXiv preprint arXiv:2306.00608 (2023).
[12] Brekelmans, Rob, et al. "Improving mutual information estimation with annealed and energy-based bounds." arXiv preprint arXiv:2303.06992 (2023).
[13] Kraskov, Alexander, Harald Stögbauer, and Peter Grassberger. "Estimating mutual information." Physical review E 69.6 (2004): 066138.
Contributing
Interested in contributing? Check out the contributing guidelines. Please note that this project is released with a Code of Conduct. By contributing to this project, you agree to abide by its terms.
License
torch_mist
was created by Marco Federici. It is licensed under the terms of the MIT license.
Credits
torch_mist
was created with cookiecutter
and the py-pkgs-cookiecutter
template.
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
File details
Details for the file torch_mist-0.1.6.tar.gz
.
File metadata
- Download URL: torch_mist-0.1.6.tar.gz
- Upload date:
- Size: 54.9 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: poetry/1.5.1 CPython/3.8.1 Linux/5.4.0-169-generic
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | db91f11a0ad93c2998cedf47db19685b7b68444a78955a791a949198ac6343e9 |
|
MD5 | a69d1fda8baaabfebef63d39d02639ab |
|
BLAKE2b-256 | eee192f4b3339d2bb971e58b780ad79c447d619ccc9831566ce0fc68b45d0f29 |
File details
Details for the file torch_mist-0.1.6-py3-none-any.whl
.
File metadata
- Download URL: torch_mist-0.1.6-py3-none-any.whl
- Upload date:
- Size: 86.1 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: poetry/1.5.1 CPython/3.8.1 Linux/5.4.0-169-generic
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | 5f803596de50fb1b220c92346e674481446883b35885f9296f94313978de4246 |
|
MD5 | a5cb31f4e57d647710849b882636baf3 |
|
BLAKE2b-256 | 5eb3a13a2cff09f9c4b8eebd4c50dddf02a26209cff209e619d2f75c34b13782 |