Skip to main content

Probabilistic Time Series Modeling using Trans-MAF model.

Project description

transmaf

transmaf это библиотека на основе PyTorch для вероятностного прогнозирования временных рядов с использованием модела Trans-MAF. В качестве бэк-энд API используется GluonTS для загрузки, трансформации и бэк-теста датасетов.

Installation

$ pip install transmaf

Quick start

Imports

import numpy as np
import torch

from gluonts.dataset.multivariate_grouper import MultivariateGrouper
from gluonts.dataset.repository.datasets import get_dataset
from gluonts.evaluation import MultivariateEvaluator
from gluonts.evaluation.backtest import make_evaluation_predictions

from transmaf import Trainer
from transmaf.model.transformer_tempflow import TransformerTempFlowEstimator

Read data

electricity = get_dataset("electricity_nips", regenerate=False)

# create train/test groupers
electricity_train_grouper = MultivariateGrouper(
    max_target_dim=min(2000, int(electricity.metadata.feat_static_cat[0].cardinality))
    )
electricity_test_grouper = MultivariateGrouper(
    num_test_dates=int(len(electricity.test) / len(electricity.train)), 
    max_target_dim=min(2000, int(electricity.metadata.feat_static_cat[0].cardinality))
    )

# create train/test datasets
electricity_dataset_train = list(electricity_train_grouper(electricity.train))
electricity_dataset_train *= 100 
electricity_dataset_test = electricity_test_grouper(electricity.test)

Train estimator

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

estimator = TransformerTempFlowEstimator(
    input_size=744,
    target_dim=int(electricity.metadata.feat_static_cat[0].cardinality),
    prediction_length=electricity.metadata.prediction_length,
    context_length=electricity.metadata.prediction_length * 4,
    flow_type='MAF',
    dequantize=True,
    freq=electricity.metadata.freq,
    trainer=Trainer(
        device='cpu',
        epochs=14,
        learning_rate=1e-3,
        num_batches_per_epoch=100,
        batch_size=64,
    )
)

predictor = estimator.train(
    electricity_dataset_train, 
    num_workers=4
    )

Prediction

# init evaluator
evaluator = MultivariateEvaluator(
    quantiles=(np.arange(20)/20.0)[1:],
    target_agg_funcs={'sum': np.sum}
)

# prediction
forecast_it, ts_it = make_evaluation_predictions(
    dataset=electricity_dataset_test,
    predictor=predictor,
    num_samples=20
)

forecasts = list(forecast_it)
targets = list(ts_it)

Calculate metrics

agg_metric, _ = evaluator(
    targets, forecasts, num_series=len(electricity_dataset_test)
    )

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

transmaf-0.1.4.tar.gz (34.0 kB view details)

Uploaded Source

File details

Details for the file transmaf-0.1.4.tar.gz.

File metadata

  • Download URL: transmaf-0.1.4.tar.gz
  • Upload date:
  • Size: 34.0 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/5.0.0 CPython/3.9.6

File hashes

Hashes for transmaf-0.1.4.tar.gz
Algorithm Hash digest
SHA256 382af8bba5fee4f0fd0630a828fb3e6d058bcce9d89aa15d57d6ab42833edada
MD5 3433b7c1ebe698bb25c1e7481b782e99
BLAKE2b-256 47638886c647d41d032bc8b9d52835820877ccaffed5f290d18cc4a6df6c6969

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page