Skip to main content

Nonlinear autoregression with feed-forward neural networks (or any estimator really)

Project description

ARNet: Nonlinear autoregression with feed-forward neural networks (or any estimator really)

A Python port of nnetar.R: nonlinear autoregression with feed-forward neural networks. Here, we made it a bit more generic to use any kind of estimator respecting the sklearn's interface, though the default is to use a single layer MLPRegressor from sklearn. Following nnetar.R, the ordinal (p) and seasonal (P) lags to look back are automatically chosen if left unspecified. Furthermore, number of hidden neurons in the default estimator is also chosen automatically following the heuristic in nnetar.R.

Installation:

pip install arnet

Here are some snippets to illustrate the usage:

Fit-predict flow

from arnet import ARNet

y_train, y_test = ...

model = ARNet()
model.fit(y_train)
predictions = model.predict(n_steps=y_test.size)

Instantaniate the model, fit to data and predict; that's all.

If you have side information, i.e., exogenous regressors to help in prediction, you can supply them like so:

X_train, X_test, y_train, y_test = ...

model = ARNet()
model.fit(X_train, y_train)
predictions = model.predict(X_test)

If you have seasonality in the series:

# An automatic `P` (seasonal lags) will be chosen as left unspecified
model = ARNet(seasonality=12)
model.fit(...)
predictons = model.predict(...)

Default base model is an MLPRegressor; if you want to use another one, that's okay too:

from sklearn.ensemble import RandomForestRegressor

model = ARNet(RandomForestRegressor(max_features=0.9))
model.fit(...).predict(...)

In fact, if you use LinearRegression, you effectively have the linear AR(p) model as a subset! Here is an example for that.

Prediction intervals

Following the procedure here, the model is able to produce prediction intervals for the future given confidence levels:

model = ARNet()
model.fit(y_train)
predictions, intervals = model.predict(n_steps=y_test.size, return_intervals=True, alphas=[80, 95])

You can also obtain the simulated paths by issuing return_paths=True.

Validation

There is a .validate method to perform time series validation (expanding window) on a parameter grid with either a full search or a randomized one:

X, y = ...
param_grid = {"p": [2, 5, 8, None], "estimator__solver": ["sgd", "adam"]}
n_iter = -1  # -1 (or None) means full grid search; any positive integer would mean a randomized search

search = ARNet.validate(y, param_grid, X=X, n_iter=n_iter)
best_model = search["best_estimator_"]
# Do something with the best model

Plotting

There is also a helper static function for plotting lines -- it might be helpful in visualizing the true values along with the predictions and intervals.

preds_in_sample = model.fitted_values_  # fitted values are available as a post-fit attribute
preds_out_sample, intervals = model.predict(n_steps=y_test.size, return_intervals=True, alphas=[80, 95])

ARNet.plot(lines=[y_train, preds_in_sample, y_test, preds_out_sample],
           labels=["y-train", "in-sample preds", "y-test", "out-sample preds"],
                   true_indexes=[0, 2],
                   intervals=intervals)

Here is an example plot output: example plot

For examples with a dataset in action, please see here; for the API reference, see here.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

arnet-0.0.1.tar.gz (24.1 kB view details)

Uploaded Source

Built Distribution

arnet-0.0.1-py3-none-any.whl (19.3 kB view details)

Uploaded Python 3

File details

Details for the file arnet-0.0.1.tar.gz.

File metadata

  • Download URL: arnet-0.0.1.tar.gz
  • Upload date:
  • Size: 24.1 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/5.1.1 CPython/3.11.0

File hashes

Hashes for arnet-0.0.1.tar.gz
Algorithm Hash digest
SHA256 31ddc4b64c83aec11cacf908ed3dfab324eadae0f629cf03f4bff564e9140a15
MD5 3d423a73347998b9176d22db7c53c8f6
BLAKE2b-256 e302cbe476d7e2295ad0dedd4ad92539e781b3df5b7567ae25198661b1488bce

See more details on using hashes here.

File details

Details for the file arnet-0.0.1-py3-none-any.whl.

File metadata

  • Download URL: arnet-0.0.1-py3-none-any.whl
  • Upload date:
  • Size: 19.3 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/5.1.1 CPython/3.11.0

File hashes

Hashes for arnet-0.0.1-py3-none-any.whl
Algorithm Hash digest
SHA256 44a8af6ec6b381557c5ad9ba99524bcae949a9cab50ade9a9937d1375abef703
MD5 bf4c6d1843aae1b88a8ea3b0b3c318dc
BLAKE2b-256 f86d4b4de476a840edda1a3087638ae3af5b231d49fc3cc4a65fae1da92cfcac

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page