Skip to main content

Neural network framework based on Generalized Additive Models.

Project description

neuralGAM: Interpretable Neural Network Based on Generalized Additive Models

Neural Networks are one of the most popular methods nowadays given their high performance on diverse tasks, such as computer vision, anomaly detection, computer-aided disease detection and diagnosis, or natural language processing. However, it is usually unclear how neural networks make decisions, and current methods that try to provide interpretability to neural networks are not robust enough.

We introduce neuralGAM, a fully explainable neural network based on Generalized Additive Models, which trains a different neural network to estimate the contribution of each feature to the response variable. In contrast to other Neural Additive Models implementations, in neuralGAM neural networks are trained independently leveraging the local scoring and backfitting algorithms to ensure that the Generalized Additive Model converges and it is additive. The resultant model is a highly accurate and explainable deep learning model, which can be used for high-risk AI practices where decision-making should be based on accountable and interpretable algorithms.

neuralGAM is also available as an R package at the CRAN

Installation

To install the neuralGAM package, you can use the following command:

pip install neuralGAM

Usage

Linear Regression

To perform linear regression using the neuralGAM package, follow these steps:

  1. Import the necessary libraries and the NeuralGAM class:

    from neuralGAM.model import NeuralGAM
    import pandas as pd
    from sklearn.model_selection import train_test_split
    from sklearn.metrics import mean_squared_error
    
  2. Load your dataset and split it into training and testing sets:

    data = pd.read_csv('path/to/your/dataset.csv')
    X = data.drop(columns=['target'])
    y = data['target']
    X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)
    
  3. Initialize the NeuralGAM model. You might need to adjust the num_units parameter depending on your data complexity and availability. Each number in the list defines the number of hidden units in each layer of the Deep Neural Network. For simple problems we recommend a single-layer neural network with 1024 units.

    ngam = NeuralGAM(family="gaussian", num_units=[1024], learning_rate=0.00053)
    
  4. Fit the model to the training data:

    ngam.fit(X_train=X_train, y_train=y_train, max_iter_ls=10, bf_threshold=1e-5, ls_threshold=0.01, max_iter_backfitting=10, parallel=True)
    
  5. Make predictions on the test data and compute the mean squared error:

    y_pred = ngam.predict(X_test, type="response")
    pred_err = mean_squared_error(y_test, y_pred)
    print(f"MSE in the test set = {pred_err}")
    
  6. Plot the partial dependencies:

    from neuralGAM.plot import plot_partial_dependencies
    import matplotlib.pyplot as plt
    
    plt.style.use('seaborn-v0_8')
    plot_partial_dependencies(x=X_train, fs=ngam.feature_contributions, title="Estimated Training Partial Effects")
    fs_test_est = ngam.predict(X_test, type="terms")
    plot_partial_dependencies(x=X_test, fs=fs_test_est, title="Estimated Test Partial Effects")
    

Logistic Regression

To perform logistic regression using the neuralGAM package, follow these steps:

  1. Import the necessary libraries and the NeuralGAM class:

    from neuralGAM.model import NeuralGAM
    import pandas as pd
    from sklearn.model_selection import train_test_split
    from sklearn.metrics import accuracy_score
    
  2. Load your dataset and split it into training and testing sets:

    data = pd.read_csv('path/to/your/dataset.csv')
    X = data.drop(columns=['target'])
    y = data['target']
    X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)
    
  3. Initialize the NeuralGAM model. You might need to adjust the num_units parameter depending on your data complexity and availability. Each number in the list defines the number of hidden units in each layer of the Deep Neural Network. For simple problems we recommend a single-layer neural network with 1024 units. If you want to force a linear fit for a specific covariate, you can do so using the linear_terms parameter:

    ngam = NeuralGAM(family="binomial", num_units=[1024], learning_rate=0.00053, linear_terms=[1])
    
  4. Fit the model to the training data:

    ngam.fit(X_train=X_train, y_train=y_train, max_iter_ls=10, bf_threshold=1e-5, ls_threshold=0.01, max_iter_backfitting=10, parallel=True)
    
  5. Make predictions on the test data and compute the accuracy:

    y_pred = ngam.predict(X_test, type="response")  # get predicted probabilities
    y_pred_class = (y_pred > 0.5).astype(int)   
    accuracy = accuracy_score(y_test, y_pred_class) # assuming y_test is in the discrete set {0,1}
    print(f"Accuracy in the test set = {accuracy}")
    
  6. Plot the partial dependencies:

    from neuralGAM.plot import plot_partial_dependencies
    import matplotlib.pyplot as plt
    
    plt.style.use('seaborn-v0_8')
    plot_partial_dependencies(x=X_train, fs=ngam.feature_contributions, title="Estimated Training Partial Effects")
    fs_test_est = ngam.predict(X_test, type="terms")
    plot_partial_dependencies(x=X_test, fs=fs_test_est, title="Estimated Test Partial Effects")
    

Examples

You can find detailed examples for both linear and logistic regression in the examples folder. These examples are provided as Jupyter notebooks:

  • Linear Regression Example
  • Logistic Regression Example

Citation

If you use neuralGAM in your research, please cite the following paper:

Ortega-Fernandez, I., Sestelo, M. & Villanueva, N.M. Explainable generalized additive neural networks with independent neural network training. Stat Comput 34, 6 (2024). https://doi.org/10.1007/s11222-023-10320-5

@article{ortega2024explainable,
  title={Explainable generalized additive neural networks with independent neural network training},
  author={Ortega-Fernandez, Ines and Sestelo, Marta and Villanueva, Nora M},
  journal={Statistics and Computing},
  volume={34},
  number={1},
  pages={6},
  year={2024},
  publisher={Springer}
}

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

neuralgam-1.0.1.tar.gz (29.7 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

neuralgam-1.0.1-py3-none-any.whl (32.7 kB view details)

Uploaded Python 3

File details

Details for the file neuralgam-1.0.1.tar.gz.

File metadata

  • Download URL: neuralgam-1.0.1.tar.gz
  • Upload date:
  • Size: 29.7 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.1.0 CPython/3.9.21

File hashes

Hashes for neuralgam-1.0.1.tar.gz
Algorithm Hash digest
SHA256 9d2214cd4c9fa4e52ced1b350b20b796a595283fba1694597486675402134772
MD5 6f95c02e23c10e39d46a4f75c4b95745
BLAKE2b-256 e9943fde496e0865792da062cec5f394dee7fcf91c802081371b873d99157728

See more details on using hashes here.

File details

Details for the file neuralgam-1.0.1-py3-none-any.whl.

File metadata

  • Download URL: neuralgam-1.0.1-py3-none-any.whl
  • Upload date:
  • Size: 32.7 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.1.0 CPython/3.9.21

File hashes

Hashes for neuralgam-1.0.1-py3-none-any.whl
Algorithm Hash digest
SHA256 99b90a406a16b25a71aaa73b02c7e8ffc0fe3d76314bbb8975173942d0426773
MD5 ae3e5c495dcf649d4a035c182a295e8c
BLAKE2b-256 77e82b7517a8e01bc63aac69149177d6e00dd273050ec580f8fd3a64ec4de450

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page