Skip to main content

A lightweight toolkit to explain PyTorch vision models with gradient and perturbation-based XAI.

Project description

XAIToolkit — Explain any PyTorch CNN in minutes

A lightweight, batteries-included toolkit to explain image-classification models (your own CNN, torchvision models, or timm models) using:

  • Gradient-based XAI: Saliency, SmoothGrad, Integrated Gradients
  • Model-agnostic XAI: RISE, Occlusion
  • Local surrogate XAI: LIME-Stratified (superpixels) (stable neighborhood sampling)
  • Region / tree-style XAI: Axis-aligned SHAP-like attributions (rectangle partitioning)

This repo starts with a teaching-first notebook and also ships as a small Python package + CLI.

Quick start

1) Install (editable for development)

pip install -e .

2) Run the notebook

Open:

  • notebooks/01_grad_based_xai.ipynb

It loads a CNN, predicts Top-5 classes, and visualizes three gradient-based explainers in a 2×3 figure.

3) Explain an image from the CLI

Torchvision model:

xai-explain --image assets/cat_dog.jpg --model tv:resnet50 --methods saliency smoothgrad ig --outdir outputs

timm model (pretrained):

xai-explain --image assets/flamingo.jpg --model timm:swin_tiny_patch4_window7_224 --methods rise lime_strat shap_axis --outdir outputs

Your checkpoint + architecture:

xai-explain --image assets/cat_dog.jpg --model ckpt:checkpoints/best.pt --arch tv:resnet50 --methods rise occlusion --outdir outputs

What gets saved

For each method:

  • outputs/<method>_heatmap.png
  • outputs/<method>_overlay.png

Plus outputs/original.png.

Methods

Gradient-based

  • saliency: ∂logit/∂input (absolute, channel-mean)
  • smoothgrad: noise-averaged saliency
  • ig: Integrated Gradients (manual implementation, no Captum)

Model-agnostic

  • rise: Randomized Input Sampling for Explanation (RISE)
  • occlusion: Sliding-window occlusion sensitivity

Surrogate / region-based

  • lime_strat: LIME Image using stratified sampling of the neighborhood (bins on model output)
  • shap_axis: Axis-aligned SHAP-like attributions using hierarchical rectangle splits

Each method file includes references and canonical links at the top.

Bring your own model

You can load models in four ways:

  • tv:<name> — torchvision, e.g. tv:resnet50
  • timm:<name> — timm pretrained models
  • ckpt:<path> + --arch tv:<name>|timm:<name> — load checkpoint into a known architecture
  • py:<file.py>:<factory_fn> — load a custom model factory that returns torch.nn.Module

Project layout

  • src/xaitoolkit/ — package code
  • scripts/ — small wrappers (CLI lives here)
  • notebooks/ — teaching notebook(s)
  • assets/ — demo images
  • outputs/ — generated artifacts (gitignored)

Citation & credits

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

xaitoolkit-0.1.0.tar.gz (13.6 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

xaitoolkit-0.1.0-py3-none-any.whl (17.3 kB view details)

Uploaded Python 3

File details

Details for the file xaitoolkit-0.1.0.tar.gz.

File metadata

  • Download URL: xaitoolkit-0.1.0.tar.gz
  • Upload date:
  • Size: 13.6 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.9.0

File hashes

Hashes for xaitoolkit-0.1.0.tar.gz
Algorithm Hash digest
SHA256 58524beaec5f4663897745b2c8937d6cd0b927692ae821422c70636468c1d237
MD5 4a107e1371ca771b1e263f9de31d2ee8
BLAKE2b-256 4f258b2d04fc2d1448a2821fa13c533955912e857f243367f42ad6b36c84f1b0

See more details on using hashes here.

File details

Details for the file xaitoolkit-0.1.0-py3-none-any.whl.

File metadata

  • Download URL: xaitoolkit-0.1.0-py3-none-any.whl
  • Upload date:
  • Size: 17.3 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.9.0

File hashes

Hashes for xaitoolkit-0.1.0-py3-none-any.whl
Algorithm Hash digest
SHA256 3f633a80fdc08933a8e697d8703509817474c9cc46326e2c74e5fc706662ab0d
MD5 4cb5ef91a4cb1573549ce02716cdf391
BLAKE2b-256 bf3fd5d7e4dfffa321219ebc8e578ee01e06c2dd92657928eeee4a9d11b22276

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page