Skip to main content

A lightweight toolkit to explain PyTorch vision models with gradient and perturbation-based XAI.

Project description

XAIToolkit — Explain any PyTorch CNN in minutes

A lightweight, batteries-included toolkit to explain image-classification models (your own CNN, torchvision models, or timm models) using:

  • Gradient-based XAI: Saliency, SmoothGrad, Integrated Gradients
  • Model-agnostic XAI: RISE, Occlusion
  • Local surrogate XAI: LIME-Stratified (superpixels) (stable neighborhood sampling)
  • Region / tree-style XAI: Axis-aligned SHAP-like attributions (rectangle partitioning)

This repo starts with a teaching-first notebook and also ships as a small Python package + CLI.

Quick start

1) Install (editable for development)

pip install -e .

2) Run the notebook

Open:

  • notebooks/01_grad_based_xai.ipynb

It loads a CNN, predicts Top-5 classes, and visualizes three gradient-based explainers in a 2×3 figure.

3) Explain an image from the CLI

Torchvision model:

xai-explain --image assets/cat_dog.jpg --model tv:resnet50 --methods saliency smoothgrad ig --outdir outputs

timm model (pretrained):

xai-explain --image assets/flamingo.jpg --model timm:swin_tiny_patch4_window7_224 --methods rise lime_strat shap_axis --outdir outputs

Your checkpoint + architecture:

xai-explain --image assets/cat_dog.jpg --model ckpt:checkpoints/best.pt --arch tv:resnet50 --methods rise occlusion --outdir outputs

What gets saved

For each method:

  • outputs/<method>_heatmap.png
  • outputs/<method>_overlay.png

Plus outputs/original.png.

Methods

Gradient-based

  • saliency: ∂logit/∂input (absolute, channel-mean)
  • smoothgrad: noise-averaged saliency
  • ig: Integrated Gradients (manual implementation, no Captum)

Model-agnostic

  • rise: Randomized Input Sampling for Explanation (RISE)
  • occlusion: Sliding-window occlusion sensitivity

Surrogate / region-based

  • lime_strat: LIME Image using stratified sampling of the neighborhood (bins on model output)
  • shap_axis: Axis-aligned SHAP-like attributions using hierarchical rectangle splits

Each method file includes references and canonical links at the top.

Bring your own model

You can load models in four ways:

  • tv:<name> — torchvision, e.g. tv:resnet50
  • timm:<name> — timm pretrained models
  • ckpt:<path> + --arch tv:<name>|timm:<name> — load checkpoint into a known architecture
  • py:<file.py>:<factory_fn> — load a custom model factory that returns torch.nn.Module

Project layout

  • src/xaitoolkit/ — package code
  • scripts/ — small wrappers (CLI lives here)
  • notebooks/ — teaching notebook(s)
  • assets/ — demo images
  • outputs/ — generated artifacts (gitignored)

Citation & credits

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

xaitoolkit-0.1.1.tar.gz (14.5 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

xaitoolkit-0.1.1-py3-none-any.whl (18.4 kB view details)

Uploaded Python 3

File details

Details for the file xaitoolkit-0.1.1.tar.gz.

File metadata

  • Download URL: xaitoolkit-0.1.1.tar.gz
  • Upload date:
  • Size: 14.5 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.9.0

File hashes

Hashes for xaitoolkit-0.1.1.tar.gz
Algorithm Hash digest
SHA256 18eccec47567998237925d13d2c4b68d12d743802842311598a5c091ff1a7b7b
MD5 e66910ed7892c9d71e2545a2d1180fe3
BLAKE2b-256 4cdca33ef6006b2817cf35e58784f973b50abe3cece2a1ba13abf5ace7ef51e9

See more details on using hashes here.

File details

Details for the file xaitoolkit-0.1.1-py3-none-any.whl.

File metadata

  • Download URL: xaitoolkit-0.1.1-py3-none-any.whl
  • Upload date:
  • Size: 18.4 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.9.0

File hashes

Hashes for xaitoolkit-0.1.1-py3-none-any.whl
Algorithm Hash digest
SHA256 523ab65fb76ff052b2cc84a491993ef481019a330d2e6a3ee77be7b327a527bb
MD5 2a092ad81f1209a44b5fc38bb18c716a
BLAKE2b-256 fa41a286116c817b3ef42cdef7a11f42cfa62912feaadadf95c978a5fde0a82e

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page