A tensorflow 2.0 implementation of the Dual-Stage Attention-Based Recurrent Neural Network for Time Series Prediction
Project description
Tensorflow 2 DA-RNN
A Tensorflow 2 (Keras) implementation of the Dual-Stage Attention-Based Recurrent Neural Network for Time Series Prediction
Paper: https://arxiv.org/abs/1704.02971
Install
For Tensorflow 2
pip install da-rnn[keras]
For PyTorch
pip install da-rnn[torch]
Usage
For Tensorflow 2
from da_rnn.keras import DARNN
model = DARNN(T=10, m=128)
# Train
model.fit(
train_ds,
validation_data=val_ds,
epochs=100,
verbose=1
)
# Predict
y_hat = model(inputs)
For PyTorch (with poutyne)
import torch
from poutyne import Model
from da_rnn.torch import DARNN
darnn = DARNN(n=50, T=10, m=128)
model = Model(darnn)
# Train
model.fit(
train_ds,
validation_data=val_ds,
epochs=100,
verbose=1
)
# Predict
with torch.no_grad():
y_hat = model(inputs)
Python Docstring Notations
In docstrings of the methods of this project, we have the following notation convention:
variable_{subscript}__{superscript}
For example:
y_T__i
means , thei
-th prediction value at timeT
.alpha_t__k
means , the attention weight measuring the importance of thek
-th input feature (driving series) at timet
.
DARNN(T, m, p, y_dim=1)
DARNN(n, T, m, p, y_dim=1)
The naming of the following (hyper)parameters is consistent with the paper, except
y_dim
which is not mentioned in the paper.
- n (torch only)
int
input size, the number of features of a single driving series - T
int
the length (time steps) of the window - m
int
the number of the encoder hidden states - p
int
the number of the decoder hidden states - y_dim
int=1
the prediction dimention. Defaults to1
.
Return the DA-RNN model instance.
Data Processing
Each feature item of the dataset should be of shape (batch_size, T, length_of_driving_series + y_dim)
And each label item of the dataset should be of shape (batch_size, y_dim)
Development
Install dependencies:
make install
Run notebook:
cd notebook
jupyter lab
TODO
- no hardcoding (
1
for now) for prediction dimentionality
License
Project details
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.