Skip to main content

Temporal fusion transformer for timeseries forecasting

Project description

Timeseries forecasting with Pytorch

Install with

pip install pytorch-forecasting

Available models

Usage

import pytorch_lightning as pl
from pytorch_lightning.callbacks import EarlyStopping

from pytorch_forecasting import TimeSeriesDataSet, TemporalFusionTransformer

# load data
data = ...

# define dataset
max_encode_length = 36
max_prediction_length = 6
training_cutoff = "YYYY-MM-DD"  # day for cutoff

training = TimeSeriesDataSet(
    data[lambda x: x.date < training_cutoff],
    time_idx= ...,
    target= ...,
    # weight="weight",
    group_ids=[ ... ],
    max_encode_length=max_encode_length,
    max_prediction_length=max_prediction_length,
    static_categoricals=[ ... ],
    static_reals=[ ... ],
    time_varying_known_categoricals=[ ... ],
    time_varying_known_reals=[ ... ],
    time_varying_unknown_categoricals=[ ... ],
    time_varying_unknown_reals=[ ... ],
)


validation = TimeSeriesDataSet.from_dataset(training, data, min_prediction_idx=training.index.time.max() + 1)
batch_size = 128
train_dataloader = training.to_dataloader(train=True, batch_size=batch_size, num_workers=2)
val_dataloader = validation.to_dataloader(train=False, batch_size=batch_size, num_workers=2)


early_stop_callback = EarlyStopping(monitor="val_loss", min_delta=1e-4, patience=1, verbose=False, mode="min")
trainer = pl.Trainer(
    max_epochs=10,
    gpus=0,
    gradient_clip_val=0.1,
    early_stop_callback=early_stop_callback,
)


tft = TemporalFusionTransformer.from_dataset(training)
print(f"Number of parameters in network: {tft.size()/1e3:.1f}k")

trainer.fit(
    tft, train_dataloader=train_dataloader, val_dataloaders=val_dataloader,
)

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

pytorch_forecasting-0.1.1.tar.gz (34.4 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

pytorch_forecasting-0.1.1-py3-none-any.whl (38.3 kB view details)

Uploaded Python 3

File details

Details for the file pytorch_forecasting-0.1.1.tar.gz.

File metadata

  • Download URL: pytorch_forecasting-0.1.1.tar.gz
  • Upload date:
  • Size: 34.4 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: poetry/1.0.9 CPython/3.7.7 Darwin/19.5.0

File hashes

Hashes for pytorch_forecasting-0.1.1.tar.gz
Algorithm Hash digest
SHA256 b9ad53efc6a6d787c3e08b98b94a98de9a56a87039184a13f0f6282b4253c037
MD5 ff09b38a23f55a847a5631341a09ad2b
BLAKE2b-256 aab08c8049f361ef040bb6c5fd246f518a9eca395458d12c6926db97972aa9c4

See more details on using hashes here.

File details

Details for the file pytorch_forecasting-0.1.1-py3-none-any.whl.

File metadata

File hashes

Hashes for pytorch_forecasting-0.1.1-py3-none-any.whl
Algorithm Hash digest
SHA256 e112312b1305c49efbdc459169a21fd7a1a1f05fc4039150abef51629bb4c127
MD5 89494c1ac110ec2f9cc67d5abb3a8b49
BLAKE2b-256 b2831ffd8f653122b72fc819c0f01d6370b20300fe6d4679adae3eac41efd4b4

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page