Skip to main content

An implementation of the divergence framework as described here https://arxiv.org/abs/2405.06397

Project description

IWPC - I Will Prove Convergence

This package implements the methods described in the research paper https://arxiv.org/abs/2405.06397 for estimating a lower bound on the divergence between any two distributions, p and q, using samples from each distribution.

Install using pip install iwpc

This machine learning code in this package is organised using the fantastic PyTorch Lightning package. Some familiarity with the structure of lightning is recommended.

Basic Usage

The most basic usage of this package is for calculating an estimator for a lower bound on an f-divergence (such as the KL-divergence) between two distribution, p and q. Each example below assumes one is provided with a set of samples from drawn from distribution p and from distribution q labelled 0 and 1 respectively

calculate_divergence

The example script examples/continuous_example_2D.pycontinuous_example_2D.py shows the most basic usage of the calculate_divergence function run on the components of 2D vectors drawn from the distribution N(r | 0, 0.1) * (1 + eps * cos(theta)) / 2 / pi for the two values eps = 0. and eps = 0.2. The script shows how to calculate estimates for lower bounds on both the Kullback-Leibler divergence and the Jensen-Shannon divergence between the two distributions and compares these to numerically integrated values. At the most basic level, all calculate_divergence requires is a LightningDataModule, in this case an instance of BinaryPandasDataModule, to provide the data and an instance of FDivergenceEstimator, in this case an instance of GenericNaiveVariationalFDivergenceEstimator to provide the machine learning model.

BinaryPandasDataModule

Given two Pandas data frames containing the samples from p and q, the BinaryPandasDataModule class provides a convenient wrapper that casts the data into the form expected by calculate_divergence. In addition to the two dataframes, the BinaryPandasDataModule requires the user to specify which features columns to use, the two cartesian components 'x' and 'y' in this case, as well as the name of a weight column if one exists. By default, all data modules in this package provide a 50-50 train-validation split.

GenericNaiveVariationalFDivergenceEstimator

For generic data without any specific structures to inspire the topology of the machine learning model, the LightningModule subclass GenericNaiveVariationalFDivergenceEstimator provides a generic N to 1 dimensional function approximator needed for the divergence calculation. The only information required is the number of training features and a DifferentiableFDivergence instance to tell the module which divergence to calculate.

Output

The calculate_divergence trains the provided LightningDataModule while logging, by default, to a lightning_logs directory placed into the same directory as the main script. A subdirectory is created inside the lightning_logs directory each time the calculate_divergence function is run which contains the training results, logs, and model checkpoints for the given run. The progress of the training may be monitored in your browser using tensorboard --logdir .../lightning_logs. calculate_divergence returns an instance of DivergenceResult which contains the final divergence lower bound estimate, its error, as well as the best version of the trained model and some other useful properties.

The script renders these results as a function of the number of samples provided in two plots:

KL-divergence-sample-size.png

run_reweight_loop

The example script example_reweight_loop.py shows a more sophisticated implementation of the divergence framework. Typically, datasets are too large to fit into memory at once and so complicated that machine learning models tend to get caught up in local minima when training. To alleviate these two problems this script demonstrates the usage of PandasDirDataModule for splitting datasets up into manageable chunks dynamically loaded into memory, and the run_reweight_loop function that iteratively reweights the data when training stagnates and restarts training afresh to allow the network to focus on other features. The data in this example was drawn from the same distribution as the previous example, so the reweight loop is most certainly overkill for the given data complexity, however the reweighting procedure has been very useful within our own work with significantly more complicated data.

PandasDirDataModule

The PandasDirDataModule class provides an extremely generic implementation of a dataset which is stored on disk in separate pickle files containing Pandas dataframes that are automatically and efficiently loaded into memory as needed. See the PandasDirDataModule docstring for more information. This DataModule is recommended, even when working with smaller datasets, as the data is stored in a convenient portable form with relevant metadata.

BinnedDfAccumulator

