A tensorflow 2.0 implementation of the Dual-Stage Attention-Based Recurrent Neural Network for Time Series Prediction
Project description
Tensorflow 2 DA-RNN
A Tensorflow 2 (Keras) implementation of the Dual-Stage Attention-Based Recurrent Neural Network for Time Series Prediction
Paper: https://arxiv.org/abs/1704.02971
Install
For Tensorflow 2
pip install da-rnn[keras]
For PyTorch
pip install da-rnn[torch]
Usage
For Tensorflow 2
from da_rnn.keras import DARNN
model = DARNN(T=10, m=128)
# Train
model.fit(
train_ds,
validation_data=val_ds,
epochs=100,
verbose=1
)
# Predict
y_hat = model(inputs)
For PyTorch (with poutyne)
import torch
from poutyne import Model
from da_rnn.torch import DARNN
darnn = DARNN(n=50, T=10, m=128)
model = Model(darnn)
# Train
model.fit(
train_ds,
validation_data=val_ds,
epochs=100,
verbose=1
)
# Predict
with torch.no_grad():
y_hat = model(inputs)
Python Docstring Notations
In docstrings of the methods of this project, we have the following notation convention:
variable_{subscript}__{superscript}
For example:
y_T__imeans, the
i-th prediction value at timeT.alpha_t__kmeans, the attention weight measuring the importance of the
k-th input feature (driving series) at timet.
DARNN(T, m, p, y_dim=1)
DARNN(n, T, m, p, y_dim=1)
The naming of the following (hyper)parameters is consistent with the paper, except
y_dimwhich is not mentioned in the paper.
- n (torch only)
intinput size, the number of features of a single driving series - T
intthe length (time steps) of the window - m
intthe number of the encoder hidden states - p
intthe number of the decoder hidden states - y_dim
int=1the prediction dimention. Defaults to1.
Return the DA-RNN model instance.
Data Processing
Each feature item of the dataset should be of shape (batch_size, T, length_of_driving_series + y_dim)
And each label item of the dataset should be of shape (batch_size, y_dim)
Development
Install dependencies:
make install
Run notebook:
cd notebook
jupyter lab
TODO
- no hardcoding (
1for now) for prediction dimentionality
License
Project details
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file da-rnn-1.0.2.tar.gz.
File metadata
- Download URL: da-rnn-1.0.2.tar.gz
- Upload date:
- Size: 8.5 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/3.4.1 importlib_metadata/3.7.3 pkginfo/1.5.0.1 requests/2.24.0 requests-toolbelt/0.9.1 tqdm/4.48.2 CPython/3.8.5
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
e6607d3f6612fa53d0dcecb8fdc78909b7fc7901c2c6eb6936269c2208940a61
|
|
| MD5 |
dbf1989502627906b768798580464c01
|
|
| BLAKE2b-256 |
ac06c5259e0eb430a0d86fb84d0c1ee8fc0fa0ad8004d6b901908b9216ff6373
|
File details
Details for the file da_rnn-1.0.2-py3-none-any.whl.
File metadata
- Download URL: da_rnn-1.0.2-py3-none-any.whl
- Upload date:
- Size: 16.4 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/3.4.1 importlib_metadata/3.7.3 pkginfo/1.5.0.1 requests/2.24.0 requests-toolbelt/0.9.1 tqdm/4.48.2 CPython/3.8.5
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
2cc714e6b72116402f15fa65d4df3527101b0232de303489010129d72d7864e2
|
|
| MD5 |
1cbb2e1c5ff3b160f28d22b107b55980
|
|
| BLAKE2b-256 |
e577c1dd701f85737cd780e7cb653bb96e500562eea83c003ba4a5a7e609fe01
|