Skip to main content

Temporal fusion transformer for timeseries forecasting

Project description

Pytorch Forecasting aims to ease timeseries forecasting with neural networks. It specificially provides a class to wrap timeseries datasets and a number of PyTorch models.

Installation

Install with

pip install pytorch-forecasting

Visit the documentation at https://pytorch-forecasting.readthedocs.io.

Available models

Usage

import pytorch_lightning as pl
from pytorch_lightning.callbacks import EarlyStopping

from pytorch_forecasting import TimeSeriesDataSet, TemporalFusionTransformer

# load data
data = ...

# define dataset
max_encode_length = 36
max_prediction_length = 6
training_cutoff = "YYYY-MM-DD"  # day for cutoff

training = TimeSeriesDataSet(
    data[lambda x: x.date <= training_cutoff],
    time_idx= ...,
    target= ...,
    group_ids=[ ... ],
    max_encode_length=max_encode_length,
    max_prediction_length=max_prediction_length,
    static_categoricals=[ ... ],
    static_reals=[ ... ],
    time_varying_known_categoricals=[ ... ],
    time_varying_known_reals=[ ... ],
    time_varying_unknown_categoricals=[ ... ],
    time_varying_unknown_reals=[ ... ],
)


validation = TimeSeriesDataSet.from_dataset(training, data, min_prediction_idx=training.index.time.max() + 1, stop_randomization=True)
batch_size = 128
train_dataloader = training.to_dataloader(train=True, batch_size=batch_size, num_workers=2)
val_dataloader = validation.to_dataloader(train=False, batch_size=batch_size, num_workers=2)


early_stop_callback = EarlyStopping(monitor="val_loss", min_delta=1e-4, patience=1, verbose=False, mode="min")
lr_logger = LearningRateLogger()
trainer = pl.Trainer(
    max_epochs=100,
    gpus=0,
    gradient_clip_val=0.1,
    early_stop_callback=early_stop_callback,
    limit_train_batches=30,
    callbacks=[lr_logger],
)


tft = TemporalFusionTransformer.from_dataset(
    training,
    learning_rate=0.03,
    hidden_size=32,
    attention_head_size=1,
    dropout=0.1,
    hidden_continuous_size=16,
    output_size=7,
    loss=QuantileLoss(),
    log_interval=2,
    reduce_on_plateau_patience=4
)
print(f"Number of parameters in network: {tft.size()/1e3:.1f}k")

# find optimal learning rate
res = trainer.lr_find(
    tft, train_dataloader=train_dataloader, val_dataloaders=val_dataloader, early_stop_threshold=1000.0, max_lr=0.3,
)

print(f"suggested learning rate: {res.suggestion()}")
fig = res.plot(show=True, suggest=True)
fig.show()

trainer.fit(
    tft, train_dataloader=train_dataloader, val_dataloaders=val_dataloader,
)

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

pytorch_forecasting-0.2.1.tar.gz (48.4 kB view details)

Uploaded Source

Built Distribution

pytorch_forecasting-0.2.1-py3-none-any.whl (52.3 kB view details)

Uploaded Python 3

File details

Details for the file pytorch_forecasting-0.2.1.tar.gz.

File metadata

  • Download URL: pytorch_forecasting-0.2.1.tar.gz
  • Upload date:
  • Size: 48.4 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: poetry/1.0.4 CPython/2.7.16 Darwin/19.6.0

File hashes

Hashes for pytorch_forecasting-0.2.1.tar.gz
Algorithm Hash digest
SHA256 446d3c14fc848a15670feb5a349f6ea92d30262cfbcfceabcc8453e801c62a20
MD5 36b99116ec9d9fdcee279ecc72a772d1
BLAKE2b-256 22369795904660e8be679738da2dd9018781ce704abc427186db1688673e4277

See more details on using hashes here.

File details

Details for the file pytorch_forecasting-0.2.1-py3-none-any.whl.

File metadata

File hashes

Hashes for pytorch_forecasting-0.2.1-py3-none-any.whl
Algorithm Hash digest
SHA256 f7ba1a13ed70e1c89606fe537479c5bd167911af80763814c266f53716907360
MD5 67b6cb9e63f724281f2a8ba5f069379e
BLAKE2b-256 90650620061c13e0c25e53655b16345caab48b201768bec1195beb466bff65fa

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page