Skip to main content

N-Beats

Project description

N-BEATS: Neural basis expansion analysis for interpretable time series forecasting


N-Beats at the beginning of the training

Trust me, after a few more steps, the green curve (predictions) matches the ground truth exactly :-)

Installation

From PyPI

Install Keras: pip install nbeats-keras.

Install Pytorch: pip install nbeats-pytorch.

From the sources

Installation is based on a MakeFile. Make sure you are in a virtualenv and have python3 installed.

Command to install N-Beats with Keras: make install-keras

Command to install N-Beats with Pytorch: make install-pytorch

Run on the GPU

To force the utilization of the GPU (Tensorflow), run: pip uninstall -y tensorflow && pip install tensorflow-gpu.

Example

Jupyter notebook: NBeats.ipynb: make run-jupyter.

Here is a toy example on how to use this model (train and predict):

import warnings

import numpy as np

from nbeats_keras.model import NBeatsNet as NBeatsKeras
from nbeats_pytorch.model import NBeatsNet as NBeatsPytorch

warnings.filterwarnings(action='ignore', message='Setting attributes')


def main():
    # https://keras.io/layers/recurrent/
    num_samples, time_steps, input_dim, output_dim = 50_000, 10, 1, 1

    # Definition of the model.
    model_keras = NBeatsKeras(backcast_length=time_steps, forecast_length=output_dim,
                              stack_types=(NBeatsKeras.GENERIC_BLOCK, NBeatsKeras.GENERIC_BLOCK),
                              nb_blocks_per_stack=2, thetas_dim=(4, 4), share_weights_in_stack=True,
                              hidden_layer_units=64)

    model_pytorch = NBeatsPytorch(backcast_length=time_steps, forecast_length=output_dim,
                                  stack_types=(NBeatsPytorch.GENERIC_BLOCK, NBeatsPytorch.GENERIC_BLOCK),
                                  nb_blocks_per_stack=2, thetas_dim=(4, 4), share_weights_in_stack=True,
                                  hidden_layer_units=64)

    # Definition of the objective function and the optimizer.
    model_keras.compile(loss='mae', optimizer='adam')
    model_pytorch.compile(loss='mae', optimizer='adam')

    # Definition of the data. The problem to solve is to find f such as | f(x) - y | -> 0.
    # where f = np.mean.
    x = np.random.uniform(size=(num_samples, time_steps, input_dim))
    y = np.mean(x, axis=1, keepdims=True)

    # Split data into training and testing datasets.
    c = num_samples // 10
    x_train, y_train, x_test, y_test = x[c:], y[c:], x[:c], y[:c]
    test_size = len(x_test)

    # Train the model.
    print('Keras training...')
    model_keras.fit(x_train, y_train, validation_data=(x_test, y_test), epochs=20, batch_size=128)
    print('Pytorch training...')
    model_pytorch.fit(x_train, y_train, validation_data=(x_test, y_test), epochs=20, batch_size=128)

    # Save the model for later.
    model_keras.save('n_beats_model.h5')
    model_pytorch.save('n_beats_pytorch.th')

    # Predict on the testing set (forecast).
    predictions_keras_forecast = model_keras.predict(x_test)
    predictions_pytorch_forecast = model_pytorch.predict(x_test)
    np.testing.assert_equal(predictions_keras_forecast.shape, (test_size, model_keras.forecast_length, output_dim))
    np.testing.assert_equal(predictions_pytorch_forecast.shape, (test_size, model_pytorch.forecast_length, output_dim))

    # Predict on the testing set (backcast).
    predictions_keras_backcast = model_keras.predict(x_test, return_backcast=True)
    predictions_pytorch_backcast = model_pytorch.predict(x_test, return_backcast=True)
    np.testing.assert_equal(predictions_keras_backcast.shape, (test_size, model_keras.backcast_length, output_dim))
    np.testing.assert_equal(predictions_pytorch_backcast.shape, (test_size, model_pytorch.backcast_length, output_dim))

    # Load the model.
    model_keras_2 = NBeatsKeras.load('n_beats_model.h5')
    model_pytorch_2 = NBeatsPytorch.load('n_beats_pytorch.th')

    np.testing.assert_almost_equal(predictions_keras_forecast, model_keras_2.predict(x_test))
    np.testing.assert_almost_equal(predictions_pytorch_forecast, model_pytorch_2.predict(x_test))


if __name__ == '__main__':
    main()

Citation

@misc{NBeatsPRemy,
  author = {Philippe Remy},
  title = {N-BEATS: Neural basis expansion analysis for interpretable time series forecasting},
  year = {2020},
  publisher = {GitHub},
  journal = {GitHub repository},
  howpublished = {\url{https://github.com/philipperemy/n-beats}},
}

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

nbeats-pytorch-1.4.0.tar.gz (5.9 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

nbeats_pytorch-1.4.0-py3-none-any.whl (6.8 kB view details)

Uploaded Python 3

File details

Details for the file nbeats-pytorch-1.4.0.tar.gz.

File metadata

  • Download URL: nbeats-pytorch-1.4.0.tar.gz
  • Upload date:
  • Size: 5.9 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/3.3.0 pkginfo/1.7.0 requests/2.25.1 setuptools/54.0.0 requests-toolbelt/0.9.1 tqdm/4.58.0 CPython/3.7.9

File hashes

Hashes for nbeats-pytorch-1.4.0.tar.gz
Algorithm Hash digest
SHA256 e0e2ef9a5ac6a9ea60fe031bd39339ac8aa7b42edcea2bd03a35285c7c7ee74b
MD5 49d06e5daa287c92b0c8cad59f9999c7
BLAKE2b-256 251880a676d0054f4368d505f83f71b24b89f763bf00cc16316c4f23308d9e35

See more details on using hashes here.

File details

Details for the file nbeats_pytorch-1.4.0-py3-none-any.whl.

File metadata

  • Download URL: nbeats_pytorch-1.4.0-py3-none-any.whl
  • Upload date:
  • Size: 6.8 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/3.3.0 pkginfo/1.7.0 requests/2.25.1 setuptools/54.0.0 requests-toolbelt/0.9.1 tqdm/4.58.0 CPython/3.7.9

File hashes

Hashes for nbeats_pytorch-1.4.0-py3-none-any.whl
Algorithm Hash digest
SHA256 1d69c07c8508a54d1b27b48d4e64a55fc308d26dd07943c7890e724c3e0a312a
MD5 1da2b482f063ea10f64c160d64fc9db7
BLAKE2b-256 3f1b2e7bb4b7ebb1d3d153f0551262da1372b0aa9ccdb91519341f49e1f28153

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page