A tensorflow 2.0 implementation of the Dual-Stage Attention-Based Recurrent Neural Network for Time Series Prediction
Project description
Tensorflow 2 DA-RNN
A Tensorflow 2 (Keras) implementation of the Dual-Stage Attention-Based Recurrent Neural Network for Time Series Prediction
Paper: https://arxiv.org/abs/1704.02971
Install
For Tensorflow 2
pip install da-rnn[keras]
For PyTorch
pip install da-rnn[torch]
Usage
For Tensorflow 2
from da_rnn.keras import DARNN
model = DARNN(T=10, m=128)
# Train
model.fit(
train_ds,
validation_data=val_ds,
epochs=100,
verbose=1
)
# Predict
y_hat = model(inputs)
For PyTorch (with poutyne)
import torch
from poutyne import Model
from da_rnn.torch import DARNN
darnn = DARNN(n=50, T=10, m=128)
model = Model(darnn)
# Train
model.fit(
train_ds,
validation_data=val_ds,
epochs=100,
verbose=1
)
# Predict
with torch.no_grad():
y_hat = model(inputs)
Python Docstring Notations
In docstrings of the methods of this project, we have the following notation convention:
variable_{subscript}__{superscript}
For example:
y_T__i
means , thei
-th prediction value at timeT
.alpha_t__k
means , the attention weight measuring the importance of thek
-th input feature (driving series) at timet
.
DARNN(T, m, p, y_dim=1)
DARNN(n, T, m, p, y_dim=1)
The naming of the following (hyper)parameters is consistent with the paper, except
y_dim
which is not mentioned in the paper.
- n (torch only)
int
input size, the number of features of a single driving series - T
int
the length (time steps) of the window - m
int
the number of the encoder hidden states - p
int
the number of the decoder hidden states - y_dim
int=1
the prediction dimention. Defaults to1
.
Return the DA-RNN model instance.
Data Processing
Each feature item of the dataset should be of shape (batch_size, T, length_of_driving_series + y_dim)
And each label item of the dataset should be of shape (batch_size, y_dim)
Development
Install dependencies:
make install
Run notebook:
cd notebook
jupyter lab
TODO
- no hardcoding (
1
for now) for prediction dimentionality
License
Project details
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
File details
Details for the file da-rnn-1.0.2.tar.gz
.
File metadata
- Download URL: da-rnn-1.0.2.tar.gz
- Upload date:
- Size: 8.5 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/3.4.1 importlib_metadata/3.7.3 pkginfo/1.5.0.1 requests/2.24.0 requests-toolbelt/0.9.1 tqdm/4.48.2 CPython/3.8.5
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | e6607d3f6612fa53d0dcecb8fdc78909b7fc7901c2c6eb6936269c2208940a61 |
|
MD5 | dbf1989502627906b768798580464c01 |
|
BLAKE2b-256 | ac06c5259e0eb430a0d86fb84d0c1ee8fc0fa0ad8004d6b901908b9216ff6373 |
File details
Details for the file da_rnn-1.0.2-py3-none-any.whl
.
File metadata
- Download URL: da_rnn-1.0.2-py3-none-any.whl
- Upload date:
- Size: 16.4 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/3.4.1 importlib_metadata/3.7.3 pkginfo/1.5.0.1 requests/2.24.0 requests-toolbelt/0.9.1 tqdm/4.48.2 CPython/3.8.5
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | 2cc714e6b72116402f15fa65d4df3527101b0232de303489010129d72d7864e2 |
|
MD5 | 1cbb2e1c5ff3b160f28d22b107b55980 |
|
BLAKE2b-256 | e577c1dd701f85737cd780e7cb653bb96e500562eea83c003ba4a5a7e609fe01 |