Skip to main content

Variance-based feature importance for Neural Networks using callbacks for Keras and PyTorch

Project description

neural-feature-importance

PyPI version Python 3.10+ License: MIT

Variance-based feature importance for deep learning models.

neural-feature-importance implements the method described in CR de Sá, Variance-based Feature Importance in Neural Networks. It tracks the variance of the first trainable layer using Welford's algorithm and produces normalized importance scores for each feature.

Features

  • VarianceImportanceKeras — drop-in callback for TensorFlow/Keras models
  • VarianceImportanceTorch — helper class for PyTorch training loops
  • MetricThreshold — early-stopping callback based on a monitored metric
  • Example scripts to reproduce the experiments from the paper

Installation

pip install "neural-feature-importance[tensorflow]"  # for Keras
pip install "neural-feature-importance[torch]"       # for PyTorch

Retrieve the package version via:

from neural_feature_importance import __version__
print(__version__)

Quick start

Keras

from neural_feature_importance import VarianceImportanceKeras
from neural_feature_importance.utils import MetricThreshold

viann = VarianceImportanceKeras()
monitor = MetricThreshold(monitor="val_accuracy", threshold=0.95)
model.fit(X, y, validation_split=0.05, epochs=30, callbacks=[viann, monitor])
print(viann.feature_importances_)

PyTorch

from neural_feature_importance import VarianceImportanceTorch

tracker = VarianceImportanceTorch(model)
tracker.on_train_begin()
for epoch in range(num_epochs):
    train_one_epoch(model, optimizer, dataloader)
    tracker.on_epoch_end()
tracker.on_train_end()
print(tracker.feature_importances_)

Example scripts

Run scripts/compare_feature_importance.py to train a small network on the Iris dataset and compare the scores with a random forest baseline:

python compare_feature_importance.py

Run scripts/full_experiment.py to reproduce the experiments from the paper:

python full_experiment.py

Convolutional models

To compute importances for convolutional networks, use ConvVarianceImportanceKeras from neural_feature_importance.conv_callbacks. scripts/conv_visualization_example.py trains small Conv2D models on the MNIST and scikit‑learn digits datasets and displays per-filter heatmaps. An equivalent notebook is available in notebooks/conv_visualization_example.ipynb:

python scripts/conv_visualization_example.py

Embedding layers

To compute token importances from embedding weights, use EmbeddingVarianceImportanceKeras or EmbeddingVarianceImportanceTorch from neural_feature_importance.embedding_callbacks. Run scripts/token_importance_topk_example.py to train a small text classifier on IMDB and display the most important tokens. A matching notebook lives in notebooks/token_importance_topk_example.ipynb:

python scripts/token_importance_topk_example.py

Development

After making changes, run the following checks:

python -m py_compile neural_feature_importance/callbacks.py
python -m py_compile "variance-based feature importance in artificial neural networks.ipynb" 2>&1 | head
jupyter nbconvert --to script "variance-based feature importance in artificial neural networks.ipynb" --stdout | head

Citation

If you use this package in your research, please cite:

@inproceedings{DBLP:conf/dis/Sa19,
  author       = {Cl{\'a}udio Rebelo de S{\'a}},
  editor       = {Petra Kralj Novak and
                  Tomislav Smuc and
                  Saso Dzeroski},
  title        = {Variance-Based Feature Importance in Neural Networks},
  booktitle    = {Discovery Science - 22nd International Conference, {DS} 2019, Split,
                  Croatia, October 28-30, 2019, Proceedings},
  series       = {Lecture Notes in Computer Science},
  volume       = {11828},
  pages        = {306--315},
  publisher    = {Springer},
  year         = {2019},
  url          = {https://doi.org/10.1007/978-3-030-33778-0\_24},
  doi          = {10.1007/978-3-030-33778-0\_24},
  timestamp    = {Thu, 07 Nov 2019 09:20:36 +0100},
  biburl       = {https://dblp.org/rec/conf/dis/Sa19.bib},
  bibsource    = {dblp computer science bibliography, https://dblp.org}
}

We appreciate citations as they help the community discover this work.

License

This project is licensed under the MIT License.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

neural_feature_importance-0.9.1.tar.gz (21.6 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

neural_feature_importance-0.9.1-py3-none-any.whl (10.1 kB view details)

Uploaded Python 3

File details

Details for the file neural_feature_importance-0.9.1.tar.gz.

File metadata

File hashes

Hashes for neural_feature_importance-0.9.1.tar.gz
Algorithm Hash digest
SHA256 2fc21e6933f8971908a5e19de65a6c84b341075e47eef7f9e14bb3e3d98989a1
MD5 152c687ebed5680ba01efcfbf137d552
BLAKE2b-256 fbdbf78bac062c1921149447c623ec7e0aca5d39b6e56b8f76ab5f8dd6f6b228

See more details on using hashes here.

File details

Details for the file neural_feature_importance-0.9.1-py3-none-any.whl.

File metadata

File hashes

Hashes for neural_feature_importance-0.9.1-py3-none-any.whl
Algorithm Hash digest
SHA256 3596ca2c663e2d4fda275d34edba7aabe6ebfa97102fe4ae688dcec8d5149e5d
MD5 2e6cd7381daac93440d8d6c304544831
BLAKE2b-256 c69abaf9ccf6f278111f3d15007d77bae68c7b0f58a06ea7950c459e2172b1ef

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page