Skip to main content

NodeGAM - an interpretable deep learning GAM model.

Project description

NODE GAM: Differentiable Generalized Additive Model for Interpretable Deep Learning:

NodeGAM is an interpretable deep learning GAM model proposed in our ICLR 2022 paper: NODE GAM: Differentiable Generalized Additive Model for Interpretable Deep Learning. In short, it trains a GAM model by multi-layer differentiable trees to be accurate, interpretable, and differentiable. See this blog post for an intro, and our documentation website!

Installation

pip install nodegam

The performance and the runtime of the NodeGAM package

We compare NodeGAM with other GAMs (EBM, XGB-GAM), and XGB in 6 datasets. All models use default parameters, so the performance of NodeGAM here is lower than what paper reported. We find NodeGAM often performs better in larger datasets.

3 classification datasets:

Dataset/AUROC Domain N P NodeGAM EBM XGB-GAM XGB
MIMIC-II Medicine 25K 17 0.844 ± 0.018 0.842 ± 0.019 0.833 ± 0.02 0.845 ± 0.019
Adult Finance 33K 14 0.916 ± 0.002 0.927 ± 0.003 0.925 ± 0.002 0.927 ± 0.002
Credit Finance 285K 30 0.989 ± 0.008 0.984 ± 0.007 0.985 ± 0.008 0.984 ± 0.01

3 regression datasets:

Dataset/RMSE Domain N P NodeGAM EBM XGB-GAM XGB
Wine Nature 5K 12 0.705 ± 0.012 0.69 ± 0.011 0.713 ± 0.006 0.682 ± 0.023
Bikeshare Retail 17K 16 57.438 ± 3.899 55.676 ± 0.327 101.093 ± 0.946 45.212 ± 1.254
Year Music 515K 90 9.013 ± 0.004 9.204 ± 0.0 9.257 ± 0.0 9.049 ± 0.0

We also find the run time of our model increases mildly with growing data size due to mini-batch training, while our baselines increase training time much more.

3 classification datasets:

Dataset/Seconds Domain N P NodeGAM EBM XGB-GAM XGB
MIMIC-II Medicine 25K 17 105.0 ± 14.0 6.0 ± 2.0 0.0 ± 1.0 1.0 ± 1.0
Adult Finance 33K 14 196.0 ± 56.0 15.0 ± 8.0 6.0 ± 0.0 1.0 ± 0.0
Credit Finance 285K 30 113.0 ± 36.0 37.0 ± 2.0 26.0 ± 7.0 16.0 ± 2.0

3 regression datasets:

Dataset/Seconds Domain N P NodeGAM EBM XGB-GAM XGB
Wine Nature 5K 12 157.0 ± 86.0 4.0 ± 2.0 0.0 ± 0.0 0.0 ± 0.0
Bikeshare Retail 17K 16 223.0 ± 23.0 15.0 ± 3.0 1.0 ± 1.0 2.0 ± 1.0
Year Music 515K 90 318.0 ± 20.0 501.0 ± 8.0 376.0 ± 1.0 537.0 ± 1.0

Reproducing notebook is here.

See the Table 1 and 2 of our paper for more comparisons.

NodeGAM Training

Sklearn interface

To simply use it on your dataset, just run:

from nodegam.sklearn import NodeGAMClassifier, NodeGAMRegressor

model = NodeGAMClassifier()
model.fit(X, y)

Understand the model:

model.visualize()

or

from nodegam.vis_utils import vis_GAM_effects

vis_GAM_effects({
    'nodegam': model.get_GAM_df(),
})

See the notebooks/toy dataset with nodegam sklearn.ipynb here.

Notebook training

It is useful if you want to customize the NodeGAM training to your PyTorch pipeline. You can find details of the training in this notebook: https://colab.research.google.com/drive/1C_gBoSc1AlQ7VvCXVWiU-7X3YjQZTiZI?usp=sharing

And see more examples under notebooks/

Python file

You can also train a NodeGAM using our main file. To reproduce our results, e.g. NODE-GA2M trained in fold 0 (total 5 folds) of bikeshare, you can run

hparams="resources/best_hparams/node_ga2m/0519_f0_best_bikeshare_GAM_ga2m_s83_nl4_nt125_td1_d6_od0.0_ld0.3_cs0.5_lr0.01_lo0_la0.0_pt0_pr0_mn0_ol0_ll1"
python main.py \ 
--name 0603_best_bikeshare_f0 \ 
--load_from_hparams ${hparams}
--fold 0

The models will be stored in logs/0603_best_bikeshare_f0/. And the results including test/val error are stored in results/bikeshare_GAM.csv

We provide the best hyperparmeters we found in best_hparams/.

Baseline GAMs

We also provide code to train other GAMs for comparisons such as:

Sklearn interface

To train baselines on your dataset, just run:

from nodegam.gams.MySpline import MySplineLogisticGAM, MySplineGAM
from nodegam.gams.MyEBM import MyExplainableBoostingClassifier, MyExplainableBoostingRegressor
from nodegam.gams.MyXGB import MyXGBOnehotClassifier, MyXGBOnehotRegressor
from nodegam.gams.MyBagging import MyBaggingClassifier, MyBaggingRegressor


ebm = MyExplainableBoostingClassifier()
ebm.fit(X, y)

