Skip to main content

Temporal fusion transformer for timeseries forecasting

Project description

Timeseries forecasting with Pytorch

Install with

pip install pytorch_forecasting

Available models

Usage

import pytorch_lightning as pl
from pytorch_lightning.callbacks import EarlyStopping

from pytorch_forecasting import TimeSeriesDataSet, TemporalFusionTransformer

# load data
data = ... 

# define dataset
max_encode_length = 36
max_prediction_length = 6
training_cutoff = "YYYY-MM-DD"  # day for cutoff

training = TimeSeriesDataSet(
    data[lambda x: x.date < training_cutoff],
    time_idx= ...,
    target= ...,
    # weight="weight",
    group_ids=[ ... ],
    max_encode_length=max_encode_length,
    max_prediction_length=max_prediction_length,
    static_categoricals=[ ... ],
    static_reals=[],
    time_varying_known_categoricals=[ ... ],
    time_varying_known_reals=[
        "time_idx",
        "price_regular",
        "price_actual",
        "discount",
        "avg_population_2017",
        "avg_yearly_household_income_2017",
        "discount_in_percent",
    ],
    time_varying_unknown_categoricals=[],
    time_varying_unknown_reals=["volume", "log_volume", "industry_volume", "soda_volume", "avg_max_temp"],
    constant_fill_strategy={"volume": 0},
    dropout_categoricals=["sku"],
)


validation = TimeSeriesDataSet.from_dataset(training, data, min_prediction_idx=training.data_index.time.max() + 1)
batch_size = 128
train_dataloader = training.to_dataloader(train=True, batch_size=batch_size, num_workers=2)
val_dataloader = validation.to_dataloader(train=False, batch_size=batch_size, num_workers=2)


early_stop_callback = EarlyStopping(monitor="val_loss", min_delta=1e-4, patience=1, verbose=False, mode="min")
trainer = pl.Trainer(
    max_epochs=10,
    gpus=0,
    gradient_clip_val=0.1,
    early_stop_callback=early_stop_callback,
)


tft = TemporalFusionTransformer.from_dataset(training)
print(f"Number of parameters in network: {tft.size()/1e3:.1f}k")

trainer.fit(
    tft, train_dataloader=train_dataloader, val_dataloaders=val_dataloader,
)

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

pytorch_forecasting-0.1.0.tar.gz (23.7 kB view details)

Uploaded Source

Built Distribution

pytorch_forecasting-0.1.0-py3-none-any.whl (25.1 kB view details)

Uploaded Python 3

File details

Details for the file pytorch_forecasting-0.1.0.tar.gz.

File metadata

  • Download URL: pytorch_forecasting-0.1.0.tar.gz
  • Upload date:
  • Size: 23.7 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: poetry/1.0.9 CPython/3.7.7 Darwin/19.4.0

File hashes

Hashes for pytorch_forecasting-0.1.0.tar.gz
Algorithm Hash digest
SHA256 d571531efe647f21423b7c2c4117efc61df58d27d225683f8d60a60618eca6bb
MD5 e5940d159dd89eeb6f908f4bc53f0393
BLAKE2b-256 47a2df2a22fe995f64494ee4bc157fb1f9c254adb1d90182bfbc50cdc8fb08ea

See more details on using hashes here.

File details

Details for the file pytorch_forecasting-0.1.0-py3-none-any.whl.

File metadata

File hashes

Hashes for pytorch_forecasting-0.1.0-py3-none-any.whl
Algorithm Hash digest
SHA256 512f167c75a7a903d22d39303dda73f8e74fd72513d6774e9e579913954ba8a0
MD5 e6b6e83a4e089df49eb3716605c01b91
BLAKE2b-256 a2abc0049cfb4edb09cd251560caee1cb78331959dc4d646f421a1eff6ac8d24

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page