Skip to main content

Mutual Information Estimation toolkit based on pytorch

Project description

Mist - A PyTorch Mutual information Estimation toolkit

arXiv PyPI version Build workflow codecov Documentation Status License: MIT Code style: black

Mutual Information Estimation toolkit based on pytorch. Please refer to the documentation for additional details regarding installation, usage, tutorials and pracical use-case example.

Installation

The package can be installed via pip as follows:

$ pip install torch_mist

Usage

The torch_mist package provides the basic functionalities for sample-based continuous mutual information estimation using modern neural network architectures.

Here we provide a simple example of how to use the package to estimate mutual information between pairs of observations using the MINE estimator [2]. Consider the variables $x$ and $y$ as of shape [N, x_dim], [N, y_dim] respectively sampled from some joint distribution $p(x,y)$. Mutual information can be estimated directly using the estimate_mi utility function that takes care of fitting the estimator's parameters and evaluating mutual information.

from torch_mist import estimate_mi

estimated_mi = estimate_mi(
    estimator_name='mine',  # Use MINE
    hidden_dims=[32, 32],  # Hidden dimensions of the neural network
    x=x,  # The values for x
    y=y,  # The values for y
)

print(f"Mutual information estimated value: {estimated_mi} nats")

Additional flags that can be used to customize the estimators, training and evaluation procedure are included in the documentation.

Alternatively, it is possible to manually instantiate, train and evaluate the mutual information estimators.

from torch_mist.estimators import mine
from torch_mist.utils.train import train_mi_estimator
from torch_mist.utils import evaluate_mi

# Instantiate the mutual information estimator
estimator = mine(
    x_dim=x.shape[-1],
    y_dim=y.shape[-1],
    hidden_dims=[32, 32],
)

# Train it on the given samples
train_log = train_mi_estimator(
    estimator=estimator,
    x=x,
    y=y,
    batch_size=64,
    verbose=True
)

# Evaluate the estimator on the entirety of the data
estimated_mi = evaluate_mi(
    estimator=estimator,
    x=x,
    y=y,
    batch_size=64
)

print(f"Mutual information estimated value: {estimated_mi} nats")

Note that the two code snippets above perform the same procedure. Please refer to the documentation for a detailed description of the package and its usage.

Estimators

Each estimator implemented in the library is an instance of MutualInformationEstimator and can be instantiated through a simplified utility functions

############################
# Simplified instantiation #
############################
from torch_mist.estimators import mine

estimator = mine(
    x_dim=x.shape[-1],
    y_dim=y.shape[-1],
    neg_samples=16,
    hidden_dims=[32, 32],
    critic_type='joint'
)

or directly using the corresponfing MutualInformationEstimator class

##########################
# Advanced instantiation #
##########################
from torch_mist.estimators import MINE
from torch_mist.critic import JointCritic
from torch import nn

# First we define the critic architecture
critic = JointCritic(  # Wrapper to concatenate the inputs x and y 
    joint_net=nn.Sequential(  # The neural network architectures that maps [x,y] to a scalar
        nn.Linear(x.shape[-1] + y.shape[-1], 32),
        nn.ReLU(True),
        nn.Linear(32, 32),
        nn.ReLU(True),
        nn.Linear(32, 1)
    )
)

# Then we pass it to the MINE constructor
estimator = MINE(
    critic=critic,
    neg_samples=16,
)

Note that the simplified and advanced instantiation reported in the example above result in the same model.

The basic estimators implemented in this package are summarized in the following table:

Estimator Type Models Hyperparameters
NWJ [1] Discriminative $f_\phi(x,y)$ M
MINE [2] Discriminative $f_\phi(x,y)$ M, $\gamma_{EMA}$
InfoNCE [3] Discriminative $f_\phi(x,y)$ M
TUBA [4] Discriminative $f_\phi(x,y)$, $b_\xi(x)$ M
AlphaTUBA [4] Discriminative $f_\phi(x,y)$, $b_\xi(x)$ M, $\alpha$
JS [5] Discriminative $f_\phi(x,y)$ M
SMILE [6] Discriminative $f_\phi(x,y)$ M, $\tau$
FLO [7] Discriminative $f_\phi(x,y)$, $b_\xi(x,y)$ M
BA [8] Generative $q_\theta(y|x)$ -
DoE [9] Generative $q_\theta(y|x)$, $q_\psi(y)$ -
GM [6] Generative $q_\theta(x,y)$, $q_\psi(x)$, $q_\psi(y)$ -
L1OUT [4] [10] Generative $q_\theta(y|x)$ -
CLUB [10] Generative $q_\theta(y|x)$ -
Binned [13] Transformed (Generative) $Q(x)$, $Q(y)$ -
PQ [11] Transformed (Generative) $Q(y)$, $q_\theta(Q(y)|x)$ -

