A lightweight toolkit to explain PyTorch vision models with gradient and perturbation-based XAI.
Project description
XAIToolkit — Explain any PyTorch CNN in minutes
A lightweight, batteries-included toolkit to explain image-classification models (your own CNN, torchvision models, or timm models) using:
- Gradient-based XAI: Saliency, SmoothGrad, Integrated Gradients
- Model-agnostic XAI: RISE, Occlusion
- Local surrogate XAI: LIME-Stratified (superpixels) (stable neighborhood sampling)
- Region / tree-style XAI: Axis-aligned SHAP-like attributions (rectangle partitioning)
This repo starts with a teaching-first notebook and also ships as a small Python package + CLI.
Quick start
1) Install (editable for development)
pip install -e .
2) Run the notebook
Open:
notebooks/01_grad_based_xai.ipynb
It loads a CNN, predicts Top-5 classes, and visualizes three gradient-based explainers in a 2×3 figure.
3) Explain an image from the CLI
Torchvision model:
xai-explain --image assets/cat_dog.jpg --model tv:resnet50 --methods saliency smoothgrad ig --outdir outputs
timm model (pretrained):
xai-explain --image assets/flamingo.jpg --model timm:swin_tiny_patch4_window7_224 --methods rise lime_strat shap_axis --outdir outputs
Your checkpoint + architecture:
xai-explain --image assets/cat_dog.jpg --model ckpt:checkpoints/best.pt --arch tv:resnet50 --methods rise occlusion --outdir outputs
What gets saved
For each method:
outputs/<method>_heatmap.pngoutputs/<method>_overlay.png
Plus outputs/original.png.
Methods
Gradient-based
saliency: ∂logit/∂input (absolute, channel-mean)smoothgrad: noise-averaged saliencyig: Integrated Gradients (manual implementation, no Captum)
Model-agnostic
rise: Randomized Input Sampling for Explanation (RISE)occlusion: Sliding-window occlusion sensitivity
Surrogate / region-based
lime_strat: LIME Image using stratified sampling of the neighborhood (bins on model output)shap_axis: Axis-aligned SHAP-like attributions using hierarchical rectangle splits
Each method file includes references and canonical links at the top.
Bring your own model
You can load models in four ways:
tv:<name>— torchvision, e.g.tv:resnet50timm:<name>— timm pretrained modelsckpt:<path>+--arch tv:<name>|timm:<name>— load checkpoint into a known architecturepy:<file.py>:<factory_fn>— load a custom model factory that returnstorch.nn.Module
Project layout
src/xaitoolkit/— package codescripts/— small wrappers (CLI lives here)notebooks/— teaching notebook(s)assets/— demo imagesoutputs/— generated artifacts (gitignored)
Citation & credits
- ResNet: https://arxiv.org/abs/1512.03385
- Integrated Gradients: https://arxiv.org/abs/1703.01365
- SmoothGrad: https://arxiv.org/abs/1706.03825
- Grad-CAM: https://arxiv.org/abs/1610.02391
- RISE: https://arxiv.org/abs/1806.07421
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file xaitoolkit-0.1.0.tar.gz.
File metadata
- Download URL: xaitoolkit-0.1.0.tar.gz
- Upload date:
- Size: 13.6 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.2.0 CPython/3.9.0
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
58524beaec5f4663897745b2c8937d6cd0b927692ae821422c70636468c1d237
|
|
| MD5 |
4a107e1371ca771b1e263f9de31d2ee8
|
|
| BLAKE2b-256 |
4f258b2d04fc2d1448a2821fa13c533955912e857f243367f42ad6b36c84f1b0
|
File details
Details for the file xaitoolkit-0.1.0-py3-none-any.whl.
File metadata
- Download URL: xaitoolkit-0.1.0-py3-none-any.whl
- Upload date:
- Size: 17.3 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.2.0 CPython/3.9.0
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
3f633a80fdc08933a8e697d8703509817474c9cc46326e2c74e5fc706662ab0d
|
|
| MD5 |
4cb5ef91a4cb1573549ce02716cdf391
|
|
| BLAKE2b-256 |
bf3fd5d7e4dfffa321219ebc8e578ee01e06c2dd92657928eeee4a9d11b22276
|