Implementation to the State-Adaptive Neurofuzzy Inference System (S-ANFIS) network
Project description
sanfis
This is a PyTorch-based implementation of my project S-ANFIS: State-ANFIS: A Generalized Regime-Switching Model for Financial Modeling (2022). S-ANFIS is an generalization of Jang's ANFIS: adaptive-network-based fuzzy inference system (1993). The implemenation can easliy be used to fit an ANFIS network.
1. What is S-ANFIS
S-ANFIS is a simple generalization of the ANFIS network, where the input to the premise and the consequence part of the model can be controlled separately. As general notation, I call the input the premise part "state" variables s
and the input of the consequence part "input" or "explanatory" variables x
.
For an in-depth explaination, check out our paper.
2. Installation
This package is intended to be installed on top of PyTorch, so you need to do that first.
Step 1: Install PyTorch
Make sure to consider the correct operating system: Windows, macOS (Intel / Apple Silicon) or Linux. Everything is explained on the developer's website.
To ensure that PyTorch was installed correctly, verify the installation by running sample PyTorch code:
import torch
x = torch.rand(5, 3)
print(x)
Step 2: Install sanfis
sanfis can be installed via pip:
pip install sanfis
3. Quick start
First let's generate some data! The given example is an AR(2)-process whoose AR-parameters depend on the regime of two independent state variables:
# Load modules
import numpy as np
import torch
from sanfis import SANFIS, plottingtools
from sanfis.datagenerators import sanfis_generator
# seed for reproducibility
np.random.seed(3)
torch.manual_seed(3)
## Generate Data ##
S, S_train, S_valid, X, X_train, X_valid, y, y_train, y_valid, = sanfis_generator.gen_data_ts(
n_obs=1000, test_size=0.33, plot_dgp=True)
Set a list of membership functions for each of the state variables that enter the model:
# list of membership functions
membfuncs = [
{'function': 'sigmoid',
'n_memb': 2,
'params': {'c': {'value': [0.0, 0.0],
'trainable': True},
'gamma': {'value': [-2.5, 2.5],
'trainable': True}}},
{'function': 'sigmoid',
'n_memb': 2,
'params': {'c': {'value': [0.0, 0.0],
'trainable': True},
'gamma': {'value': [-2.5, 2.5],
'trainable': True}}}
]
The given example uses two sigmoid functions for each state variable.
Now create the model, fit and evaluate:
# make model / set loss function and optimizer
fis = SANFIS(membfuncs=membfuncs, n_input=2, scale='Std')
loss_function = torch.nn.MSELoss(reduction='mean')
optimizer = torch.optim.Adam(fis.parameters(), lr=0.005)
# fit model
history = fis.fit([S_train, X_train, y_train], [S_valid, X_valid, y_valid],
optimizer, loss_function, epochs=1000)
# eval model
y_pred = fis.predict([S, X])
plottingtools.plt_prediction(y, y_pred,
save_path='img/sanfis_prediction.pdf')
# plottingtools.plt_learningcurves(history)
4. Features
4.1 Membership functions
The implementation allows a very flexible usage of membership functions. For each input variable that enters the premise-part of the model, the type and number of membership functions can be flexibly chosen. As of today, three possible membership functions are implemented:
Gaussian
The Gaussian is described by 2 parameters, mu
for the location and sigma
for the wideness.
# Example
gaussian_membfunc = {'function': 'gaussian',
'n_memb': 3, # 3 membership functions
'params': {'mu': {'value': [-2.0, 0.0, 1.5],
'trainable': True},
'sigma': {'value': [1.0, 0.5, 1.0],
'trainable': True}}
}
In this example, three membership functions are considered.
General bell-shaped
The general bell-shaped function is described by three parameters, a
(wideness), b
(shape) and c
(location).
bell_membfunc = {'function': 'bell',
'n_memb': 2,
'params': {'c': {'value': [-1.5, 1.5],
'trainable': True},
'a': {'value': [3.0, 1.0],
'trainable': False},
'b': {'value': [1.0, 3.0],
'trainable': False}}
}
Sigmoid
The sigmoid is described by two parameters: c
(location) and gamma
(steepness).
sigmoid_membfunc = {'function': 'sigmoid',
'n_memb': 2,
'params': {'c': {'value': [0.0, 0.0],
'trainable': True},
'gamma': {'value': [-2.5, 2.5],
'trainable': True}}
}
Remember to add a list of membership functions as membfunc
argument when creating the SANFIS
oject, e.g.:
MEMBFUNCS = [gaussian_membfunc, bell_membfunc, sigmoid_membfunc]
model = SANFIS(MEMBFUNCS, n_input=2)
model.plotmfs(bounds=[[-2.0, 2.0], # plot bounds for first membfunc
[-4.0, 2.0], # plot bounds for second membfunc
[-5.0, 5.0]], # plot bounds fo third membfunc
save_path='img/membfuncs.pdf')
4.2 Tensorboard
Tensorboard provides visualization needed for machine learning experimentation. Further information can be found here
Step 1: Install tensorboard
pip install tensorboard
Step 2: enable tensorboard usage during training
Tensorboard functionality can be added via arguments in the fit()
function, e.g.
history = model.fit( ...
use_tensorboard=True,
logdir='logs/tb',
hparams_dict={}
)
Note that hparams_dict
is an optional argument where you can store additional hyperparameters of you model, e.g. hparams_dict={'n_input':2}
.
Step 3: Open tensorboard
tensorboard --logdir=logs/tb
5. Using the plain vanilla ANFIS network
To use the plain vanilla ANFIS network, simply remove the state variables s
from the training (fit()
). This automatically sets the same input for premise and consequence part of the model.
# Set 4 input variables with 3 gaussian membership functions each
MEMBFUNCS = [
{'function': 'gaussian',
'n_memb': 3,
'params': {'mu': {'value': [-0.5, 0.0, 0.5],
'trainable': True},
'sigma': {'value': [1.0, 1.0, 1.0],
'trainable': True}}},
{'function': 'gaussian',
'n_memb': 3,
'params': {'mu': {'value': [-0.5, 0.0, 0.5],
'trainable': True},
'sigma': {'value': [1.0, 1.0, 1.0],
'trainable': True}}},
{'function': 'gaussian',
'n_memb': 3,
'params': {'mu': {'value': [-0.5, 0.0, 0.5],
'trainable': True},
'sigma': {'value': [1.0, 1.0, 1.0],
'trainable': True}}},
{'function': 'gaussian',
'n_memb': 3,
'params': {'mu': {'value': [-0.5, 0.0, 0.5],
'trainable': True},
'sigma': {'value': [1.0, 1.0, 1.0],
'trainable': True}}},
]
# generate some data (mackey chaotic time series)
X, X_train, X_valid, y, y_train, y_valid = datagenerator.gen_data(data_id='mackey',
n_obs=2080, n_input=4)
# create model
model = SANFIS(membfuncs=MEMBFUNCS,
n_input=4,
scale='Std')
optimizer = torch.optim.Adam(params=model.parameters())
loss_functions = torch.nn.MSELoss(reduction='mean')
# fit model
history = model.fit(train_data=[X_train, y_train],
valid_data=[X_valid, y_valid],
optimizer=optimizer,
loss_function=loss_functions,
epochs=200,
)
# predict data
y_pred = model.predict(X)
# plot learning curves
plottingtools.plt_learningcurves(history, save_path='img/learning_curves.pdf')
# plot prediction
plottingtools.plt_prediction(y, y_pred, save_path='img/mackey_prediction.pdf')
6. Related work
- AnfisTensorflow2.0 by me
- bare-bones implementation of ANFIS (manual derivatives) by twmeggs
- PyTorch implementation by James Power
- simple ANFIS based on Tensorflow 1.15.2 by Santiago Cuervo
Contact
I am very thankful for feedback. Also, if you have questions, please contact gregor.lenhard92@gmail.com
References
If you use my work, please cite it appropriately:
G. Lenhard and D. Maringer, "State-ANFIS: A Generalized Regime-Switching Model for Financial Modeling," 2022 IEEE Symposium on Computational Intelligence for Financial Engineering and Economics (CIFEr), 2022, pp. 1-8, doi: 10.1109/CIFEr52523.2022.9776208.
BibTex:
@INPROCEEDINGS{lenhard2022sanfis,
author={Lenhard, Gregor and Maringer, Dietmar},
booktitle={2022 IEEE Symposium on Computational Intelligence for Financial Engineering and Economics ({CIFEr})},
title={State-{ANFIS}: A Generalized Regime-Switching Model for Financial Modeling},
year={2022},
pages={1--8},
doi={10.1109/CIFEr52523.2022.9776208},
organization={IEEE}
}
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
File details
Details for the file sanfis-0.1.0.tar.gz
.
File metadata
- Download URL: sanfis-0.1.0.tar.gz
- Upload date:
- Size: 25.7 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/4.0.1 CPython/3.8.12
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | 4667c9f19c3003a3e4d647c1c3cec1f1793f5069fb81d3a630b52a43e2a0da1f |
|
MD5 | f651b556ceb82eeb34d0b6840c82e023 |
|
BLAKE2b-256 | 6b6737eb3e9d772d580de375645e2a178cbf9e5a91c58bd25c9a97cc580a1844 |
File details
Details for the file sanfis-0.1.0-py3-none-any.whl
.
File metadata
- Download URL: sanfis-0.1.0-py3-none-any.whl
- Upload date:
- Size: 24.4 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/4.0.1 CPython/3.8.12
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | 77f7a514f02c56ac067e73c31666dce5ae90d823a9dbe9068b36b81930a520e1 |
|
MD5 | 15b6b479b28feff4213659a4d61c556b |
|
BLAKE2b-256 | 282c7f296ff67ae04ecc590f66f4d71b964d9bab42b30410c73905d276871982 |