in which the following models are used:

  • $f_\phi(x,y)$ is a critic neural network with parameters $\phi, which maps pairs of observations to a scalar value. Critics can be either joint or separable depending on whether they parametrize function of both $x$ and $y$ directly, or through the product of separate projection heads ( $f_\phi(x,y)=h_\phi(x)^T h_\phi(y)$ ) respectively.
  • $b_\xi(x)$ is a baseline neural network with parameters $\xi$, which maps observations (or paris of observations) to a scalar value. When the baseline is a function of both $x$ and $y$ it is referred to as a joint_baseline.
  • $q_\theta(y|x)$ is a conditional variational distribution q_Y_given_X used to approximate $p(y|x)$ with parameters $\theta$. Conditional distributions may have learnable parameters $\theta$ that are usually parametrized by a (conditional) normalizing flow.
  • $q_\psi(y)$ is a marginal variational distribution q_Y used to approximate $p(y)$ with parameters $\psi$. Marginal distributions may have learnable parameters $\psi$ that are usually parametrized by a normalizing flow.
  • $q_\theta(x,y)$ is a joint variational distribution q_XY used to approximate $p(x,y)$ with parameters $\theta$. Joint distributions may have learnable parameters $\theta$ that are usually parametrized by a normalizing flow.
  • $Q(x)$ and $Q(y)$ are quantization functions that map observations to a finite set of discrete values.

And the following hyperparameters:

  • $M \in [1, N]$ is the number of samples used to estimate the log-normalization constant for each element in the batch.
  • $\gamma_{EMA} \in (0,1]$ is the exponential moving average decay used to update the baseline in MINE.
  • $\alpha \in [0,1]$ is the weight of the baseline in AlphaTUBA (0 corresponds to InfoNCE, 1 to TUBA).
  • $\tau \in [0..]$ is used to define the interval $[-\tau,\tau]$ in which critic values are clipped in SMILE.

Hybrid estimators

The torch_mist package allows to combine Generative and Discriminative estimators in a single hybrid estimators as proposed in [11][12]. Hybrid mutual information estimators combine the flexibility of discriminative mutual information estimators with the lower variance of generative estimators.

from torch_mist.estimators.hybrid import ResampledHybridMIEstimator
from torch_mist.estimators import nwj, doe

# Use the proposal r(y|x) to sample negatives instead of p(y)
estimator = ResampledHybridMIEstimator(
    # Difference of Entropies generative estimator
    generative_estimator=doe(
        x_dim=x.shape[-1],
        y_dim=y.shape[-1],
        hidden_dims=[32, 32],
    ),
    # NWJ discriminative estimator
    discriminative_estimator=nwj(
        x_dim=x.shape[-1],
        y_dim=y.shape[-1],
        hidden_dims=[32, 32],
        neg_samples=16
    )
)

Further details on the available hybrid mutual information estimators and additional details are reported in the tutorial available in the documentation.

Training and Evaluation

Most of the estimators included in this package are parametric and require a training procedure for accurate estimation. The train_mi_estimator utility function supports either row data x and y as numpy.array or torch.Tensor.

from torch_mist.utils.train import train_mi_estimator

######################################
# Training using tensors for x and y #
######################################
# By default 10% of the data is used for cross-validation and early stopping
train_log = train_mi_estimator(
    estimator=estimator,
    x=x,
    y=y,
    batch_size=64,
    valid_percentage=0.1,
)

Alternatively, it is possible to use a torch.utils.DataLoader that returns eiter batches of pairs (batch_x, batch_y) or dictionaries of batches {'x': batch_x, 'y': batch_y}, with batch_x of shape [batch_size, ..., x_dim] and [batch_size, ..., y_dim] respectively.

#############################
# Training with DataLoaders #
#############################
from torch_mist.utils.data import SampleDataset
from torch.utils.data import DataLoader, random_split

# We provide an utility to make the tensors into a torch.utils.data.Dataset object
# This can be replaced with any other Dataset object that may load the data from disk
dataset = SampleDataset(
    samples={'x': x, 'y': y}
)

# Split into train and validation
train_size = int(len(dataset)*0.9)
valid_size = len(dataset)-train_size
train_set, valid_set = random_split(dataset, [train_size, valid_size])

# Instantiate the dataloaders
train_loader = DataLoader(
    train_set,
    batch_size=64,
    shuffle=True,
    num_workers=8
)

valid_loader = DataLoader(
    valid_set,
    batch_size=64,
    num_workers=8
)

