Skip to main content

A tensorflow 2.0 implementation of the Dual-Stage Attention-Based Recurrent Neural Network for Time Series Prediction

Project description

Tensorflow 2 DA-RNN

A Tensorflow 2 (Keras) implementation of the Dual-Stage Attention-Based Recurrent Neural Network for Time Series Prediction

Paper: https://arxiv.org/abs/1704.02971

Install

For Tensorflow 2

pip install da-rnn[keras]

For PyTorch

pip install da-rnn[torch]

Usage

For Tensorflow 2

from da_rnn.keras import DARNN

model = DARNN(T=10, m=128)

# Train
model.fit(
    train_ds,
    validation_data=val_ds,
    epochs=100,
    verbose=1
)

# Predict
y_hat = model(inputs)

For PyTorch (with poutyne)

import torch
from poutyne import Model
from da_rnn.torch import DARNN

darnn = DARNN(n=50, T=10, m=128)
model = Model(darnn)

# Train
model.fit(
    train_ds,
    validation_data=val_ds,
    epochs=100,
    verbose=1
)

# Predict
with torch.no_grad():
    y_hat = model(inputs)

Python Docstring Notations

In docstrings of the methods of this project, we have the following notation convention:

variable_{subscript}__{superscript}

For example:

  • y_T__i means y_T__i, the i-th prediction value at time T.
  • alpha_t__k means alpha_t__k, the attention weight measuring the importance of the k-th input feature (driving series) at time t.

DARNN(T, m, p, y_dim=1)

DARNN(n, T, m, p, y_dim=1)

The naming of the following (hyper)parameters is consistent with the paper, except y_dim which is not mentioned in the paper.

  • n (torch only) int input size, the number of features of a single driving series
  • T int the length (time steps) of the window
  • m int the number of the encoder hidden states
  • p int the number of the decoder hidden states
  • y_dim int=1 the prediction dimention. Defaults to 1.

Return the DA-RNN model instance.

Data Processing

Each feature item of the dataset should be of shape (batch_size, T, length_of_driving_series + y_dim)

And each label item of the dataset should be of shape (batch_size, y_dim)

Development

Install dependencies:

make install

Run notebook:

cd notebook
jupyter lab

TODO

  • no hardcoding (1 for now) for prediction dimentionality

License

MIT

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

da-rnn-1.0.2.tar.gz (8.5 kB view details)

Uploaded Source

Built Distribution

da_rnn-1.0.2-py3-none-any.whl (16.4 kB view details)

Uploaded Python 3

File details

Details for the file da-rnn-1.0.2.tar.gz.

File metadata

  • Download URL: da-rnn-1.0.2.tar.gz
  • Upload date:
  • Size: 8.5 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/3.4.1 importlib_metadata/3.7.3 pkginfo/1.5.0.1 requests/2.24.0 requests-toolbelt/0.9.1 tqdm/4.48.2 CPython/3.8.5

File hashes

Hashes for da-rnn-1.0.2.tar.gz
Algorithm Hash digest
SHA256 e6607d3f6612fa53d0dcecb8fdc78909b7fc7901c2c6eb6936269c2208940a61
MD5 dbf1989502627906b768798580464c01
BLAKE2b-256 ac06c5259e0eb430a0d86fb84d0c1ee8fc0fa0ad8004d6b901908b9216ff6373

See more details on using hashes here.

File details

Details for the file da_rnn-1.0.2-py3-none-any.whl.

File metadata

  • Download URL: da_rnn-1.0.2-py3-none-any.whl
  • Upload date:
  • Size: 16.4 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/3.4.1 importlib_metadata/3.7.3 pkginfo/1.5.0.1 requests/2.24.0 requests-toolbelt/0.9.1 tqdm/4.48.2 CPython/3.8.5

File hashes

Hashes for da_rnn-1.0.2-py3-none-any.whl
Algorithm Hash digest
SHA256 2cc714e6b72116402f15fa65d4df3527101b0232de303489010129d72d7864e2
MD5 1cbb2e1c5ff3b160f28d22b107b55980
BLAKE2b-256 e577c1dd701f85737cd780e7cb653bb96e500562eea83c003ba4a5a7e609fe01

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page