Skip to main content

Implementation to the State-Adaptive Neurofuzzy Inference System (S-ANFIS) network

Project description

sanfis

This is a PyTorch-based implementation of my project S-ANFIS: State-ANFIS: A Generalized Regime-Switching Model for Financial Modeling (2022). S-ANFIS is an generalization of Jang's ANFIS: adaptive-network-based fuzzy inference system (1993). The implemenation can easliy be used to fit an ANFIS network.

1. What is S-ANFIS

S-ANFIS is a simple generalization of the ANFIS network, where the input to the premise and the consequence part of the model can be controlled separately. As general notation, I call the input the premise part "state" variables s and the input of the consequence part "input" or "explanatory" variables x.

S-ANFIS architecture

For an in-depth explaination, check out our paper.

2. Installation

This package is intended to be installed on top of PyTorch, so you need to do that first.

Step 1: Install PyTorch

Make sure to consider the correct operating system: Windows, macOS (Intel / Apple Silicon) or Linux. Everything is explained on the developer's website.

To ensure that PyTorch was installed correctly, verify the installation by running sample PyTorch code:

import torch
x = torch.rand(5, 3)
print(x)

Step 2: Install sanfis

sanfis can be installed via pip:

pip install sanfis

3. Quick start

First let's generate some data! The given example is an AR(2)-process whoose AR-parameters depend on the regime of two independent state variables:

# Load modules
import numpy as np
import torch
from sanfis import SANFIS, plottingtools
from sanfis.datagenerators import sanfis_generator

# seed for reproducibility
np.random.seed(3)
torch.manual_seed(3)
## Generate Data ##
S, S_train, S_valid, X, X_train, X_valid, y, y_train, y_valid, = sanfis_generator.gen_data_ts(
    n_obs=1000, test_size=0.33, plot_dgp=True)

s-anfis data generating process

Set a list of membership functions for each of the state variables that enter the model:

# list of membership functions
membfuncs = [
    {'function': 'sigmoid',
     'n_memb': 2,
     'params': {'c': {'value': [0.0, 0.0],
                      'trainable': True},
                'gamma': {'value': [-2.5, 2.5],
                          'trainable': True}}},

    {'function': 'sigmoid',
     'n_memb': 2,
     'params': {'c': {'value': [0.0, 0.0],
                      'trainable': True},
                'gamma': {'value': [-2.5, 2.5],
                          'trainable': True}}}
]

The given example uses two sigmoid functions for each state variable.

Now create the model, fit and evaluate:

# make model / set loss function and optimizer
fis = SANFIS(membfuncs=membfuncs, n_input=2, scale='Std')
loss_function = torch.nn.MSELoss(reduction='mean')
optimizer = torch.optim.Adam(fis.parameters(), lr=0.005)

# fit model
history = fis.fit([S_train, X_train, y_train], [S_valid, X_valid, y_valid],
                  optimizer, loss_function, epochs=1000)
# eval model
y_pred = fis.predict([S, X])
plottingtools.plt_prediction(y, y_pred,
                             save_path='img/sanfis_prediction.pdf')
# plottingtools.plt_learningcurves(history)

s-anfis prediction

4. Features

4.1 Membership functions

The implementation allows a very flexible usage of membership functions. For each input variable that enters the premise-part of the model, the type and number of membership functions can be flexibly chosen. As of today, three possible membership functions are implemented:

Gaussian

The Gaussian is described by 2 parameters, mu for the location and sigma for the wideness.

# Example
gaussian_membfunc = {'function': 'gaussian',
			 'n_memb': 3,	 # 3 membership functions
			 'params': {'mu': {'value': [-2.0, 0.0, 1.5], 
			                'trainable': True},
			           'sigma': {'value': [1.0, 0.5, 1.0],
			               'trainable': True}}
			}

In this example, three membership functions are considered.

General bell-shaped

The general bell-shaped function is described by three parameters, a (wideness), b (shape) and c (location).

bell_membfunc = {'function': 'bell',
			'n_memb': 2,
			'params': {'c': {'value': [-1.5, 1.5],
			                'trainable': True},
			            'a': {'value': [3.0, 1.0],
			                'trainable': False},
			            'b': {'value': [1.0, 3.0],
			                'trainable': False}}
					}

Sigmoid

The sigmoid is described by two parameters: c (location) and gamma (steepness).

sigmoid_membfunc = {'function': 'sigmoid',
			'n_memb': 2,
			'params': {'c': {'value': [0.0, 0.0],
			                'trainable': True},
			            'gamma': {'value': [-2.5, 2.5],
			                    'trainable': True}}
}

Remember to add a list of membership functions as membfunc argument when creating the SANFIS oject, e.g.:

MEMBFUNCS = [gaussian_membfunc, bell_membfunc, sigmoid_membfunc]
model = SANFIS(MEMBFUNCS, n_input=2)
model.plotmfs(bounds=[[-2.0, 2.0],  # plot bounds for first membfunc
                      [-4.0, 2.0],  # plot bounds for second membfunc
                      [-5.0, 5.0]],  # plot bounds fo third membfunc
              save_path='img/membfuncs.pdf')

membership functions

4.2 Tensorboard

Tensorboard provides visualization needed for machine learning experimentation. Further information can be found here

Step 1: Install tensorboard

