Skip to main content

Time series forecasting with Gaussian Processes

Project description

Time series forecasting with Gaussian Processes

Related Publication

The theoretical description of the algorithm implemented in this software and empirical results can be found in:

"Time series forecasting with Gaussian Processes needs priors"
ECML PKDD 2021: Proc. Machine Learning and Knowledge Discovery in Databases. Applied Data Science Track pp 103–117
Giorgio Corani, Alessio Benavoli, Marco Zaffalon

gpforecast package

The software includes a small package that builds the gaussian process and uses it to produce predictions. The package heavily relies on GPy. A convenience script can be used to run the GP over collections of timeseries.

Installation

gpforecast requires Python 3.8+ to run. Create a new conda environment named ENVNAME, install Python 3.8

conda create --name ENVNAME python=3.8

Activate the new environment to use it

conda activate ENVNAME

Install gpforecast with dependencies

pip install gpforecast

Tutorial [Python]

from gpforecast import GP, get_sample_ts

# STEP 0. Get a sample time series (Monthly or Quarterly time series from M3)
f = 'quarterly'  # 'monthly' or 'quarterly'
sample = get_sample_ts(frequency=f)

# STEP 1. Construct our model
g = GP(frequency=f, priors=True)

# STEP 2. Build our model
g.build(sample['Y'])

# STEP 3. Forecast
res = g.forecast(len(sample['YY']), level=[80,95])

# STEP 4. Evaluating forecast accuracy {MAE, CRPS, LL}, you could implement your own!
mean = res.PointForecast
u95  = res.Hi95
acc = g.compute_indicators(sample['YY'], mean, u95, level=95)

# STEP 5. Forecasts Plot
import matplotlib.pyplot as plt

def forecast_plot(train, res):
    plt.style.use('ggplot')
    plt.figure(figsize=(16,8))
    plt.plot(list(range(1,len(train)+1)), train.reshape(-1,), color='black')
    plt.plot(list(range(len(train)+1,len(train)+len(res.PointForecast)+1)), 
             res.PointForecast, color='blue')
    levels = []
    for i in res.columns:
        if 'Lo' in i: levels.append(int(i[2:]))
    for i, l in enumerate(sorted(levels)):
        plt.fill_between(list(range(len(train)+1,len(train)+len(res.PointForecast)+1)), 
                         res['Lo'+str(l)], res['Hi'+str(l)], color='blue', alpha=0.3-0.1*i)
    plt.title('Forecasts from GP', loc='left', fontsize=16)
    plt.xlabel("Time", fontsize=14)
    plt.gca().tick_params(axis='both', which='major', labelsize=12)
    
forecast_plot(sample['Y'], res)

Tutorial [R]

gpforecast could be used in R thanks to Reticulate

install.packages("Reticulate")
library(reticulate)
use_virtualenv("ENVNAME")           # python environment with gpforecast
gpforecast <- import("gpforecast")  # load Python module
# STEP 0. Get a sample time series (Monthly or Quarterly time series from M3)
f = 'monthly'  # 'monthly' or 'quarterly'
sample = gpforecast$get_sample_ts(frequency=f)

# STEP 1. Construct our model
g <- gpforecast$GP(frequency=f, priors=TRUE)

# STEP 2. Build our model
g$build(sample$Y)

# STEP 3. Forecast
res <- g$forecast(length(sample$YY), level=as.integer(c(80,95)))

# STEP 4. Evaluating forecast accuracy {MAE, CRPS, LL}, you could implement your own!
mean <- as.array(as.vector(t(res['PointForecast'])))
u95  <- as.array(as.vector(t(res['Hi95'])))
acc <- g$compute_indicators(sample$YY, mean, u95, level=as.integer(95))

# STEP 5. Forecasts Plot
library(ggplot2)

forecast_plot <- function(train, res) {
    options(repr.plot.width = 16, repr.plot.height = 8)
    train_df <- data.frame(
        X = 1:length(train), 
        Y = c(as.vector(t(train))))
    res$h = 1:length(res[,'PointForecast']) + length(sample$Y)
    p <- ggplot() +   
    geom_line(data = train_df, aes(X, Y), color = "black", size = 0.8) +
    geom_line(data = res, aes(h, res[,'PointForecast']), color = "blue", size = 0.8)
    levels <- list()
    l <-1
    for (i in colnames(res)) {
        if (grepl('Lo', i, fixed = TRUE)) {
            levels[l] <- substr(i,3,nchar(i))
            l <- l + 1
        }
    }
    levels <- sort(as.numeric(levels))
    for (i in 1:length(levels)) {
        lo <- paste0('Lo',levels[i])
        hi <- paste0('Hi',levels[i])
        p <- p +
        geom_ribbon(data = res, 
                aes_string(x='h', ymin=lo, ymax=hi), 
                alpha=0.4-0.1*i, fill = "blue")
    }
    p <- p + ggtitle('Forecasts from GP') + xlab('Time') +
    theme(
        axis.title.y=element_blank(),
        plot.title = element_text(size = 20),
        axis.title = element_text(size = 18),
        axis.text = element_text(size = 14))
    return (p)
}

forecast_plot(sample$Y, res)

IDSIA - Istituto Dalle Molle di Studi sull'Intelligenza Artificiale

Lugano, Switzerland idsia.ch

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

gpforecast-1.0.2.tar.gz (6.8 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

gpforecast-1.0.2-py3-none-any.whl (7.0 kB view details)

Uploaded Python 3

File details

Details for the file gpforecast-1.0.2.tar.gz.

File metadata

  • Download URL: gpforecast-1.0.2.tar.gz
  • Upload date:
  • Size: 6.8 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/3.7.1 importlib_metadata/4.10.1 pkginfo/1.8.2 requests/2.27.1 requests-toolbelt/0.9.1 tqdm/4.62.3 CPython/3.8.12

File hashes

Hashes for gpforecast-1.0.2.tar.gz
Algorithm Hash digest
SHA256 528aaf03473dd85c083c545b362f7569cfa98a1b97cd0fe0ec4e16451801dbf2
MD5 32792dfa3be57e4abf97d3bb6bf7785d
BLAKE2b-256 c53ca14f9d47e37cf389eab92a577bd66a210501b48a7fe0414306885e930a81

See more details on using hashes here.

File details

Details for the file gpforecast-1.0.2-py3-none-any.whl.

File metadata

  • Download URL: gpforecast-1.0.2-py3-none-any.whl
  • Upload date:
  • Size: 7.0 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/3.7.1 importlib_metadata/4.10.1 pkginfo/1.8.2 requests/2.27.1 requests-toolbelt/0.9.1 tqdm/4.62.3 CPython/3.8.12

File hashes

Hashes for gpforecast-1.0.2-py3-none-any.whl
Algorithm Hash digest
SHA256 b761697037b051c2091f84f3936cbda75c00a6715fbefb91dcae08e68b1bf4df
MD5 d1d6253fee2ba1699aec70e304cf485f
BLAKE2b-256 3dade18a06d9ef9b859c045b378445f0344d1ee9f6e7eb830e0734934f402cfd

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page