# Train using the specified dataloaders
# Note that the validation set is optional but recommended to prevent overfitting.

train_log = train_mi_estimator(
    estimator=estimator,
    train_loader=train_loader,
    valid_loader=valid_loader,
)

The two options result in the same training procedure, but we recommend using DataLoader for larger datasets.

Both DataLoader and torch.Tensor (or np.array) can be used for the evaluate_mi function.

References

[1] Nguyen, XuanLong, Martin J. Wainwright, and Michael I. Jordan. "Estimating divergence functionals and the likelihood ratio by convex risk minimization." IEEE Transactions on Information Theory 56.11 (2010): 5847-5861.

[2] Belghazi, Mohamed Ishmael, et al. "Mutual information neural estimation." International conference on machine learning. PMLR, 2018.

[3] Oord, Aaron van den, Yazhe Li, and Oriol Vinyals. "Representation learning with contrastive predictive coding." arXiv preprint arXiv:1807.03748 (2018).

[4] Poole, Ben, et al. "On variational bounds of mutual information." International Conference on Machine Learning. PMLR, 2019.

[5] Hjelm, R. Devon, et al. "Learning deep representations by mutual information estimation and maximization." arXiv preprint arXiv:1808.06670 (2018).

[6] Song, Jiaming, and Stefano Ermon. "Understanding the limitations of variational mutual information estimators." arXiv preprint arXiv:1910.06222 (2019).

[7] Guo, Qing, et al. "Tight mutual information estimation with contrastive fenchel-legendre optimization." Advances in Neural Information Processing Systems 35 (2022): 28319-28334.

[8] Barber, David, and Felix Agakov. "The im algorithm: a variational approach to information maximization." Advances in neural information processing systems 16.320 (2004): 201.

[9] McAllester, David, and Karl Stratos. "Formal limitations on the measurement of mutual information." International Conference on Artificial Intelligence and Statistics. PMLR, 2020.

[10] Cheng, Pengyu, et al. "Club: A contrastive log-ratio upper bound of mutual information." International conference on machine learning. PMLR, 2020.

[11] Federici, Marco, David Ruhe, and Patrick Forré. "On the Effectiveness of Hybrid Mutual Information Estimation." arXiv preprint arXiv:2306.00608 (2023).

[12] Brekelmans, Rob, et al. "Improving mutual information estimation with annealed and energy-based bounds." arXiv preprint arXiv:2303.06992 (2023).

[13] Kraskov, Alexander, Harald Stögbauer, and Peter Grassberger. "Estimating mutual information." Physical review E 69.6 (2004): 066138.

Contributing

Interested in contributing? Check out the contributing guidelines. Please note that this project is released with a Code of Conduct. By contributing to this project, you agree to abide by its terms.

License

torch_mist was created by Marco Federici. It is licensed under the terms of the MIT license.

Credits

torch_mist was created with cookiecutter and the py-pkgs-cookiecutter template.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

torch_mist-0.1.8.tar.gz (54.9 kB view details)

Uploaded Source

Built Distribution

torch_mist-0.1.8-py3-none-any.whl (86.2 kB view details)

Uploaded Python 3

File details

Details for the file torch_mist-0.1.8.tar.gz.

File metadata

  • Download URL: torch_mist-0.1.8.tar.gz
  • Upload date:
  • Size: 54.9 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: poetry/1.5.1 CPython/3.8.1 Linux/5.4.0-170-generic

File hashes

Hashes for torch_mist-0.1.8.tar.gz
Algorithm Hash digest
SHA256 6a9c8e51c4355a13ad151d6e4a1913191b8f01d493190c3f566415366f93ae96
MD5 b88305dbf45bf0916e66050a570a1907
BLAKE2b-256 bff781fe4f7886547320a9a9694ddfb4cf3d05158852320bc52d1f2b4b157085

See more details on using hashes here.

File details

Details for the file torch_mist-0.1.8-py3-none-any.whl.

File metadata

  • Download URL: torch_mist-0.1.8-py3-none-any.whl
  • Upload date:
  • Size: 86.2 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: poetry/1.5.1 CPython/3.8.1 Linux/5.4.0-170-generic

File hashes

Hashes for torch_mist-0.1.8-py3-none-any.whl
Algorithm Hash digest
SHA256 a2ee2359769447ec4467b268da7daaed5d3dda1992525f93bfdf7bc7ca987395
MD5 442602019e5c61df242d44fd25af012b
BLAKE2b-256 e6f0d099340e1b8a0b0ab94f1ae5dfd9df5806dfdf1e5802c8f1209d83349489

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page