Skip to main content

Temporal fusion transformer for timeseries forecasting

Project description

Pytorch Forecasting aims to ease timeseries forecasting with neural networks. It specificially provides a class to wrap timeseries datasets and a number of PyTorch models.

Installation

If you are working windows, you need to first install PyTorch with

pip install torch -f https://download.pytorch.org/whl/torch_stable.html.

Otherwise, you can proceed with

pip install pytorch-forecasting

Visit the documentation at https://pytorch-forecasting.readthedocs.io.

Available models

Usage

import pytorch_lightning as pl
from pytorch_lightning.callbacks import EarlyStopping

from pytorch_forecasting import TimeSeriesDataSet, TemporalFusionTransformer

# load data
data = ...

# define dataset
max_encode_length = 36
max_prediction_length = 6
training_cutoff = "YYYY-MM-DD"  # day for cutoff

training = TimeSeriesDataSet(
    data[lambda x: x.date <= training_cutoff],
    time_idx= ...,
    target= ...,
    group_ids=[ ... ],
    max_encode_length=max_encode_length,
    max_prediction_length=max_prediction_length,
    static_categoricals=[ ... ],
    static_reals=[ ... ],
    time_varying_known_categoricals=[ ... ],
    time_varying_known_reals=[ ... ],
    time_varying_unknown_categoricals=[ ... ],
    time_varying_unknown_reals=[ ... ],
)


validation = TimeSeriesDataSet.from_dataset(training, data, min_prediction_idx=training.index.time.max() + 1, stop_randomization=True)
batch_size = 128
train_dataloader = training.to_dataloader(train=True, batch_size=batch_size, num_workers=2)
val_dataloader = validation.to_dataloader(train=False, batch_size=batch_size, num_workers=2)


early_stop_callback = EarlyStopping(monitor="val_loss", min_delta=1e-4, patience=1, verbose=False, mode="min")
lr_logger = LearningRateLogger()
trainer = pl.Trainer(
    max_epochs=100,
    gpus=0,
    gradient_clip_val=0.1,
    early_stop_callback=early_stop_callback,
    limit_train_batches=30,
    callbacks=[lr_logger],
)


tft = TemporalFusionTransformer.from_dataset(
    training,
    learning_rate=0.03,
    hidden_size=32,
    attention_head_size=1,
    dropout=0.1,
    hidden_continuous_size=16,
    output_size=7,
    loss=QuantileLoss(),
    log_interval=2,
    reduce_on_plateau_patience=4
)
print(f"Number of parameters in network: {tft.size()/1e3:.1f}k")

# find optimal learning rate
res = trainer.lr_find(
    tft, train_dataloader=train_dataloader, val_dataloaders=val_dataloader, early_stop_threshold=1000.0, max_lr=0.3,
)

print(f"suggested learning rate: {res.suggestion()}")
fig = res.plot(show=True, suggest=True)
fig.show()

trainer.fit(
    tft, train_dataloader=train_dataloader, val_dataloaders=val_dataloader,
)

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

pytorch_forecasting-0.2.2.tar.gz (48.6 kB view details)

Uploaded Source

Built Distribution

pytorch_forecasting-0.2.2-py3-none-any.whl (52.4 kB view details)

Uploaded Python 3

File details

Details for the file pytorch_forecasting-0.2.2.tar.gz.

File metadata

  • Download URL: pytorch_forecasting-0.2.2.tar.gz
  • Upload date:
  • Size: 48.6 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: poetry/1.0.4 CPython/2.7.16 Darwin/19.6.0

File hashes

Hashes for pytorch_forecasting-0.2.2.tar.gz
Algorithm Hash digest
SHA256 c1865372e500451b0ce59f4464e9d62309fe783073af98f35102375b2d00bd9e
MD5 dfc4ee0a9d7d0877e06fa181ec8b140b
BLAKE2b-256 9559f175eab544ecebcc1aa0a8b4ff4f54290618b26eed9066021bf3667461f0

See more details on using hashes here.

File details

Details for the file pytorch_forecasting-0.2.2-py3-none-any.whl.

File metadata

File hashes

Hashes for pytorch_forecasting-0.2.2-py3-none-any.whl
Algorithm Hash digest
SHA256 267076b3bad3e9a924b938d9e89896bd28fadd8b54d85b0ef4d895d88f381348
MD5 3fdaee91f86bd38db28ab0ad60f76704
BLAKE2b-256 9127209cea7b25a8c3f6dfb198b978dbcc943b5281dc9371f32951f7502f0cb2

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page