An implementation of the divergence framework as described here https://arxiv.org/abs/2405.06397
Project description
IWPC - I Will Prove Convergence
This package implements the methods described in the research paper https://arxiv.org/abs/2405.06397 for estimating a lower bound on the divergence between any two distributions, p and q, using samples from each distribution.
Install using pip install iwpc
This machine learning code in this package is organised using the fantastic PyTorch Lightning package. Some familiarity with the structure of lightning is recommended.
Basic Usage
The most basic usage of this package is for calculating an estimator for a lower bound on an f-divergence (such as the KL-divergence) between two distribution, p and q. Each example below assumes one is provided with a set of samples from drawn from distribution p and from distribution q labelled 0 and 1 respectively
calculate_divergence
The example script examples/continuous_example_2D.py
continuous_example_2D.py shows the most basic usage of the calculate_divergence function
run on the components of 2D vectors drawn from the distribution N(r | 0, 0.1) * (1 + eps * cos(theta)) / 2 / pi
for
the two values eps = 0.
and eps = 0.2
. The script shows how to calculate estimates for lower bounds on both the Kullback-Leibler
divergence and the Jensen-Shannon divergence between the two distributions and compares these to numerically integrated
values. At the most basic level, all calculate_divergence requires is a LightningDataModule
, in this case an
instance of BinaryPandasDataModule
, to provide the data and an instance of FDivergenceEstimator
, in this case an
instance of GenericNaiveVariationalFDivergenceEstimator
to provide the machine learning model.
BinaryPandasDataModule
Given two Pandas data frames containing the samples from p and q, the BinaryPandasDataModule class provides a convenient wrapper that casts the data into the form expected by calculate_divergence. In addition to the two dataframes, the BinaryPandasDataModule requires the user to specify which features columns to use, the two cartesian components 'x' and 'y' in this case, as well as the name of a weight column if one exists. By default, all data modules in this package provide a 50-50 train-validation split.
GenericNaiveVariationalFDivergenceEstimator
For generic data without any specific structures to inspire the topology of the machine learning model, the LightningModule
subclass GenericNaiveVariationalFDivergenceEstimator provides a generic N to 1 dimensional
function approximator needed for the divergence calculation. The only information required is the number of training
features and a DifferentiableFDivergence
instance to tell the module which divergence to calculate.
Output
The calculate_divergence trains the provided LightningDataModule while logging, by default, to a lightning_logs
directory
placed into the same directory as the main script. A subdirectory is created inside the lightning_logs
directory each time the
calculate_divergence function is run which contains the training results, logs,
and model checkpoints for the given run. The progress of the training may be monitored in your browser using
tensorboard --logdir .../lightning_logs
. calculate_divergence returns an
instance of DivergenceResult which contains the final divergence lower bound estimate,
its error, as well as the best version of the trained model and some other useful properties.
The script renders these results as a function of the number of samples provided in two plots:
run_reweight_loop
The example script example_reweight_loop.py shows a more sophisticated implementation of the divergence framework. Typically, datasets are too large to fit into memory at once and so complicated that machine learning models tend to get caught up in local minima when training. To alleviate these two problems this script demonstrates the usage of PandasDirDataModule for splitting datasets up into manageable chunks dynamically loaded into memory, and the run_reweight_loop function that iteratively reweights the data when training stagnates and restarts training afresh to allow the network to focus on other features. The data in this example was drawn from the same distribution as the previous example, so the reweight loop is most certainly overkill for the given data complexity, however the reweighting procedure has been very useful within our own work with significantly more complicated data.
PandasDirDataModule
The PandasDirDataModule class provides an extremely generic implementation of a dataset which is stored on disk in separate pickle files containing Pandas dataframes that are automatically and efficiently loaded into memory as needed. See the PandasDirDataModule docstring for more information. This DataModule is recommended, even when working with smaller datasets, as the data is stored in a convenient portable form with relevant metadata.
BinnedDfAccumulator
Once trained to find a difference between p and q, the natural question becomes, how exactly is the machine learning model telling the two distributions apart. This is an extremely important question as the conclusion that some difference exists is typically uninformative, as the network is free to pick up on any differences between p and q, including those we might deem uninteresting. The BinnedDfAccumulator class assists in this process by calculating the degree to which the obtained degree of divergence can be explained by the marginal distribution in a given set of variables alone, as well as how the degree of divergence changes as a function of these values. Although the BinnedDfAccumulator is generically written for any number of dimensions, the plotting features are only implemented in 1D and 2D currently. As a result, the examples given in example_reweight_loop.py are 1D and 2D.
The first plot below shows how the divergence behaves as a function of the radius, r, of the samples. For 1D plots, the top left plot shows the (weighted) distribution of the variable in the validation dataset. Since the radius of the samples from p and q were drawn from the same gaussian, these two histograms unsurprisingly show no signs of disagreement and the estimated divergence in r alone is consistent with 0. The top right plot shows an estimate, derived from the same network trained in the reweight loop, of the divergence of the distributions of the samples which landed within each bin. In this case, since the only other variable orthogonal to r is theta, this amounts to the divergence in the theta distributions at each fixed value of r. This plot suggests that the divergence is constant as a function of r, once again unsurprising given the form of p and q.
The source of the divergence is obvious in the theta plots below. The estimate of the marginalised divergence (ie the divergence attributable to the theta distribution alone) is consistent with the global divergence, confirming this to be the only source of divergence between p and q. This is also clearly visible in the top right as the divergence conditioned on theta is consistent with 0 in all cases. The bottom left and bottom right plots confirm that the network has in fact learnt the features in the data. The full explanation of how these plots are calculated is a little involved, but suffice to say, they demonstrate the ways in which the network believes the distributions look in the given variable. The error bars on the 'learned' quantities indicate uncertainty on how well we are able to reconstruct what the network believes the distribution to be. These should not be interpreted as error-bars indicating how far the truth may be from what the network has learnt, and these plots may well demonstrate hallucinations.
The 2D plot in theta and r is mostly redundant in this case, but we can clearly see many of the features discussed in the 1D case. Top left shows the ratio of the two distributions in validation. Top right once shows the divergence within each bin. Bottom left once again shows an estimate for what the net believes the ratio of the two distributions is, and the bottom right is simple a histogram of the p distribution in validation.
Help and Suggestions
For any suggestions or questions please reach out to Jeremy J. H. Wilkinson.
If this tool has been helpful in your research, please consider citing https://arxiv.org/abs/2405.06397.
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
File details
Details for the file iwpc-0.1.0.tar.gz
.
File metadata
- Download URL: iwpc-0.1.0.tar.gz
- Upload date:
- Size: 38.8 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/4.0.2 CPython/3.11.9
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | 23c3f156ab0fbd5dc96bcd408597730ce5ab33a3d73d6c4b72ebfe1f54473f61 |
|
MD5 | c79c27b4787999cfb1741f3411049d4f |
|
BLAKE2b-256 | 729eb0ce22945a4b65fd3f68936dac58ab877698081f1d5ad888f237b9e307a3 |
File details
Details for the file iwpc-0.1.0-py3-none-any.whl
.
File metadata
- Download URL: iwpc-0.1.0-py3-none-any.whl
- Upload date:
- Size: 47.0 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/4.0.2 CPython/3.11.9
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | 995aa3a6dcfde81f806c2f8c486088ac84aedebbfce1cf3f9892bca4250aedc2 |
|
MD5 | 89d510d9f49850fc89e23f5b683f7295 |
|
BLAKE2b-256 | b1f539cf23ce29c7edb62c5a1fba26cf81f27f3adc65a04d5d8e42601ecdca4b |