Skip to main content

This library uses nbeats-pytorch as base and accomplishes univariate time series forecasting using N-BEATS.

Project description

NBEATS

N-BEATS: Neural basis expansion analysis for interpretable time series forecasting

NBEATS is a pytorch based library for deep learning based time series forecasting (https://arxiv.org/pdf/1905.10437v3.pdf) and utilises nbeats-pytorch and (https://github.com/amitesh863/nbeats_forecast) with modifications.It uses the architecture provided in (https://github.com/philipperemy/n-beats) and modifies it to provide a choice of 'No Residual' version of the architecture along with 'DRESS' version.

Dependencies: Python >=3.6

Installation

$ pip install NBEATS

Import

from NBEATS import NeuralBeats

Mandatory Parameters:

  • data
  • forecast_length

Basic model with only mandatory parameters can be used to get forecasted values as shown below:

import pandas as pd
from NBEATS import NeuralBeats

data = pd.read_csv('test.csv')   
data = data.values        # (nx1 array)

model = NeuralBeats(data=data, forecast_length=5)
model.fit()
forecast = model.predict()

Optional parameters to the model object

Parameter Default Value
backcast_length 3* forecast_length
Architecture 'DRESS' or 'No Residual' (Default is 'DRESS')
path ' ' (path to save intermediate training checkpoint)
checkpoint_name 'NBEATS-checkpoint.th'
mode 'cpu'
batch_size len(data)/10
thetas_dims [4, 8]
nb_blocks_per_stack 3
share_weights_in_stack False
train_percent 0.8
save_model False
hidden_layer_units 128
stack [1,1] (As per the paper- Mapping is as follows -- 1: GENERIC_BLOCK, 2: TREND_BLOCK , 3: SEASONALITY_BLOCK)

Functions

fit()

This is used for training the model. The default value of parameters passed are epoch=25, optimiser=Adam, plot=True, verbose=True

ex:

model.fit(epoch=25,optimiser=torch.optim.AdamW(model.parameters, lr=0.001, betas=(0.9, 0.999), eps=1e-07, weight_decay=0.01, amsgrad=False),plot=False, verbose=True)
predict_data ()

The argument to the method could be empty or a numpy array of length backcast_length x 1 which means if no argument is passed and training data is till month m then prediction will be for month m+1,m+2 and m+3 when forecast_length=3.If forecast is needed for month m+3 onwards then numpy array of backcast_length (3 x forecast_length -This is by default) i.e 9(3 x 3) previous months (m-6 to m+2) needs to be provided to predict for month m+3,m+4 and m+5.

Important Note : Backcast length can be provided as a model argument along with forecast_length eg: backcast_length=6,backcast_length=9,backcast_length=12......till backcast_length=21 for forecast_length=3 ,as the paper suggests values between 2 x forecast_length to 7 x forecast_length .The default is 3 x forecast_length .

Returns forecasted values.

save(file) & load(file,optimizer):

Save and load the model after training respectively.

Example: model.save('NBEATS.th') or model.load('NBEATS.th')

DEMO

1: GENERIC_BLOCK and 3: SEASONALITY_BLOCK stacks are used below (stack=[1,3]).Go through the paper for more details.Playing around with the 3 blocks(GENERIC,SEASONALITY and TREND) might improve accuracy.

import pandas as pd
from NBEATS import NeuralBeats
from torch import optim

data = pd.read_csv('test.csv')   
data = data.values # nx1(numpy array)

model=NeuralBeats(data=data,forecast_length=5,stack=[1,3],nb_blocks_per_stack=3,thetas_dims=[3,7])

#or use prebuilt models
#model.load(file='NBEATS.th')


#use customised optimiser with parameters
model.fit(epoch=35,optimiser=optim.AdamW(model.parameters, lr=0.001, betas=(0.9, 0.999), eps=1e-07, weight_decay=0.01, amsgrad=False)) 
#or 
#model.fit()

forecast=model.predict()
#or
#model.predict(predict_data=pred_data) where pred_data is numpy array of size backcast_length*1

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

NBEATS-1.3.11.tar.gz (7.1 kB view details)

Uploaded Source

Built Distribution

NBEATS-1.3.11-py3-none-any.whl (7.6 kB view details)

Uploaded Python 3

File details

Details for the file NBEATS-1.3.11.tar.gz.

File metadata

  • Download URL: NBEATS-1.3.11.tar.gz
  • Upload date:
  • Size: 7.1 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/3.1.1 pkginfo/1.4.2 requests/2.22.0 setuptools/46.1.1 requests-toolbelt/0.9.1 tqdm/4.28.1 CPython/3.7.1

File hashes

Hashes for NBEATS-1.3.11.tar.gz
Algorithm Hash digest
SHA256 d78fed64b27916e85a3095a2f7dd1a74aedfb17c8b7fcd39dd05859db5d69bca
MD5 5ad562baeec384f0487f84082e3f6ea7
BLAKE2b-256 bb1f45045592aa0007544ae8d17005d155383f5fd0509ad9009adf1db96e3f0e

See more details on using hashes here.

File details

Details for the file NBEATS-1.3.11-py3-none-any.whl.

File metadata

  • Download URL: NBEATS-1.3.11-py3-none-any.whl
  • Upload date:
  • Size: 7.6 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/3.1.1 pkginfo/1.4.2 requests/2.22.0 setuptools/46.1.1 requests-toolbelt/0.9.1 tqdm/4.28.1 CPython/3.7.1

File hashes

Hashes for NBEATS-1.3.11-py3-none-any.whl
Algorithm Hash digest
SHA256 0d5db68f6d4d05e1560ddc04638f73a21ce916ec89316e7f257ca64b8baf1c7d
MD5 d520dc8b21538a4286ff508e4dd8f436
BLAKE2b-256 9e7afcd20b30745d3e5fb9906cb0d43cded58abed635f1f380fc8cdb4302ec58

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page