Nonlinear autoregression with feed-forward neural networks (or any estimator really)
Project description
ARNet: Nonlinear autoregression with feed-forward neural networks (or any estimator really)
A Python port of nnetar.R: nonlinear
autoregression with feed-forward neural networks. Here, we made it a bit more generic to use any kind of estimator
respecting the sklearn's interface, though the default is to use a single layer MLPRegressor from sklearn. Following
nnetar.R, the ordinal (p
) and seasonal (P
) lags to look back are automatically chosen if left
unspecified. Furthermore, number of hidden neurons in the default estimator is also chosen automatically following the
heuristic in nnetar.R.
Installation:
pip install arnet
Here are some snippets to illustrate the usage:
Fit-predict flow
from arnet import ARNet
y_train, y_test = ...
model = ARNet()
model.fit(y_train)
predictions = model.predict(n_steps=y_test.size)
Instantaniate the model, fit to data and predict; that's all.
If you have side information, i.e., exogenous regressors to help in prediction, you can supply them like so:
X_train, X_test, y_train, y_test = ...
model = ARNet()
model.fit(X_train, y_train)
predictions = model.predict(X_test)
If you have seasonality in the series:
# An automatic `P` (seasonal lags) will be chosen as left unspecified
model = ARNet(seasonality=12)
model.fit(...)
predictons = model.predict(...)
Default base model is an MLPRegressor
; if you want to use another one, that's okay too:
from sklearn.ensemble import RandomForestRegressor
model = ARNet(RandomForestRegressor(max_features=0.9))
model.fit(...).predict(...)
In fact, if you use LinearRegression
, you
effectively have the linear AR(p) model as a subset! Here is an example for that.
Prediction intervals
Following the procedure here, the model is able to produce prediction intervals for the future given confidence levels:
model = ARNet()
model.fit(y_train)
predictions, intervals = model.predict(n_steps=y_test.size, return_intervals=True, alphas=[80, 95])
You can also obtain the simulated paths by issuing return_paths=True
.
Validation
There is a .validate
method to perform time series validation (expanding window) on a parameter grid with either a full search or a randomized one:
X, y = ...
param_grid = {"p": [2, 5, 8, None], "estimator__solver": ["sgd", "adam"]}
n_iter = -1 # -1 (or None) means full grid search; any positive integer would mean a randomized search
search = ARNet.validate(y, param_grid, X=X, n_iter=n_iter)
best_model = search["best_estimator_"]
# Do something with the best model
Plotting
There is also a helper static function for plotting lines -- it might be helpful in visualizing the true values along with the predictions and intervals.
preds_in_sample = model.fitted_values_ # fitted values are available as a post-fit attribute
preds_out_sample, intervals = model.predict(n_steps=y_test.size, return_intervals=True, alphas=[80, 95])
ARNet.plot(lines=[y_train, preds_in_sample, y_test, preds_out_sample],
labels=["y-train", "in-sample preds", "y-test", "out-sample preds"],
true_indexes=[0, 2],
intervals=intervals)
Here is an example plot output:
For examples with a dataset in action, please see here; for the API reference, see here.
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
File details
Details for the file arnet-0.0.1.tar.gz
.
File metadata
- Download URL: arnet-0.0.1.tar.gz
- Upload date:
- Size: 24.1 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/5.1.1 CPython/3.11.0
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | 31ddc4b64c83aec11cacf908ed3dfab324eadae0f629cf03f4bff564e9140a15 |
|
MD5 | 3d423a73347998b9176d22db7c53c8f6 |
|
BLAKE2b-256 | e302cbe476d7e2295ad0dedd4ad92539e781b3df5b7567ae25198661b1488bce |
File details
Details for the file arnet-0.0.1-py3-none-any.whl
.
File metadata
- Download URL: arnet-0.0.1-py3-none-any.whl
- Upload date:
- Size: 19.3 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/5.1.1 CPython/3.11.0
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | 44a8af6ec6b381557c5ad9ba99524bcae949a9cab50ade9a9937d1375abef703 |
|
MD5 | bf4c6d1843aae1b88a8ea3b0b3c318dc |
|
BLAKE2b-256 | f86d4b4de476a840edda1a3087638ae3af5b231d49fc3cc4a65fae1da92cfcac |