pip install tensorboard

Step 2: enable tensorboard usage during training

Tensorboard functionality can be added via arguments in the fit() function, e.g.

history = model.fit( ...
                    use_tensorboard=True,
                    logdir='logs/tb',
                    hparams_dict={}
                   )

Note that hparams_dict is an optional argument where you can store additional hyperparameters of you model, e.g. hparams_dict={'n_input':2}.

Step 3: Open tensorboard

tensorboard --logdir=logs/tb

5. Using the plain vanilla ANFIS network

ANFIS architecture

To use the plain vanilla ANFIS network, simply remove the state variables s from the training (fit()). This automatically sets the same input for premise and consequence part of the model.

# Set 4 input variables with 3 gaussian membership functions each
MEMBFUNCS = [
    {'function': 'gaussian',
     'n_memb': 3,
     'params': {'mu': {'value': [-0.5, 0.0, 0.5],
                       'trainable': True},
                'sigma': {'value': [1.0, 1.0, 1.0],
                          'trainable': True}}},

    {'function': 'gaussian',
     'n_memb': 3,
     'params': {'mu': {'value': [-0.5, 0.0, 0.5],
                       'trainable': True},
                'sigma': {'value': [1.0, 1.0, 1.0],
                          'trainable': True}}},

    {'function': 'gaussian',
     'n_memb': 3,
     'params': {'mu': {'value': [-0.5, 0.0, 0.5],
                       'trainable': True},
                'sigma': {'value': [1.0, 1.0, 1.0],
                          'trainable': True}}},

    {'function': 'gaussian',
     'n_memb': 3,
     'params': {'mu': {'value': [-0.5, 0.0, 0.5],
                       'trainable': True},
                'sigma': {'value': [1.0, 1.0, 1.0],
                          'trainable': True}}},

]

# generate some data (mackey chaotic time series)
X, X_train, X_valid, y, y_train, y_valid = datagenerator.gen_data(data_id='mackey',
                                                                  n_obs=2080, n_input=4)

# create model
model = SANFIS(membfuncs=MEMBFUNCS,
               n_input=4,
               scale='Std')
optimizer = torch.optim.Adam(params=model.parameters())
loss_functions = torch.nn.MSELoss(reduction='mean')

# fit model
history = model.fit(train_data=[X_train, y_train],
                    valid_data=[X_valid, y_valid],
                    optimizer=optimizer,
                    loss_function=loss_functions,
                    epochs=200,
                    )

# predict data
y_pred = model.predict(X)

# plot learning curves
plottingtools.plt_learningcurves(history, save_path='img/learning_curves.pdf')

# plot prediction
plottingtools.plt_prediction(y, y_pred, save_path='img/mackey_prediction.pdf')

learning curves

prediction mackey time series

6. Related work

Contact

I am very thankful for feedback. Also, if you have questions, please contact gregor.lenhard92@gmail.com

References

If you use my work, please cite it appropriately:

G. Lenhard and D. Maringer, "State-ANFIS: A Generalized Regime-Switching Model for Financial Modeling," 2022 IEEE Symposium on Computational Intelligence for Financial Engineering and Economics (CIFEr), 2022, pp. 1-8, doi: 10.1109/CIFEr52523.2022.9776208.

BibTex:

@INPROCEEDINGS{lenhard2022sanfis,
  author={Lenhard, Gregor and Maringer, Dietmar},
  booktitle={2022 IEEE Symposium on Computational Intelligence for Financial Engineering and Economics ({CIFEr})}, 
  title={State-{ANFIS}: A Generalized Regime-Switching Model for Financial Modeling}, 
  year={2022},
  pages={1--8},
  doi={10.1109/CIFEr52523.2022.9776208},
  organization={IEEE}
  }

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

sanfis-0.1.0.tar.gz (25.7 kB view details)

Uploaded Source

Built Distribution

sanfis-0.1.0-py3-none-any.whl (24.4 kB view details)

Uploaded Python 3

File details

Details for the file sanfis-0.1.0.tar.gz.

File metadata

  • Download URL: sanfis-0.1.0.tar.gz
  • Upload date:
  • Size: 25.7 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/4.0.1 CPython/3.8.12

File hashes

Hashes for sanfis-0.1.0.tar.gz
Algorithm Hash digest
SHA256 4667c9f19c3003a3e4d647c1c3cec1f1793f5069fb81d3a630b52a43e2a0da1f
MD5 f651b556ceb82eeb34d0b6840c82e023
BLAKE2b-256 6b6737eb3e9d772d580de375645e2a178cbf9e5a91c58bd25c9a97cc580a1844

See more details on using hashes here.

File details

Details for the file sanfis-0.1.0-py3-none-any.whl.

File metadata

  • Download URL: sanfis-0.1.0-py3-none-any.whl
  • Upload date:
  • Size: 24.4 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/4.0.1 CPython/3.8.12

File hashes

Hashes for sanfis-0.1.0-py3-none-any.whl
Algorithm Hash digest
SHA256 77f7a514f02c56ac067e73c31666dce5ae90d823a9dbe9068b36b81930a520e1
MD5 15b6b479b28feff4213659a4d61c556b
BLAKE2b-256 282c7f296ff67ae04ecc590f66f4d71b964d9bab42b30410c73905d276871982

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page