Skip to main content

Neural mutual information estimators with quantile-based uncertainty

Project description

mist

Mutual Information estimation via Supervised Training

License

MIST is a framework for fully data-driven mutual information (MI) estimation. It leverages neural networks trained on large meta-datasets of distributions to learn flexible, differentiable MI estimators that generalize across sample sizes, dimensions, and modalities. The framework supports uncertainty quantification via quantile regression and provides fast, well-calibrated inference suitable for integration into modern ML pipelines.

This repository contains the reference implementation for the preprint "Mutual Information via Supervised Training". It includes scripts to reproduce our experiments as well as tools for training and evaluating MIST-style MI estimators.

Installation

Install with pip

pip install mist-statinf

Install with conda

conda env create -f environment.yml
conda activate mist-statinf

Install from sources

Alternatively, you can also clone the latest version from the repository and install it directly from the source code:

pip install -e .       

Quickstart: MI on your (X, Y)

If you want to evaluate MI or obtain confidence intervals on your own data using the MIST or MIST-QR models described in the paper, use the MISTQuickEstimator.

Point MI estimate with a MIST

from mist_statinf import MISTQuickEstimator

X, Y = <your data>

mist = MISTQuickEstimator(
    loss="mse",
    checkpoint="checkpoints/mist/weights.ckpt",
)

mi = mist.estimate_point(X, Y)
print("MIST estimate:", mi)

Median MI estimate and quantile-based confidence intervals with MIST-QR

from mist_statinf import MISTQuickEstimator 

X, Y = <your data>

mist_qr = MISTQuickEstimator(
    loss="qr",
    checkpoint="checkpoints/mist_qr/weights.ckpt", 
)

mi_median = mist_qr.estimate_point(X, Y)
print("Median MI:", mi_median)

mi_q90 = mist_qr.estimate_point(X, Y, tau=0.90)
print("q90 MI estimate:", mi_q90)

# --- fast quantile-based uncertainty interval ---
interval = mist_qr.estimate_interval_qr(X, Y, lower=0.05, upper=0.95)
print(interval)

By default, MISTQuickEstimator loads the pretrained models used in the paper from the package’s checkpoints/ directory, using the architecture defined in configs/inference/quickstart.yaml. You can override both the checkpoint and the architecture if you have your own trained models.

Evaluating estimators on test sets

If you want to reproduce the experiments from the paper, we recommend evaluating our trained estimators on the provided test sets (M_test and M_test_extended).

Since the test sets take a considerable amount of storage space, we publish them separately on Zenodo.
Before running inference, download the desired subset (either M_test or M_test_extended).
Below we show an example using M_test, as it is significantly lighter.

mist-statinf get-data --preset m_test_imd --dir data/test_imd_data
mist-statinf get-data --preset m_test_oomd --dir data/test_oomd_data

The simplest way to run inference on these datasets is:

mist-statinf infer configs/inference/mist_inference.yaml "checkpoints/mist/" 

NOTE: The file mist_inference.yaml allows you to configure the evaluation mode (bootstrap or QCQR calibration), select the specific test subset, and specify which quantiles to compute.

Below we show the results we obtained on M_test:

mist

Train your own MIST Estimators

If you want to reproduce the full training pipeline from the paper — possibly with your own modifications — we recommend following the workflow below.

1. Data Generation

mist-statinf generate configs/data_generation/train.yaml # the same for test and val

The generated datasets and their corresponding configuration files will appear under data/train_data and etc.

2. Train a MIST Model

mist-statinf train configs/train/mist_train.yaml

Inside the training config you can switch between MSE training and QCQR training. After training, logs, configs, and the saved model checkpoint will be stored under: logs/mist_train/run_YYYYmmdd-HHMMSS

3. Running Baselines

mist-statinf baselines configs/inference/baselines.yaml

Baseline results, logs, and configs will be saved to: logs/bmi_baselines.

3. Test Stage

mist-statinf infer configs/inference/mist_inference.yaml "logs/mist_train/run_YYYYmmdd-HHMMSS"

This will produce CSV predictions and a JSON summary in the same run directory: logs/mist_train/run_YYYYmmdd-HHMMSS.

4*. (Optional) Hyperparameter Search

mist-statinf tune logs/mist_train/run_YYYYmmdd-HHMMSS --model-type MSE --n-trials 30

This performs a parameter search (via Optuna) starting from a given training run.

Citation

If you use MIST or MIST-QR in your work, please cite:

@article{mist2025,
  title   = {Mutual Information via Supervised Training},
  author  = {German Gritsai and Megan Richards and Maxime Meloux and Kyunghyun Cho and Maxime Peyrard},
  journal = {arXiv preprint arXiv:XXXX.XXXXX},
  year    = {2025},
}

Authors

German Gritsai, Megan Richards, Maxime Meloux, Kyunghyun Cho, Maxime Peyrard.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

mist_statinf-0.1.0.tar.gz (46.1 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

mist_statinf-0.1.0-py3-none-any.whl (55.3 kB view details)

Uploaded Python 3

File details

Details for the file mist_statinf-0.1.0.tar.gz.

File metadata

  • Download URL: mist_statinf-0.1.0.tar.gz
  • Upload date:
  • Size: 46.1 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.10.14

File hashes

Hashes for mist_statinf-0.1.0.tar.gz
Algorithm Hash digest
SHA256 8b5c4811fadad99ffd61a968f70c620d8ae694c75b957edf2e5edab29b5ef2cc
MD5 4501fead6c4ee4cfa174473da6df9980
BLAKE2b-256 41f10cc5b918fa9c398b7c9cff71c41006deb80c34343b53ac296e327c00cf5c

See more details on using hashes here.

File details

Details for the file mist_statinf-0.1.0-py3-none-any.whl.

File metadata

  • Download URL: mist_statinf-0.1.0-py3-none-any.whl
  • Upload date:
  • Size: 55.3 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.10.14

File hashes

Hashes for mist_statinf-0.1.0-py3-none-any.whl
Algorithm Hash digest
SHA256 bded6c33dec8eaa075b58b8d80bc6493f6dbe4524667b2e12005351984b47f1c
MD5 367cdebd6836e0031f1b620d94e051a6
BLAKE2b-256 56d3f2defe6f8b8756462b988feaac8c92e236aac0ef9856db09c5472e4c3c4e

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page