Skip to main content

Unified Training of Universal Time Series Forecasting Transformers

Project description

Unified Training of Universal Time Series Forecasting Transformers

Paper | Blog Post

Uni2TS is a PyTorch based library for research and applications related to Time Series Transformers. This library aims to provide a unified solution to large-scale pre-training of Universal Time Series Transformers. Uni2TS also provides tools for fine-tuning, inference, and evaluation for time series forecasting.

🎉 What's New

  • May 2024: The Uni2TS paper has been accepted to ICML 2024 as an Oral presentation!

  • Mar 2024: Release of Uni2TS library, along with Moirai-1.0-R and LOTSA data!

✅ TODO

  • Improve docstrings and documentation

⚙️ Installation

  1. Clone repository:
git clone https://github.com/SalesforceAIResearch/uni2ts.git
cd uni2ts
  1. Create virtual environment:
virtualenv venv
. venv/bin/activate
  1. Build from source:
pip install -e '.[notebook]'
  1. Create a .env file:
touch .env

🏃 Getting Started

Let's see a simple example on how to use Uni2TS to make zero-shot forecasts from a pre-trained model. We first load our data using pandas, in the form of a wide DataFrame. Uni2TS relies on GluonTS for inference as it provides many convenience functions for time series forecasting, such as splitting a dataset into a train/test split and performing rolling evaluations, as demonstrated below.

import torch
import matplotlib.pyplot as plt
import pandas as pd
from gluonts.dataset.pandas import PandasDataset
from gluonts.dataset.split import split
from huggingface_hub import hf_hub_download

from uni2ts.eval_util.plot import plot_single
from uni2ts.model.moirai import MoiraiForecast, MoiraiModule


SIZE = "small"  # model size: choose from {'small', 'base', 'large'}
PDT = 20  # prediction length: any positive integer
CTX = 200  # context length: any positive integer
PSZ = "auto"  # patch size: choose from {"auto", 8, 16, 32, 64, 128}
BSZ = 32  # batch size: any positive integer
TEST = 100  # test set length: any positive integer

# Read data into pandas DataFrame
url = (
    "https://gist.githubusercontent.com/rsnirwan/c8c8654a98350fadd229b00167174ec4"
    "/raw/a42101c7786d4bc7695228a0f2c8cea41340e18f/ts_wide.csv"
)
df = pd.read_csv(url, index_col=0, parse_dates=True)

# Convert into GluonTS dataset
ds = PandasDataset(dict(df))

# Split into train/test set
train, test_template = split(
    ds, offset=-TEST
)  # assign last TEST time steps as test set

# Construct rolling window evaluation
test_data = test_template.generate_instances(
    prediction_length=PDT,  # number of time steps for each prediction
    windows=TEST // PDT,  # number of windows in rolling window evaluation
    distance=PDT,  # number of time steps between each window - distance=PDT for non-overlapping windows
)

# Prepare pre-trained model by downloading model weights from huggingface hub
model = MoiraiForecast(
    module=MoiraiModule.from_pretrained(f"Salesforce/moirai-1.0-R-{SIZE}"),
    prediction_length=PDT,
    context_length=CTX,
    patch_size=PSZ,
    num_samples=100,
    target_dim=1,
    feat_dynamic_real_dim=ds.num_feat_dynamic_real,
    past_feat_dynamic_real_dim=ds.num_past_feat_dynamic_real,
)

predictor = model.create_predictor(batch_size=BSZ)
forecasts = predictor.predict(test_data.input)

input_it = iter(test_data.input)
label_it = iter(test_data.label)
forecast_it = iter(forecasts)

inp = next(input_it)
label = next(label_it)
forecast = next(forecast_it)

plot_single(
    inp, 
    label, 
    forecast, 
    context_length=200,
    name="pred",
    show_label=True,
)
plt.show()

📔 Jupyter Notebook Examples

See the example folder for more examples on common tasks, e.g. visualizing forecasts, predicting from pandas DataFrame, etc.

💻 Command Line Interface

We provide several scripts which act as a command line interface to easily run fine-tuning, evaluation, and even pre-training jobs. Configurations are managed with the Hydra framework.

Fine-tuning

Firstly, let's see how to use Uni2TS to fine-tune a pre-trained model on your custom dataset. Uni2TS uses the Hugging Face datasets library to handle data loading, and we first need to convert your dataset into the Uni2TS format. If your dataset is a simple pandas DataFrame, we can easily process your dataset with the following script. We'll use the ETTh1 dataset from the popular Long Sequence Forecasting benchmark for this example. For more complex use cases, see this notebook for more in-depth examples on how to use your custom dataset with Uni2TS.

  1. To begin the process, add the path to the directory where you want to save the processed dataset into the .env file.
