Skip to main content

Mutual Information Estimation toolkit based on pytorch

Project description

Mist - A PyTorch Mutual information Estimation toolkit

arXiv PyPI version Build workflow codecov Documentation Status License: MIT Code style: black

Mutual Information Estimation toolkit based on pytorch. Please refer to the documentation for additional details regarding installation, usage, tutorials and pracical use-case example.

Installation

The package can be installed via pip as follows:

$ pip install torch_mist

Usage

The torch_mist package provides the basic functionalities for sample-based continuous mutual information estimation using modern neural network architectures.

Here we provide a simple example of how to use the package to estimate mutual information between pairs of observations using the MINE estimator [2]. Consider the variables $x$ and $y$ as of shape [N, x_dim], [N, y_dim] respectively sampled from some joint distribution $p(x,y)$. Mutual information can be estimated directly using the estimate_mi utility function that takes care of fitting the estimator's parameters and evaluating mutual information.

from torch_mist import estimate_mi
from sklearn.datasets import load_iris

# Load the Iris Dataset as a pandas DataFrame
iris_dataset = load_iris(as_frame=True)['data']

# Estimate how much information the petal length and its width have in common
estimated_mi, estimator, train_log = estimate_mi(
    data=iris_dataset,          # The dataset (as a pandas.DataFrame, many other formats are supported)
    x_key='petal length (cm)',  # Consider the 'petal length' column as x
    y_key='petal width (cm)',   # And the 'petal witdh` as y
    estimator_name='mine',      # Use the MINE mutual information estimator
    max_iterations=1000,        # Number of maximum train iterations 
)

print(f"Mutual information estimated value: {estimated_mi} nats")

The estimate_mi function supports additional data formats such as tuples (x, y) or dictionaries {'x':x, 'y':y} of tensors, but also datasets (see torch.utils.data.Dataset) and torch DataLoaders.

Additional flags can be used to customize the estimators, training and evaluation procedure, as detailed in the documentation.

Command line

The torch_mist package provides basic functionality to estimate mutual information directly from the command line. Given a file iris.csv containing the columns (sepal_1, sepal_2, petal_1, petal_2), one can estimate mutual information between the sepal and petal 2-dimensional features with:

mist data=csv data.filepath=iris.csv mi_estimator=js x_key=sepal y_key=petal

The same flags and options provided by the estimate_mi function are also available from command line.

Additionally, internal properties of the estimator can also be easily specified:

mist data=csv data.filepath=iris.csv mi_estimator=js x_key=sepal y_key=petal \ 
  # Train on GPU
  device=cuda \
  # Use AdamW for the optimization
  estimation.optimizer_class._target_=torch.optim.AdamW \
  # Use ELU as nonlinearities
  +mi_estimator.nonlinearity=torch.nn.ELU \
  # Change the batch size to 256
  params.batch_size=256
  # Log on weights and bias
  logger=wandb

To visualize the full list use:

mist data=csv --help

The mist CLI is implemented using hydra and the full configuration can be accessed here.

Advanced usage

It is possible to manually instantiate, train and evaluate the mutual information estimators.

from torch_mist.estimators import mine
from torch_mist.utils.train import train_mi_estimator
from torch_mist.utils import evaluate_mi

# Instantiate the JS mutual information estimator
estimator = mine(
    x_dim=1,
    y_dim=1,
    hidden_dims=[64, 32],
)

# Define x and y as the vectors of petal lengths and widths, respectively
x = iris_dataset['petal length (cm)'].values 
y = iris_dataset['petal width (cm)'].values

# Train it on the given samples
train_log = train_mi_estimator(
    estimator=estimator,
    data=(x, y),
    batch_size=64,
    max_iterations=1000
)

# Evaluate the estimator on the entirety of the data
estimated_mi = evaluate_mi(
    estimator=estimator,
    data=(x, y),
    batch_size=64
)

print(f"Mutual information estimated value: {estimated_mi} nats")

