Probabilistic Time Series Modeling using Trans-MAF model.
Project description
transmaf
transmaf это библиотека на основе PyTorch для вероятностного прогнозирования временных рядов с использованием модела Trans-MAF. В качестве бэк-энд API используется GluonTS для загрузки, трансформации и бэк-теста датасетов.
Installation
$ pip install transmaf
Quick start
Imports
import numpy as np
import torch
from gluonts.dataset.multivariate_grouper import MultivariateGrouper
from gluonts.dataset.repository.datasets import get_dataset
from gluonts.evaluation import MultivariateEvaluator
from gluonts.evaluation.backtest import make_evaluation_predictions
from transmaf import Trainer
from transmaf.model.transformer_tempflow import TransformerTempFlowEstimator
Read data
electricity = get_dataset("electricity_nips", regenerate=False)
# create train/test groupers
electricity_train_grouper = MultivariateGrouper(
max_target_dim=min(2000, int(electricity.metadata.feat_static_cat[0].cardinality))
)
electricity_test_grouper = MultivariateGrouper(
num_test_dates=int(len(electricity.test) / len(electricity.train)),
max_target_dim=min(2000, int(electricity.metadata.feat_static_cat[0].cardinality))
)
# create train/test datasets
electricity_dataset_train = list(electricity_train_grouper(electricity.train))
electricity_dataset_train *= 100
electricity_dataset_test = electricity_test_grouper(electricity.test)
Train estimator
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
estimator = TransformerTempFlowEstimator(
input_size=744,
target_dim=int(electricity.metadata.feat_static_cat[0].cardinality),
prediction_length=electricity.metadata.prediction_length,
context_length=electricity.metadata.prediction_length * 4,
flow_type='MAF',
dequantize=True,
freq=electricity.metadata.freq,
trainer=Trainer(
device='cpu',
epochs=14,
learning_rate=1e-3,
num_batches_per_epoch=100,
batch_size=64,
)
)
predictor = estimator.train(
electricity_dataset_train,
num_workers=4
)
Prediction
# init evaluator
evaluator = MultivariateEvaluator(
quantiles=(np.arange(20)/20.0)[1:],
target_agg_funcs={'sum': np.sum}
)
# prediction
forecast_it, ts_it = make_evaluation_predictions(
dataset=electricity_dataset_test,
predictor=predictor,
num_samples=20
)
forecasts = list(forecast_it)
targets = list(ts_it)
Calculate metrics
agg_metric, _ = evaluator(
targets, forecasts, num_series=len(electricity_dataset_test)
)
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
transmaf-0.1.4.tar.gz
(34.0 kB
view details)
File details
Details for the file transmaf-0.1.4.tar.gz.
File metadata
- Download URL: transmaf-0.1.4.tar.gz
- Upload date:
- Size: 34.0 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/5.0.0 CPython/3.9.6
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
382af8bba5fee4f0fd0630a828fb3e6d058bcce9d89aa15d57d6ab42833edada
|
|
| MD5 |
3433b7c1ebe698bb25c1e7481b782e99
|
|
| BLAKE2b-256 |
47638886c647d41d032bc8b9d52835820877ccaffed5f290d18cc4a6df6c6969
|