Skip to main content

Prediction-Powered Inference

Project description

Prediction-powered inference (PPI) is a framework for statistically rigorous scientific discovery using machine learning. Given a small amount of data with gold-standard labels and a large amount of unlabeled data, prediction-powered inference allows for the estimation of population parameters, such as the mean outcome, median outcome, linear and logistic regression coefficients. Prediction-powered inference can be used both to produce better point estimates of these quantities as well as tighter confidence intervals and more powerful p-values. The methods work both in the i.i.d. setting and for certain classes of distribution shifts.

This package is actively maintained, and contributions from the community are welcome.

Getting Started

In order to install the package, run

pip install ppi-python

This will build and install the most recent version of the package.

Warmup: estimating the mean

To test your installation, you can try running the prediction-powered mean estimation algorithm on the galaxies dataset. The gold-standard labels and model predictions from the dataset will be downloaded into a folder called ./data/. The labels, $Y$, are binary indicators of whether or not the galaxy is a spiral galaxy. The model predictions, $\hat{Y}$, are the model's estimated probability of whether the galaxy image has spiral arms. The inference target is $\theta^* = \mathbb{E}[Y]$, the fraction of spiral galaxies. You will produce a confidence interval, $\mathcal{C}^{\mathrm{PP}}_\alpha$, which contains $\theta^*$ with probability $1-\alpha=0.9$, i.e.,

$$\mathbb{P}\left( \theta^* \in \mathcal{C}^{\mathrm{PP}}_\alpha\right) \geq 0.9.$$

The code for this is below. It can be copy-pasted directly into the Python REPL.

# Imports
import numpy as np
from ppi_py import ppi_mean_ci
from ppi_py.datasets import load_dataset
np.random.seed(0) # For reproducibility's sake
# Download and load dataset
data = load_dataset('./data/', "galaxies")
Y_total = data["Y"]; Yhat_total = data["Yhat"]
# Set up the inference problem
alpha = 0.1 # Error rate
n = 1000 # Number of labeled data points
rand_idx = np.random.permutation(Y_total.shape[0])
Yhat = Yhat_total[rand_idx[:n]]
Y = Y_total[rand_idx[:n]]
Yhat_unlabeled = Yhat_total[n:]
# Produce the prediction-powered confidence interval
ppi_ci = ppi_mean_ci(Y, Yhat, Yhat_unlabeled, alpha=alpha)
# Print the results
print(f"theta={Y_total.mean():.3f}, CPP={ppi_ci}")

The expected results look as below $^*$:

theta=0.259, CPP=(0.235677274705698, 0.26595223970754855)

($^*$ these results were produced with numpy=1.26.0, and may differ slightly due to randomness in other environments.)

If you have reached this stage, congratulations! You have constructed a prediction-powered confidence interval. See the documentation for more usages of prediction-powered inference.

Examples

The package somes with a suite of examples on real data:

Usage and Documentation

There is a common template that all PPI confidence intervals follow.

ppi_[ESTIMAND]_ci(X, Y, Yhat, X_unlabeled, Yhat_unlabeled, alpha=0.1)

You can replace [ESTIMAND] with the estimand of your choice. For certain estimands, not all the arguments are required, and in this case, they are omitted. For example, in the case of mean estimation, the function signature is:

ppi_mean_ci(Y, Yhat, Yhat_unlabeled, alpha=0.1)

All the prediction-powered point estimates and confidence intervals implemented so far can be imported by running from ppi_py import ppi_[ESTIMAND]_pointestimate, ppi_[ESTIMAND]_ci. For the case of the mean, one can also import the p-value as from ppi import ppi_mean_pval.

Full documentation is available here.

Repository structure

The repository is organized into three main folders:

  • ./ppi_py/
  • ./examples/
  • ./tests/

The first foler, ./ppi_py, contains all the code that eventually gets compiled into the ppi_py package. Most importantly, there is a file, ./ppi_py/ppi.py, which implements all the prediction-powered point estimates, confidence intervals, and p-values for different estimators. There is also a file, ./ppi_py/baselines.py, which implements several baselines. Finally, the file ./ppi_py/datasets/datasets.py handles the loading of the sample datasets.

The folder ./examples contains notebooks for implementing prediction-powered inference on several datasets and estimands. These are listed above. There is also an additional subfolder, ./examples/baselines, which contains comparisons to certain baseline algorithms, as in the appendix of the original PPI paper.

The folder ./tests contains unit tests for each function implemented in the ppi_py package. The tests are organized by estimand, and can be run by executing pytest in the root directory. Some of the tests are stochastic, and therefore, have some failure probability, even if the functions are all implemented correctly. If a test fails, it may be worth running it again. Debugging the tests can be done by adding the -s flag and using print statements or pdb. Note that in order to be recognized by pytest, all tests must be preceded by test_.

The remainder of the files/folders are boilerplate and not relevant to most users.

Contributing

Thank you so much for considering making a contribution to ppi_py; we deeply value and appreciate it.

The contents of this repository will be pushed to PyPI whenever there are substantial revisions. If there are methods or examples within the PPI framework you'd like to see implemented, feel free to suggest them on the issues page. Community contributions are welcome and encouraged as pull requests directly onto the main branch. The main criteria for accepting such pull requests is:

  • The contribution should align with the repository's scope.
  • All new functionality should be tested for correctness within our existing pytest framework.
  • If the pull request involves a new PPI method, it should have a formal mathematical proof of validity which can be referenced.
  • If the pull request solves a bug, there should be a reproducible bug (within a specific environment) that is solved. Bug reports can be made on the issues page.
  • The contribution should be well documented.
  • The pull request should be of generally high quality, up to the review of the repository maintainers. The repository maintainers will approve pull requests at their discretion. Before working on one, it may be helpful to post a question on the issues page to verify if the contribution would be a good candidate for merging into the main branch.

Accepted pull requests will be run through an automated Black formatter, so contributors may want to run Black locally first.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

ppi-python-0.2.2.tar.gz (25.6 kB view details)

Uploaded Source

Built Distribution

ppi_python-0.2.2-py3-none-any.whl (27.6 kB view details)

Uploaded Python 3

File details

Details for the file ppi-python-0.2.2.tar.gz.

File metadata

  • Download URL: ppi-python-0.2.2.tar.gz
  • Upload date:
  • Size: 25.6 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/5.0.0 CPython/3.9.12

File hashes

Hashes for ppi-python-0.2.2.tar.gz
Algorithm Hash digest
SHA256 3d30257c06fd6ca68d35b0d3c4cc3c4ab3a937b4bdb04180a4fe9ae149400e20
MD5 979a7a2063220e5f7b3deb302e21b410
BLAKE2b-256 9b861c3d96d7ad0e6a4fb0b2fea4dfe5ef8a00579226662131f91eabd213cd42

See more details on using hashes here.

File details

Details for the file ppi_python-0.2.2-py3-none-any.whl.

File metadata

  • Download URL: ppi_python-0.2.2-py3-none-any.whl
  • Upload date:
  • Size: 27.6 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/5.0.0 CPython/3.9.12

File hashes

Hashes for ppi_python-0.2.2-py3-none-any.whl
Algorithm Hash digest
SHA256 fedac49c4f555d950b9709d050cb65fbf1539895ac55f77e4cf95890514a19ec
MD5 cbbe0607fbc1c5b03794b94d5f1e111b
BLAKE2b-256 fa876c413543b7ce7d29806036d5a9536f1d55a3067ac8095c8f65a0710a55b3

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page