Note that the two code snippets above perform the same procedure. Please refer to the documentation for a detailed description of the package and its usage.

Estimators

Each estimator implemented in the library is an instance of MutualInformationEstimator and can be instantiated through a simplified utility functions

############################
# Simplified instantiation #
############################
from torch_mist.estimators import mine

estimator = mine(
    x_dim=1,
    y_dim=1,
    neg_samples=16,
    hidden_dims=[64, 32],
    critic_type='joint'
)

or directly using the corresponding MutualInformationEstimator class

##########################
# Advanced instantiation #
##########################
from torch_mist.estimators import MINE
from torch_mist.critic import JointCritic
from torch import nn

# First we define the critic architecture
critic = JointCritic(  # Wrapper to concatenate the inputs x and y 
    joint_net=nn.Sequential(  # The neural network architectures that maps [x,y] to a scalar
        nn.Linear(x.shape[-1] + y.shape[-1], 32),
        nn.ReLU(True),
        nn.Linear(32, 32),
        nn.ReLU(True),
        nn.Linear(32, 1)
    )
)

# Then we pass it to the MINE constructor
estimator = MINE(
    critic=critic,
    neg_samples=16,
)

Note that the simplified and advanced instantiation reported in the example above result in the same model.

The basic estimators implemented in this package are summarized in the following table:

Estimator Type Models Hyperparameters
NWJ [1] Discriminative $f_\phi(x,y)$ M
MINE [2] Discriminative $f_\phi(x,y)$ M, $\gamma_{EMA}$
InfoNCE [3] Discriminative $f_\phi(x,y)$ M
TUBA [4] Discriminative $f_\phi(x,y)$, $b_\xi(x)$ M
AlphaTUBA [4] Discriminative $f_\phi(x,y)$, $b_\xi(x)$ M, $\alpha$
JS [5] Discriminative $f_\phi(x,y)$ M
SMILE [6] Discriminative $f_\phi(x,y)$ M, $\tau$
FLO [7] Discriminative $f_\phi(x,y)$, $b_\xi(x,y)$ M
BA [8] Generative $q_\theta(y|x)$ -
DoE [9] Generative $q_\theta(y|x)$, $q_\psi(y)$ -
GM [6] Generative $q_\theta(x,y)$, $q_\psi(x)$, $q_\psi(y)$ -
L1OUT [4] [10] Generative $q_\theta(y|x)$ -
CLUB [10] Generative $q_\theta(y|x)$ -
Binned [13] Transformed (Generative) $Q(x)$, $Q(y)$ -
PQ [11] Transformed (Generative) $Q(y)$, $q_\theta(Q(y)|x)$ -

in which the following models are used:

  • $f_\phi(x,y)$ is a critic neural network with parameters $\phi, which maps pairs of observations to a scalar value. Critics can be either joint or separable depending on whether they parametrize function of both $x$ and $y$ directly, or through the product of separate projection heads ( $f_\phi(x,y)=h_\phi(x)^T h_\phi(y)$ ) respectively.
  • $b_\xi(x)$ is a baseline neural network with parameters $\xi$, which maps observations (or paris of observations) to a scalar value. When the baseline is a function of both $x$ and $y$ it is referred to as a joint_baseline.
  • $q_\theta(y|x)$ is a conditional variational distribution q_Y_given_X used to approximate $p(y|x)$ with parameters $\theta$. Conditional distributions may have learnable parameters $\theta$ that are usually parametrized by a (conditional) normalizing flow.
  • $q_\psi(y)$ is a marginal variational distribution q_Y used to approximate $p(y)$ with parameters $\psi$. Marginal distributions may have learnable parameters $\psi$ that are usually parametrized by a normalizing flow.
  • $q_\theta(x,y)$ is a joint variational distribution q_XY used to approximate $p(x,y)$ with parameters $\theta$. Joint distributions may have learnable parameters $\theta$ that are usually parametrized by a normalizing flow.
  • $Q(x)$ and $Q(y)$ are quantization functions that map observations to a finite set of discrete values.

