Neural network framework based on Generalized Additive Models.
Project description
neuralGAM: Interpretable Neural Network Based on Generalized Additive Models
Neural Networks are one of the most popular methods nowadays given their high performance on diverse tasks, such as computer vision, anomaly detection, computer-aided disease detection and diagnosis, or natural language processing. However, it is usually unclear how neural networks make decisions, and current methods that try to provide interpretability to neural networks are not robust enough.
We introduce neuralGAM, a fully explainable neural network based on Generalized Additive Models, which trains a different neural network to estimate the contribution of each feature to the response variable. In contrast to other Neural Additive Models implementations, in neuralGAM neural networks are trained independently leveraging the local scoring and backfitting algorithms to ensure that the Generalized Additive Model converges and it is additive. The resultant model is a highly accurate and explainable deep learning model, which can be used for high-risk AI practices where decision-making should be based on accountable and interpretable algorithms.
neuralGAM is also available as an R package at the CRAN
Installation
To install the neuralGAM package, you can use the following command:
pip install neuralGAM
Usage
Linear Regression
To perform linear regression using the neuralGAM package, follow these steps:
-
Import the necessary libraries and the NeuralGAM class:
from neuralGAM.model import NeuralGAM import pandas as pd from sklearn.model_selection import train_test_split from sklearn.metrics import mean_squared_error
-
Load your dataset and split it into training and testing sets:
data = pd.read_csv('path/to/your/dataset.csv') X = data.drop(columns=['target']) y = data['target'] X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)
-
Initialize the NeuralGAM model. You might need to adjust the
num_unitsparameter depending on your data complexity and availability. Each number in the list defines the number of hidden units in each layer of the Deep Neural Network. For simple problems we recommend a single-layer neural network with 1024 units.ngam = NeuralGAM(family="gaussian", num_units=[1024], learning_rate=0.00053)
-
Fit the model to the training data:
ngam.fit(X_train=X_train, y_train=y_train, max_iter_ls=10, bf_threshold=1e-5, ls_threshold=0.01, max_iter_backfitting=10, parallel=True)
-
Make predictions on the test data and compute the mean squared error:
y_pred = ngam.predict(X_test, type="response") pred_err = mean_squared_error(y_test, y_pred) print(f"MSE in the test set = {pred_err}")
-
Plot the partial dependencies:
from neuralGAM.plot import plot_partial_dependencies import matplotlib.pyplot as plt plt.style.use('seaborn-v0_8') plot_partial_dependencies(x=X_train, fs=ngam.feature_contributions, title="Estimated Training Partial Effects") fs_test_est = ngam.predict(X_test, type="terms") plot_partial_dependencies(x=X_test, fs=fs_test_est, title="Estimated Test Partial Effects")
Logistic Regression
To perform logistic regression using the neuralGAM package, follow these steps:
-
Import the necessary libraries and the NeuralGAM class:
from neuralGAM.model import NeuralGAM import pandas as pd from sklearn.model_selection import train_test_split from sklearn.metrics import accuracy_score
-
Load your dataset and split it into training and testing sets:
data = pd.read_csv('path/to/your/dataset.csv') X = data.drop(columns=['target']) y = data['target'] X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)
-
Initialize the NeuralGAM model. You might need to adjust the
num_unitsparameter depending on your data complexity and availability. Each number in the list defines the number of hidden units in each layer of the Deep Neural Network. For simple problems we recommend a single-layer neural network with 1024 units. If you want to force a linear fit for a specific covariate, you can do so using thelinear_termsparameter:ngam = NeuralGAM(family="binomial", num_units=[1024], learning_rate=0.00053, linear_terms=[1])
-
Fit the model to the training data:
ngam.fit(X_train=X_train, y_train=y_train, max_iter_ls=10, bf_threshold=1e-5, ls_threshold=0.01, max_iter_backfitting=10, parallel=True)
-
Make predictions on the test data and compute the accuracy:
y_pred = ngam.predict(X_test, type="response") # get predicted probabilities y_pred_class = (y_pred > 0.5).astype(int) accuracy = accuracy_score(y_test, y_pred_class) # assuming y_test is in the discrete set {0,1} print(f"Accuracy in the test set = {accuracy}")
-
Plot the partial dependencies:
from neuralGAM.plot import plot_partial_dependencies import matplotlib.pyplot as plt plt.style.use('seaborn-v0_8') plot_partial_dependencies(x=X_train, fs=ngam.feature_contributions, title="Estimated Training Partial Effects") fs_test_est = ngam.predict(X_test, type="terms") plot_partial_dependencies(x=X_test, fs=fs_test_est, title="Estimated Test Partial Effects")
Examples
You can find detailed examples for both linear and logistic regression in the examples folder. These examples are provided as Jupyter notebooks:
- Linear Regression Example
- Logistic Regression Example
Citation
If you use neuralGAM in your research, please cite the following paper:
Ortega-Fernandez, I., Sestelo, M. & Villanueva, N.M. Explainable generalized additive neural networks with independent neural network training. Stat Comput 34, 6 (2024). https://doi.org/10.1007/s11222-023-10320-5
@article{ortega2024explainable,
title={Explainable generalized additive neural networks with independent neural network training},
author={Ortega-Fernandez, Ines and Sestelo, Marta and Villanueva, Nora M},
journal={Statistics and Computing},
volume={34},
number={1},
pages={6},
year={2024},
publisher={Springer}
}
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file neuralgam-1.0.1.tar.gz.
File metadata
- Download URL: neuralgam-1.0.1.tar.gz
- Upload date:
- Size: 29.7 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.1.0 CPython/3.9.21
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
9d2214cd4c9fa4e52ced1b350b20b796a595283fba1694597486675402134772
|
|
| MD5 |
6f95c02e23c10e39d46a4f75c4b95745
|
|
| BLAKE2b-256 |
e9943fde496e0865792da062cec5f394dee7fcf91c802081371b873d99157728
|
File details
Details for the file neuralgam-1.0.1-py3-none-any.whl.
File metadata
- Download URL: neuralgam-1.0.1-py3-none-any.whl
- Upload date:
- Size: 32.7 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.1.0 CPython/3.9.21
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
99b90a406a16b25a71aaa73b02c7e8ffc0fe3d76314bbb8975173942d0426773
|
|
| MD5 |
ae3e5c495dcf649d4a035c182a295e8c
|
|
| BLAKE2b-256 |
77e82b7517a8e01bc63aac69149177d6e00dd273050ec580f8fd3a64ec4de450
|