N-Beats
Project description
NBEATS
Neural basis expansion analysis for interpretable time series forecasting
Tensorflow/Pytorch implementation | Paper | Results
Outputs of the generic and interpretable layers
Installation
It is possible to install the two backends at the same time.
From PyPI
Install the Tensorflow/Keras backend: pip install nbeats-keras
Install the Pytorch backend: pip install nbeats-pytorch
From the sources
Installation is based on a MakeFile.
Command to install N-Beats with Keras: make install-keras
Command to install N-Beats with Pytorch: make install-pytorch
Run on the GPU
It is possible that this is no longer necessary on the recent versions of Tensorflow. To force the utilization of the GPU (with the Keras backend),
run: pip uninstall -y tensorflow && pip install tensorflow-gpu
.
Example
Here is an example to get familiar with both backends. Note that only the Keras backend supports input_dim>1
at the moment.
import warnings
import numpy as np
from nbeats_keras.model import NBeatsNet as NBeatsKeras
from nbeats_pytorch.model import NBeatsNet as NBeatsPytorch
warnings.filterwarnings(action='ignore', message='Setting attributes')
def main():
# https://keras.io/layers/recurrent/
# At the moment only Keras supports input_dim > 1. In the original paper, input_dim=1.
num_samples, time_steps, input_dim, output_dim = 50_000, 10, 1, 1
# This example is for both Keras and Pytorch. In practice, choose the one you prefer.
for BackendType in [NBeatsKeras, NBeatsPytorch]:
# NOTE: If you choose the Keras backend with input_dim>1, you have
# to set the value here too (in the constructor).
backend = BackendType(
backcast_length=time_steps, forecast_length=output_dim,
stack_types=(NBeatsKeras.GENERIC_BLOCK, NBeatsKeras.GENERIC_BLOCK),
nb_blocks_per_stack=2, thetas_dim=(4, 4), share_weights_in_stack=True,
hidden_layer_units=64
)
# Definition of the objective function and the optimizer.
backend.compile(loss='mae', optimizer='adam')
# Definition of the data. The problem to solve is to find f such as | f(x) - y | -> 0.
# where f = np.mean.
x = np.random.uniform(size=(num_samples, time_steps, input_dim))
y = np.mean(x, axis=1, keepdims=True)
# Split data into training and testing datasets.
c = num_samples // 10
x_train, y_train, x_test, y_test = x[c:], y[c:], x[:c], y[:c]
test_size = len(x_test)
# Train the model.
print('Training...')
backend.fit(x_train, y_train, validation_data=(x_test, y_test), epochs=20, batch_size=128)
# Save the model for later.
backend.save('n_beats_model.h5')
# Predict on the testing set (forecast).
predictions_forecast = backend.predict(x_test)
np.testing.assert_equal(predictions_forecast.shape, (test_size, backend.forecast_length, output_dim))
# Predict on the testing set (backcast).
predictions_backcast = backend.predict(x_test, return_backcast=True)
np.testing.assert_equal(predictions_backcast.shape, (test_size, backend.backcast_length, output_dim))
# Load the model.
model_2 = BackendType.load('n_beats_model.h5')
np.testing.assert_almost_equal(predictions_forecast, model_2.predict(x_test))
if __name__ == '__main__':
main()
Browse the examples for more. It includes Jupyter notebooks.
Jupyter notebook: NBeats.ipynb: make run-jupyter
.
Citation
@misc{NBeatsPRemy,
author = {Philippe Remy},
title = {N-BEATS: Neural basis expansion analysis for interpretable time series forecasting},
year = {2020},
publisher = {GitHub},
journal = {GitHub repository},
howpublished = {\url{https://github.com/philipperemy/n-beats}},
}
Contributors
Thank you!
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
File details
Details for the file nbeats-keras-1.8.0.tar.gz
.
File metadata
- Download URL: nbeats-keras-1.8.0.tar.gz
- Upload date:
- Size: 7.1 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/4.0.1 CPython/3.8.0
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | d5039465270f9232621ce29bd90d5b463016175f0b46a5b8c91d92e96cb3e6e2 |
|
MD5 | 784cd61b587a44feaaccde9d2a4893f2 |
|
BLAKE2b-256 | cb8b312e9eab0b7d3e5a36999c9d56763baccd33ecf782a1099f0e02485d38da |
File details
Details for the file nbeats_keras-1.8.0-py3-none-any.whl
.
File metadata
- Download URL: nbeats_keras-1.8.0-py3-none-any.whl
- Upload date:
- Size: 7.3 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/4.0.1 CPython/3.8.0
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | 524482d8d29aa35aa53b1117b978cc2c091e50940a7574751bd747948c43af26 |
|
MD5 | 0c42cd1b481ca9c5ec0f97966af32775 |
|
BLAKE2b-256 | 841b537ec995eb347daea41a0293b8894d768293aca183647d4ec6969d685e8c |