And the following hyperparameters:

  • $M \in [1, N]$ is the number of samples used to estimate the log-normalization constant for each element in the batch.
  • $\gamma_{EMA} \in (0,1]$ is the exponential moving average decay used to update the baseline in MINE.
  • $\alpha \in [0,1]$ is the weight of the baseline in AlphaTUBA (0 corresponds to InfoNCE, 1 to TUBA).
  • $\tau \in [0..]$ is used to define the interval $[-\tau,\tau]$ in which critic values are clipped in SMILE.

Hybrid estimators

The torch_mist package allows to combine Generative and Discriminative estimators in a single hybrid estimators as proposed in [11][12]. Hybrid mutual information estimators combine the flexibility of discriminative mutual information estimators with the lower variance of generative estimators.

from torch_mist.estimators.hybrid import ResampledHybridMIEstimator
from torch_mist.estimators import js, doe

# Use the proposal r(y|x) to sample negatives instead of p(y)
estimator = ResampledHybridMIEstimator(
    # Difference of Entropies generative estimator
    generative_estimator=doe(
        x_dim=x.shape[-1],
        y_dim=y.shape[-1],
        hidden_dims=[32, 32],
    ),
    # NWJ discriminative estimator
    discriminative_estimator=js(
        x_dim=x.shape[-1],
        y_dim=y.shape[-1],
        hidden_dims=[32, 32],
        neg_samples=16
    )
)

Further details on the available hybrid mutual information estimators and additional details are reported in the tutorial available in the documentation.

Training and Evaluation

Most of the estimators included in this package are parametric and require a training procedure for accurate estimation. The train_mi_estimator utility function supports either row data x and y as numpy.array or torch.Tensor.

from torch_mist.utils.train import train_mi_estimator

######################################
# Training using tensors for x and y #
######################################
# By default 10% of the data is used for cross-validation and early stopping
train_log = train_mi_estimator(
    estimator=estimator,
    data=(x,y),
    batch_size=64,
    valid_percentage=0.1,
)

Alternatively, it is possible to use a torch.utils.DataLoader that returns eiter batches of pairs (batch_x, batch_y) or dictionaries of batches {'x': batch_x, 'y': batch_y}, with batch_x of shape [batch_size, ..., x_dim] and [batch_size, ..., y_dim] respectively.

#############################
# Training with DataLoaders #
#############################
from torch_mist.utils.data import SampleDataset
from torch.utils.data import DataLoader, random_split

# We provide an utility to make the tensors into a torch.utils.data.Dataset object
# This can be replaced with any other Dataset object that may load the data from disk
dataset = SampleDataset(
    samples={'x': x, 'y': y}
)

# Split into train and validation
train_size = int(len(dataset)*0.9)
valid_size = len(dataset)-train_size
train_set, valid_set = random_split(dataset, [train_size, valid_size])

# Instantiate the dataloaders
train_loader = DataLoader(
    train_set,
    batch_size=64,
    shuffle=True,
    num_workers=8
)

valid_loader = DataLoader(
    valid_set,
    batch_size=64,
    num_workers=8
)

# Train using the specified dataloaders
# Note that the validation set is optional but recommended to prevent overfitting.

train_log = train_mi_estimator(
    estimator=estimator,
    data=train_loader,
    valid_data=valid_loader,
)

The two options result in the same training procedure, but we recommend using DataLoader for larger datasets.

Both DataLoader and torch.Tensor (or np.array) can be used for the evaluate_mi function.

References

[1] Nguyen, XuanLong, Martin J. Wainwright, and Michael I. Jordan. "Estimating divergence functionals and the likelihood ratio by convex risk minimization." IEEE Transactions on Information Theory 56.11 (2010): 5847-5861.

