N-Beats
Project description
N-BEATS: Neural basis expansion analysis for interpretable time series forecasting
- Implementation in Pytorch
- Implementation in Keras by @eljdos
- https://arxiv.org/abs/1905.10437
N-Beats at the beginning of the training
Trust me, after a few more steps, the green curve (predictions) matches the ground truth exactly :-)
Installation
From PyPI
Install Keras: pip install nbeats-keras
.
Install Pytorch: pip install nbeats-pytorch
.
From the sources
Installation is based on a MakeFile. Make sure you are in a virtualenv and have python3 installed.
Command to install N-Beats with Keras: make install-keras
Command to install N-Beats with Pytorch: make install-pytorch
Model
Pytorch and Keras have the same model arguments:
class NBeatsNet:
def __init__(self,
stack_types=[TREND_BLOCK, SEASONALITY_BLOCK],
nb_blocks_per_stack=3,
forecast_length=2,
backcast_length=10,
thetas_dim=[2, 8],
share_weights_in_stack=False,
hidden_layer_units=128):
pass
Which would translate in this model:
--- Model ---
| N-Beats
| -- Stack Trend (#0) (share_weights_in_stack=False)
| -- TrendBlock(units=128, thetas_dim=2, backcast_length=50, forecast_length=10, share_thetas=True) at @4500902576
| -- TrendBlock(units=128, thetas_dim=2, backcast_length=50, forecast_length=10, share_thetas=True) at @4779951944
| -- TrendBlock(units=128, thetas_dim=2, backcast_length=50, forecast_length=10, share_thetas=True) at @4779952280
| -- Stack Seasonality (#1) (share_weights_in_stack=False)
| -- SeasonalityBlock(units=128, thetas_dim=8, backcast_length=50, forecast_length=10, share_thetas=True) at @4779952616
| -- SeasonalityBlock(units=128, thetas_dim=8, backcast_length=50, forecast_length=10, share_thetas=True) at @4779952952
| -- SeasonalityBlock(units=128, thetas_dim=8, backcast_length=50, forecast_length=10, share_thetas=True) at @4779953288
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
nbeats-pytorch-1.1.0.tar.gz
(3.6 kB
view hashes)
Built Distribution
Close
Hashes for nbeats_pytorch-1.1.0-py3-none-any.whl
Algorithm | Hash digest | |
---|---|---|
SHA256 | 88471e4a586269481cebe27c82a6c23ee34dc0d0d3c4c88ea40d6cd2ff27519b |
|
MD5 | f87e5b666537cc09650c06e6d5e62aab |
|
BLAKE2b-256 | c233c71146bd1b9863586fc33a3cf79f35e00a35ca613fe177344fda88819b5e |