This library uses nbeats-pytorch as base and accomplishes univariate time series forecasting using N-BEATS.
Project description
NBEATS
N-BEATS: Neural basis expansion analysis for interpretable time series forecasting
NBEATS is a pytorch based library for deep learning based time series forecasting (https://arxiv.org/pdf/1905.10437v3.pdf) and utilises nbeats-pytorch.It modifies (https://github.com/amitesh863/nbeats_forecast) to eliminate error runs that might arise on some systems.
Dependencies: Python >=3.6
Installation
$ pip install NBEATS
Import
from NBEATS import NeuralBeats
Mandatory Parameters:
- data
- forecast_length
Basic model with only mandatory parameters can be used to get forecasted values as shown below:
import pandas as pd
from NBEATS import NeuralBeats
data = pd.read_csv('test.csv')
data = data.values # (nx1 array)
model = NeuralBeats(data=data, forecast_length=5)
model.fit()
forecast = model.predict()
Optional parameters to the model object
Parameter | Default Value |
---|---|
backcast_length | 3* forecast_length |
path | ' ' (path to save intermediate training checkpoint) |
checkpoint_name | 'NBEATS-checkpoint.th' |
mode | 'cpu' |
batch_size | len(data)/10 |
thetas_dims | [4, 8] |
nb_blocks_per_stack | 3 |
share_weights_in_stack | False |
train_percent | 0.8 |
save_model | False |
hidden_layer_units | 128 |
stack | [1,1] (As per the paper- Mapping is as follows -- 1: GENERIC_BLOCK, 2: TREND_BLOCK , 3: SEASONALITY_BLOCK) |
Functions
fit()
This is used for training the model. The default value of parameters passed are epoch=25, optimiser=Adam, plot=True, verbose=True
ex:
model.fit(epoch=25,optimiser=torch.optim.AdamW(model.parameters, lr=0.001, betas=(0.9, 0.999), eps=1e-07, weight_decay=0.01, amsgrad=False),plot=False, verbose=True)
predict_data ()
The argument to the method could be empty or a numpy array of length backcast_length x 1 which means if no argument is passed and training data is till month m then prediction will be for month m+1,m+2 and m+3 when forecast_length=3.If forecast is needed for month m+3 onwards then numpy array of backcast_length (3 x forecast_length -This is by default) i.e 9(3 x 3) previous months (m-6 to m+2) needs to be provided to predict for month m+3,m+4 and m+5.
Important Note : Backcast length can be provided as a model argument along with forecast_length eg: backcast_length=6,backcast_length=9,backcast_length=12......till backcast_length=21 for forecast_length=3 ,as the paper suggests values between 2 x forecast_length to 7 x forecast_length .The default is 3 x forecast_length .
Returns forecasted values.
save(file) & load(file,optimizer):
Save and load the model after training respectively.
Example: model.save('NBEATS.th') or model.load('NBEATS.th')
DEMO
1: GENERIC_BLOCK and 3: SEASONALITY_BLOCK stacks are used below (stack=[1,3]).Go through the paper for more details.Playing around with the 3 blocks(GENERIC,SEASONALITY and TREND) might improve accuracy.
import pandas as pd
from NBEATS import NeuralBeats
from torch import optim
data = pd.read_csv('test.csv')
data = data.values # nx1(numpy array)
model=NeuralBeats(data=data,forecast_length=5,stack=[1,3],nb_blocks_per_stack=3,thetas_dims=[3,7])
#or use prebuilt models
#model.load(file='NBEATS.th')
#use customised optimiser with parameters
model.fit(epoch=35,optimiser=optim.AdamW(model.parameters, lr=0.001, betas=(0.9, 0.999), eps=1e-07, weight_decay=0.01, amsgrad=False))
#or
#model.fit()
forecast=model.predict()
#or
#model.predict(predict_data=pred_data) where pred_data is numpy array of size backcast_length*1
Project details
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
File details
Details for the file NBEATS-1.3.10.tar.gz
.
File metadata
- Download URL: NBEATS-1.3.10.tar.gz
- Upload date:
- Size: 5.4 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/3.1.1 pkginfo/1.4.2 requests/2.22.0 setuptools/46.1.1 requests-toolbelt/0.9.1 tqdm/4.28.1 CPython/3.7.1
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | 64cfc2438602c4e699cf87dbee95136fa4effa738989deb2de46b9a8a3a17874 |
|
MD5 | 34678359b492d109530aca4a8c39ad7c |
|
BLAKE2b-256 | 47d31af5625b52cb9a30b6d1a3671a563630c1c9e20ab66dccda83b04f49824b |
File details
Details for the file NBEATS-1.3.10-py3-none-any.whl
.
File metadata
- Download URL: NBEATS-1.3.10-py3-none-any.whl
- Upload date:
- Size: 5.6 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/3.1.1 pkginfo/1.4.2 requests/2.22.0 setuptools/46.1.1 requests-toolbelt/0.9.1 tqdm/4.28.1 CPython/3.7.1
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | 74be0a84467192c2e6abde89ce5eaa915f6a773e401a9cb79014657fc0f2d906 |
|
MD5 | 187dc022df696994423ff5a4df479d2a |
|
BLAKE2b-256 | 502222c7162b772cada097fa07d4e4ee5a8a1d9920a1aa0dc22b2f93937ead00 |