Once trained to find a difference between p and q, the natural question becomes, how exactly is the machine learning model telling the two distributions apart. This is an extremely important question as the conclusion that some difference exists is typically uninformative, as the network is free to pick up on any differences between p and q, including those we might deem uninteresting. The BinnedDfAccumulator class assists in this process by calculating the degree to which the obtained degree of divergence can be explained by the marginal distribution in a given set of variables alone, as well as how the degree of divergence changes as a function of these values. Although the BinnedDfAccumulator is generically written for any number of dimensions, the plotting features are only implemented in 1D and 2D currently. As a result, the examples given in example_reweight_loop.py are 1D and 2D.

The first plot below shows how the divergence behaves as a function of the radius, r, of the samples. For 1D plots, the top left plot shows the (weighted) distribution of the variable in the validation dataset. Since the radius of the samples from p and q were drawn from the same gaussian, these two histograms unsurprisingly show no signs of disagreement and the estimated divergence in r alone is consistent with 0. The top right plot shows an estimate, derived from the same network trained in the reweight loop, of the divergence of the distributions of the samples which landed within each bin. In this case, since the only other variable orthogonal to r is theta, this amounts to the divergence in the theta distributions at each fixed value of r. This plot suggests that the divergence is constant as a function of r, once again unsurprising given the form of p and q.

divergence_vs_r.png

The source of the divergence is obvious in the theta plots below. The estimate of the marginalised divergence (ie the divergence attributable to the theta distribution alone) is consistent with the global divergence, confirming this to be the only source of divergence between p and q. This is also clearly visible in the top right as the divergence conditioned on theta is consistent with 0 in all cases. The bottom left and bottom right plots confirm that the network has in fact learnt the features in the data. The full explanation of how these plots are calculated is a little involved, but suffice to say, they demonstrate the ways in which the network believes the distributions look in the given variable. The error bars on the 'learned' quantities indicate uncertainty on how well we are able to reconstruct what the network believes the distribution to be. These should not be interpreted as error-bars indicating how far the truth may be from what the network has learnt, and these plots may well demonstrate hallucinations.

divergence_vs_theta.png

The 2D plot in theta and r is mostly redundant in this case, but we can clearly see many of the features discussed in the 1D case. Top left shows the ratio of the two distributions in validation. Top right once shows the divergence within each bin. Bottom left once again shows an estimate for what the net believes the ratio of the two distributions is, and the bottom right is simple a histogram of the p distribution in validation.

divergence_vs_r_theta.png

Help and Suggestions

For any suggestions or questions please reach out to Jeremy J. H. Wilkinson.

If this tool has been helpful in your research, please consider citing https://arxiv.org/abs/2405.06397.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

iwpc-0.1.0.tar.gz (38.8 kB view details)

Uploaded Source

Built Distribution

iwpc-0.1.0-py3-none-any.whl (47.0 kB view details)

Uploaded Python 3

File details

Details for the file iwpc-0.1.0.tar.gz.

File metadata

  • Download URL: iwpc-0.1.0.tar.gz
  • Upload date:
  • Size: 38.8 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/4.0.2 CPython/3.11.9

File hashes

Hashes for iwpc-0.1.0.tar.gz
Algorithm Hash digest
SHA256 23c3f156ab0fbd5dc96bcd408597730ce5ab33a3d73d6c4b72ebfe1f54473f61
MD5 c79c27b4787999cfb1741f3411049d4f
BLAKE2b-256 729eb0ce22945a4b65fd3f68936dac58ab877698081f1d5ad888f237b9e307a3

See more details on using hashes here.

File details

Details for the file iwpc-0.1.0-py3-none-any.whl.

File metadata

  • Download URL: iwpc-0.1.0-py3-none-any.whl
  • Upload date:
  • Size: 47.0 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/4.0.2 CPython/3.11.9

File hashes

Hashes for iwpc-0.1.0-py3-none-any.whl
Algorithm Hash digest
SHA256 995aa3a6dcfde81f806c2f8c486088ac84aedebbfce1cf3f9892bca4250aedc2
MD5 89d510d9f49850fc89e23f5b683f7295
BLAKE2b-256 b1f539cf23ce29c7edb62c5a1fba26cf81f27f3adc65a04d5d8e42601ecdca4b

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page