Skip to main content

A library for time series conformal prediction

Project description

Improved Online Conformal Prediction via Strongly Adaptive Online Learning

This library implements numerous algorithms which perform conformal prediction on data with arbitrary distribution shifts over time. This is the official implementation for the paper Bhatnagar et al., "Improved Online Conformal Prediction via Strongly Adaptive Online Learning," 2023. We include reference implementations for the proposed methods Strongly Adaptive Online Conformal Prediction (SAOCP) and Scale-Free Online Gradient Descent (SF-OGD), as well as Split Conformal Prediction (Vovk et al., 1999), Non-Exchangeable Conformal Prediction (Barber et al., 2022), and Fully Adaptive Conformal Inference (FACI, Gibbs & Candes, 2022).

Replicating Our Experiments

First install the online_conformal package by cloning this repo and calling pip install .. To run our time series forecasting experiments, first clone the Merlion repo and install their ts_datasets package. Then, you can call

python time_series.py --model <model> --dataset <dataset> --njobs <njobs>

where <model> can be one of LGBM, ARIMA, or Prophet; <dataset> can be one of M4_Hourly, M4_Daily, M4_Weekly, or NN5_Daily; and <njobs> indicates the number of parallel cores you wish to parallelize the file with. The results will be written to a sub-directory results.

To run our experiments on image classification under distribution shift, first install PyTorch. Then, you can call

python vision.py --dataset <dataset>

where dataset is one of ImageNet or TinyImageNet. Various intermediate results will be written to sub-folders, and checkpointing (e.g. for model training) is automatic.

Using Our Code

To use our code, first install the online_conformal package by calling pip install online_conformal. You can alternatively install the package from source by cloning this repo and calling pip install ..

Each online conformal prediction method is implemented as its own class in the package. All methods share a common API. For time series forecasting, we leverage models implemented in Merlion. Below, we demonstrate how to use SAOCP to create prediction intervals for multi-horizon time series forecasting. The update loop is a simplified version of calling saocp.forecast(time_series=test_data.iloc[:horizon], time_series_prev=train_data), whose implementation you can find here.

import pandas as pd
from merlion.models.factory import ModelFactory
from merlion.utils import TimeSeries
from online_conformal.dataset import M4
from online_conformal.saocp import SAOCP

# Get some time series data as pandas.DataFrames
data = M4("Hourly")[0]
train_data, test_data = data["train_data"], data["test_data"]
# Initialize a Merlion model for time series forecasting
model = ModelFactory.create(name="LGBMForecaster")
# Initialize the SAOCP wrapper on top of the model. This splits the data 
# into train/calibration splits, trains the model on the train split, 
# and initializes SAOCP's internal state on the calibration split.
# The target coverage is 90% here, but you can adjust this freely.
# We also do 24-step-ahead forecasting by setting horizon=24.
horizon = 24
saocp = SAOCP(model=model, train_data=train_data, coverage=0.9,
              calib_frac=0.2, horizon=horizon)

# Get the model's 24-step-ahead prediction, and convert it to prediction intervals
yhat, _ = saocp.model.forecast(horizon, time_series_prev=TimeSeries.from_pd(train_data))
delta_lb, delta_ub = zip(*[saocp.predict(horizon=h + 1) for h in range(horizon)])
yhat = yhat.to_pd().iloc[:, 0]
lb, ub = yhat + delta_lb, yhat + delta_ub

# Update SAOCP's internal state based on the next 24 observations
prev = train_data.iloc[:-horizon + 1]
time_series = pd.concat((train_data.iloc[-horizon + 1:], test_data.iloc[:horizon]))
for i in range(len(time_series)):
    # Predict yhat_{t-H+i+1}, ..., yhat_{t-H+i+H} = f(y_1, ..., y_{t-H+i}) 
    y = time_series.iloc[i:i + horizon, 0]
    yhat, _ = saocp.model.forecast(y.index, time_series_prev=TimeSeries.from_pd(prev))
    yhat = yhat.to_pd().iloc[:, 0]
    # Use h-step prediction of yhat_{t-k+h} to update SAOCP's h-step prediction interval
    for h in range(len(y)):
        if i >= h:
            saocp.update(ground_truth=y[h:h + 1], forecast=yhat[h:h + 1], horizon=h + 1)
    prev = pd.concat((prev, time_series.iloc[i:i+1]))

For other use cases, you can initialize saocp = SAOCP(model=None, train_data=None, max_scale=max_scale, coverage=0.9). Here, max_scale indicates the largest value you expect the conformal score to take. Then, you can obtain the conformal score corresponding to 90% (or your desired level of coverage) by calling score = saocp.predict(horizon=1)[1], and you can use this value to compute the prediction set {y: S(X_t, y) < score} using your own custom code. Finally, after you observe the true conformal score new_score = S(X_t, Y_t), you can update the conformal predictor by calling saocp.update(ground_truth=pd.Series([new_score]), forecast=pd.Series([0]), horizon=1).

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

online_conformal-1.0.2.tar.gz (20.8 kB view details)

Uploaded Source

Built Distribution

online_conformal-1.0.2-py3-none-any.whl (23.4 kB view details)

Uploaded Python 3

File details

Details for the file online_conformal-1.0.2.tar.gz.

File metadata

  • Download URL: online_conformal-1.0.2.tar.gz
  • Upload date:
  • Size: 20.8 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/4.0.1 CPython/3.11.2

File hashes

Hashes for online_conformal-1.0.2.tar.gz
Algorithm Hash digest
SHA256 384853218ab8e7e770ebf4f1bf7c3c3b29b52567cf790295026b069392d39128
MD5 15ee8b87c9f62a4eeba187a86b7a7d2c
BLAKE2b-256 61205fb56afb2c071b60a09142bb886c7581b287d7d9a6dc0bc029d3cd31104b

See more details on using hashes here.

File details

Details for the file online_conformal-1.0.2-py3-none-any.whl.

File metadata

File hashes

Hashes for online_conformal-1.0.2-py3-none-any.whl
Algorithm Hash digest
SHA256 4841ea860dab6d2fad2f97d08adc488ccee00849d6171011a1871a8aa9f90572
MD5 7e81a5cf40ca678d9811dfa452b0db56
BLAKE2b-256 b05147e1fd4472a9297ac7a73160243eb482e5f176f52d7ca604dc9d819c5abd

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page