echo "CUSTOM_DATA_PATH=PATH_TO_SAVE" >> .env
  1. Run the following script to process the dataset into the required format. For the dataset_type option, we support wide, long and wide_multivariate.
python -m uni2ts.data.builder.simple ETTh1 dataset/ETT-small/ETTh1.csv --dataset_type wide

However, we may want validation set during fine-tuning to perform hyperparameter tuning or early stopping. To additionally split the dataset into a train and validation split we can use the mutually exclusive date_offset (datetime string) or offset (integer) options which determines the last time step of the train set. The validation set will be saved as DATASET_NAME_eval.

python -m uni2ts.data.builder.simple ETTh1 dataset/ETT-small/ETTh1.csv --date_offset '2017-10-23 23:00:00'
  1. Finally, we can simply run the fine-tuning script with the appropriate training and validation data configuration files.
python -m cli.train \
  -cp conf/finetune \
  run_name=example_run \ 
  model=moirai_1.0_R_small \ 
  data=etth1 \ 
  val_data=etth1  

Evaluation

The evaluation script can be used to calculate evaluation metrics such as MSE, MASE, CRPS, and so on (see the configuration file).

Given a test split (see previous section on processing datasets), we can run the following command to evaluate it:

python -m cli.eval \ 
  run_name=example_eval_1 \
  model=moirai_1.0_R_small \
  model.patch_size=32 \ 
  model.context_length=1000 \
  data=etth1_test

Alternatively, we provide access to popular datasets, and can be toggled via the data configurations. As an example, say we want to perform evaluation, again on the ETTh1 dataset from the popular Long Sequence Forecasting benchmark. We first need to download the pre-processed datasets and put them in the correct directory, by setting up the TSLib repository and following the instructions. Then, assign the dataset directory to the LSF_PATH environment variable:

echo "LSF_PATH=PATH_TO_TSLIB/dataset" >> .env

Thereafter, simply run the following script with the predefined Hydra config file:

python -m cli.eval \ 
  run_name=example_eval_2 \
  model=moirai_1.0_R_small \
  model.patch_size=32 \ 
  model.context_length=1000 \ 
  data=lsf_test \
  data.dataset_name=ETTh1 \
  data.prediction_length=96 

Pre-training

Now, let's see how you can pre-train your own model. We'll start with preparing the data for pre-training first, by downloading the Large-scale Open Time Series Archive (LOTSA data). Assuming you've already createed a .env file, run the following commands.

huggingface-cli download Salesforce/lotsa_data --repo-type=dataset --local-dir PATH_TO_SAVE
echo "LOTSA_V1_PATH=PATH_TO_SAVE" >> .env

Then, we can simply run the following script to start a pre-training job. See the relevant files on how to further customize the settings.

python -m cli.train \
  -cp conf/pretrain \
  run_name=first_run \
  model=moirai_small \
  data=lotsa_v1_unweighted

👀 Citing Uni2TS

If you're using Uni2TS in your research or applications, please cite it using this BibTeX:

@article{woo2024unified,
  title={Unified Training of Universal Time Series Forecasting Transformers},
  author={Woo, Gerald and Liu, Chenghao and Kumar, Akshat and Xiong, Caiming and Savarese, Silvio and Sahoo, Doyen},
  journal={arXiv preprint arXiv:2402.02592},
  year={2024}
}

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

uni2ts-1.1.0.tar.gz (74.5 kB view details)

Uploaded Source

Built Distribution

uni2ts-1.1.0-py3-none-any.whl (13.2 kB view details)

Uploaded Python 3

File details

Details for the file uni2ts-1.1.0.tar.gz.

File metadata

  • Download URL: uni2ts-1.1.0.tar.gz
  • Upload date:
  • Size: 74.5 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/5.1.1 CPython/3.11.6

File hashes

Hashes for uni2ts-1.1.0.tar.gz
Algorithm Hash digest
SHA256 ca702eda617c31a1ced6a80aa976994eeaf2e148594a586479c672141342aa47
MD5 b2c7284f6a777c5ee95a32e1f9af7ea4
BLAKE2b-256 773e9dd2e7fa947926bde8a8f7952931570eca40f0699372e1a9485f006dee13

See more details on using hashes here.

File details

Details for the file uni2ts-1.1.0-py3-none-any.whl.

File metadata

  • Download URL: uni2ts-1.1.0-py3-none-any.whl
  • Upload date:
  • Size: 13.2 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/5.1.1 CPython/3.11.6

File hashes

Hashes for uni2ts-1.1.0-py3-none-any.whl
Algorithm Hash digest
SHA256 0f945378c829abeab30e9fb3a0555c267e84661780e62f48a001d1270e85040e
MD5 06db69b1c4ee5a18965fc683d7485292
BLAKE2b-256 8a8c7addaa4331f0b74318b54dcff2b3cc89d093bc460af60f160c2eda3dce9c

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page