SIDU: SImilarity Difference and Uniqueness method for explainable AI
Project description
pytorch-sidu
SIDU: SImilarity Difference and Uniqueness method for explainable AI from the original paper
- Pytorch implementation of the SIDU method.
- Simple interface for loading pretrained models by specifying one of the following string name
- Clear interface for generating saliency maps
Some examples made with VGG19 on Caltech-101 dataset:
Installation
pip install pytorch-sidu
Usage
Load models from the pretrainde ones available in pytorch
import pytorch_sidu as sidu
weights = "ResNet18_Weights.IMAGENET1K_V1"
backbone = sidu.load_torch_backbone(weights)
After instantianting your model, generate saliency maps from Dataloader
data_loader = <your dataloader>
image, _ = next(iter(data_loader))
saliency_maps = sidu.sidu(backbone, image)
A complete example on CIFAR-10
import torch
import torchvision
from matplotlib import pyplot as plt
import pytorch_sidu as sidu
transform = torchvision.transforms.Compose([torchvision.transforms.Resize((256, 256)), torchvision.transforms.ToTensor()])
data_loader = torch.utils.data.DataLoader(
torchvision.datasets.CIFAR10(root='./data', download=True, transform=transform), batch_size=2)
weights = "ResNet18_Weights.IMAGENET1K_V1"
backbone = sidu.load_torch_backbone(weights)
for image, _ in data_loader:
saliency_maps = sidu.sidu(backbone, image)
image, saliency_maps = image.cpu(), saliency_maps.cpu()
for j in range(len(image)):
plt.figure(figsize=(5, 2.5))
plt.subplot(1, 2, 1)
plt.imshow(image[j].permute(1, 2, 0))
plt.axis('off')
plt.subplot(1, 2, 2)
plt.imshow(image[j].permute(1, 2, 0))
plt.imshow(saliency_maps[j].squeeze().detach().numpy(), cmap='jet', alpha=0.4)
plt.axis('off')
plt.show()
upcoming features:
- integration of xai metrics
- make methods work on both single images and dataloaders
- adding
device
flag to sidu function to allow device selection
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
pytorch-sidu-1.0.13.tar.gz
(17.3 kB
view hashes)
Built Distribution
Close
Hashes for pytorch_sidu-1.0.13-py3-none-any.whl
Algorithm | Hash digest | |
---|---|---|
SHA256 | fcb15af5e7c29efa39fa058614c81f7e4725e3014cd8ad5cbd2e861831e192f5 |
|
MD5 | aa1d77c0711274c716ca036d853a072f |
|
BLAKE2b-256 | cefb4f64c4281186d33754c972c22deb2cebb318807cdb520e91bc54ce47889e |