spline = MySplineLogisticGAM()
bagged_spline = MyBaggingClassifier(base_estimator=spline, n_estimators=3)
bagged_spline.fit(X, y)

xgb_gam = MyXGBOnehotClassifier()
bagged_xgb = MyBaggingClassifier(base_estimator=xgb_gam, n_estimators=3)
bagged_xgb.fit(X, y)

Understand the models:

from nodegam.vis_utils import vis_GAM_effects

fig, ax = vis_GAM_effects(
    all_dfs={
        'EBM': ebm.get_GAM_df(),
        'Spline': bagged_spline.get_GAM_df(),
        'XGB-GAM': bagged_xgb.get_GAM_df(),
    },
)

See the notebooks/toy dataset with nodegam sklearn.ipynb here for an example.

Python file

You can train Spline, EBM, and XGB-GAM by the following commands.

python baselines.py --name 0603_bikeshare_spline_f0 --fold 0 --model_name spline --dataset bikeshare
python baselines.py --name 0603_bikeshare_ebm_f0 --fold 0 --model_name ebm-o100-i100 --dataset bikeshare
python baselines.py --name 0603_bikeshare_xgb-o5_f0 --fold 0 --model_name xgb-o5 --dataset bikeshare

The result is shown in results/baselines_bikeshare.csv and the model is stored in logs/{name}/.

Visualization of the trained models stored under logs/

To visualize and compare multiple trained GAM models stored under logs/, run this in a notebook:

from nodegam.vis_utils import vis_GAM_effects
from nodegam.utils import average_GAMs

df_dict = {
    'node_ga2m': average_GAMs([
        '0603_best_bikeshare_f0',
        '0603_best_bikeshare_f1',
    ], max_n_bins=256),
    'ebm': average_GAMs([
        '0603_bikeshare_ebm_f0',
        '0603_bikeshare_ebm_f1',
    ], max_n_bins=256),
}

fig, ax = vis_GAM_effects(df_dict)

To avoid long computations, when visualizing we specify max_n_bins to do quantile binning of each feature to have at most 256 bins (default). The average_GAMs take average of multiple runs of GAMs to produce mean and stdev on the GAM graphs.

See notebooks/bikeshare visualization.ipynb here which we show bikeshare graphs for all GAMs (NODE-GA2M, NODE-GAM, EBM and Spline) in our paper.

Citations

If you find the code useful, please cite:

@inproceedings{chang2021node,
  title={NODE-GAM: Neural Generalized Additive Model for Interpretable Deep Learning},
  author={Chang, Chun-Hao and Caruana, Rich and Goldenberg, Anna},
  booktitle={International Conference on Learning Representations},
  year={2022}
}

@inproceedings{chang2021interpretable,
  title={How interpretable and trustworthy are gams?},
  author={Chang, Chun-Hao and Tan, Sarah and Lengerich, Ben and Goldenberg, Anna and Caruana, Rich},
  booktitle={Proceedings of the 27th ACM SIGKDD Conference on Knowledge Discovery \& Data Mining},
  pages={95--105},
  year={2021}
}

Contributing

All content in this repository is licensed under the MIT license.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

nodegam-0.3.0.tar.gz (67.6 kB view details)

Uploaded Source

Built Distribution

nodegam-0.3.0-py3-none-any.whl (77.1 kB view details)

Uploaded Python 3

File details

Details for the file nodegam-0.3.0.tar.gz.

File metadata

  • Download URL: nodegam-0.3.0.tar.gz
  • Upload date:
  • Size: 67.6 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/3.8.0 pkginfo/1.8.3 readme-renderer/34.0 requests/2.27.1 requests-toolbelt/0.9.1 urllib3/1.26.9 tqdm/4.43.0 importlib-metadata/4.8.1 keyring/23.4.1 rfc3986/1.5.0 colorama/0.4.5 CPython/3.6.10

File hashes

Hashes for nodegam-0.3.0.tar.gz
Algorithm Hash digest
SHA256 7753dfc2457ec3c8f345898764678f85264d7f49c9c2afea2d019e22dc286921
MD5 fd998f12b1f52a23e831b0685e3a4502
BLAKE2b-256 759f4470a096773f9729e2fa3388c53253bacc0d620133a412b3ab1f90850e0d

See more details on using hashes here.

File details

Details for the file nodegam-0.3.0-py3-none-any.whl.

File metadata

  • Download URL: nodegam-0.3.0-py3-none-any.whl
  • Upload date:
  • Size: 77.1 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/3.8.0 pkginfo/1.8.3 readme-renderer/34.0 requests/2.27.1 requests-toolbelt/0.9.1 urllib3/1.26.9 tqdm/4.43.0 importlib-metadata/4.8.1 keyring/23.4.1 rfc3986/1.5.0 colorama/0.4.5 CPython/3.6.10

File hashes

Hashes for nodegam-0.3.0-py3-none-any.whl
Algorithm Hash digest
SHA256 f95e04869b69183ccdc6689d142973cc690f5fb712add0b9a5ffb55506df8358
MD5 1d7115b15b46cdbf04511ea38bd24cc2
BLAKE2b-256 43b8d1c12b8d0e8ef891707447dd7b2f234543f7f36be6e0c031c1a4c11dc445

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page