[2] Belghazi, Mohamed Ishmael, et al. "Mutual information neural estimation." International conference on machine learning. PMLR, 2018.

[3] Oord, Aaron van den, Yazhe Li, and Oriol Vinyals. "Representation learning with contrastive predictive coding." arXiv preprint arXiv:1807.03748 (2018).

[4] Poole, Ben, et al. "On variational bounds of mutual information." International Conference on Machine Learning. PMLR, 2019.

[5] Hjelm, R. Devon, et al. "Learning deep representations by mutual information estimation and maximization." arXiv preprint arXiv:1808.06670 (2018).

[6] Song, Jiaming, and Stefano Ermon. "Understanding the limitations of variational mutual information estimators." arXiv preprint arXiv:1910.06222 (2019).

[7] Guo, Qing, et al. "Tight mutual information estimation with contrastive fenchel-legendre optimization." Advances in Neural Information Processing Systems 35 (2022): 28319-28334.

[8] Barber, David, and Felix Agakov. "The im algorithm: a variational approach to information maximization." Advances in neural information processing systems 16.320 (2004): 201.

[9] McAllester, David, and Karl Stratos. "Formal limitations on the measurement of mutual information." International Conference on Artificial Intelligence and Statistics. PMLR, 2020.

[10] Cheng, Pengyu, et al. "Club: A contrastive log-ratio upper bound of mutual information." International conference on machine learning. PMLR, 2020.

[11] Federici, Marco, David Ruhe, and Patrick Forré. "On the Effectiveness of Hybrid Mutual Information Estimation." arXiv preprint arXiv:2306.00608 (2023).

[12] Brekelmans, Rob, et al. "Improving mutual information estimation with annealed and energy-based bounds." arXiv preprint arXiv:2303.06992 (2023).

[13] Kraskov, Alexander, Harald Stögbauer, and Peter Grassberger. "Estimating mutual information." Physical review E 69.6 (2004): 066138.

Contributing

Interested in contributing? Check out the contributing guidelines. Please note that this project is released with a Code of Conduct. By contributing to this project, you agree to abide by its terms.

License

torch_mist was created by Marco Federici. It is licensed under the terms of the MIT license.

Credits

torch_mist was created with cookiecutter and the py-pkgs-cookiecutter template.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

torch_mist-0.2.3.tar.gz (64.6 kB view details)

Uploaded Source

Built Distribution

torch_mist-0.2.3-py3-none-any.whl (104.3 kB view details)

Uploaded Python 3

File details

Details for the file torch_mist-0.2.3.tar.gz.

File metadata

  • Download URL: torch_mist-0.2.3.tar.gz
  • Upload date:
  • Size: 64.6 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: poetry/1.5.1 CPython/3.8.1 Linux/5.4.0-171-generic

File hashes

Hashes for torch_mist-0.2.3.tar.gz
Algorithm Hash digest
SHA256 2c3fcffd769d56508350dae4e7e67d24fc8f76de4cedd2d1e2058a9eca3860a3
MD5 c6803356afee76a75d466c36e8180986
BLAKE2b-256 b07ce1ba8bdd43b85a61fce11bd3dfd57df42fd0311cc4d83b99805fe84717b6

See more details on using hashes here.

File details

Details for the file torch_mist-0.2.3-py3-none-any.whl.

File metadata

  • Download URL: torch_mist-0.2.3-py3-none-any.whl
  • Upload date:
  • Size: 104.3 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: poetry/1.5.1 CPython/3.8.1 Linux/5.4.0-171-generic

File hashes

Hashes for torch_mist-0.2.3-py3-none-any.whl
Algorithm Hash digest
SHA256 1fb583197958eb4a584667f8595789f215a04b2f28ff22da84faba3dd3b61a2d
MD5 0172bcb7848e4fda28187f89ef692f96
BLAKE2b-256 7b93c8498c466f701926154b7e081bbb9abc319b7c3449d0ba265b0c32cf4b57

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page