N-Beats
Project description
N-BEATS: Neural basis expansion analysis for interpretable time series forecasting
- Implementation in Pytorch
- Implementation in Keras by @eljdos
- https://arxiv.org/abs/1905.10437
N-Beats at the beginning of the training
Trust me, after a few more steps, the green curve (predictions) matches the ground truth exactly :-)
Installation
From PyPI
Install Keras: pip install nbeats-keras
.
Install Pytorch: pip install nbeats-pytorch
.
From the sources
Installation is based on a MakeFile. Make sure you are in a virtualenv and have python3 installed.
Command to install N-Beats with Keras: make install-keras
Command to install N-Beats with Pytorch: make install-pytorch
Example
Jupyter notebook: NBeats.ipynb: make run-jupyter
.
Model
Pytorch and Keras have the same model arguments:
class NBeatsNet:
def __init__(self,
stack_types=[TREND_BLOCK, SEASONALITY_BLOCK],
nb_blocks_per_stack=3,
forecast_length=2,
backcast_length=10,
thetas_dim=[2, 8],
share_weights_in_stack=False,
hidden_layer_units=128):
pass
Which would translate in this model:
--- Model ---
| N-Beats
| -- Stack Trend (#0) (share_weights_in_stack=False)
| -- TrendBlock(units=128, thetas_dim=2, backcast_length=50, forecast_length=10, share_thetas=True) at @4500902576
| -- TrendBlock(units=128, thetas_dim=2, backcast_length=50, forecast_length=10, share_thetas=True) at @4779951944
| -- TrendBlock(units=128, thetas_dim=2, backcast_length=50, forecast_length=10, share_thetas=True) at @4779952280
| -- Stack Seasonality (#1) (share_weights_in_stack=False)
| -- SeasonalityBlock(units=128, thetas_dim=8, backcast_length=50, forecast_length=10, share_thetas=True) at @4779952616
| -- SeasonalityBlock(units=128, thetas_dim=8, backcast_length=50, forecast_length=10, share_thetas=True) at @4779952952
| -- SeasonalityBlock(units=128, thetas_dim=8, backcast_length=50, forecast_length=10, share_thetas=True) at @4779953288
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
nbeats-pytorch-1.3.1.tar.gz
(3.8 kB
view hashes)
Built Distribution
Close
Hashes for nbeats_pytorch-1.3.1-py3-none-any.whl
Algorithm | Hash digest | |
---|---|---|
SHA256 | fdc5ced78c2be27540303221216279148d891541537e458029094bbd376d7a98 |
|
MD5 | 8c2a225ce8a77593e10b7e44e93e78c5 |
|
BLAKE2b-256 | 876acc613effd73a55d89a44e6d239f95cca6d5a0e0c5ab98464399f1a1406b4 |