SIDU: SImilarity Difference and Uniqueness method for explainable AI
Project description
pytorch-sidu
SIDU: SImilarity Difference and Uniqueness method for explainable AI from the original paper
- Pytorch implementation of the SIDU method.
- Simple interface for loading pretrained models by specifying one of the following string name
- Clear interface for generating saliency maps
Some examples made with VGG19 on Caltech-101 dataset:
Installation
pip install pytorch-sidu
Usage
Load models from the pretrainde ones available in pytorch
import pytorch_sidu as sidu
model_name = 'ResNet34_Weights.IMAGENET1K_V1'
model = load_torch_model_by_string(model_name)
After instantianting your model, generate saliency maps from Dataloader
data_loader = <your dataloader>
target_layer = 'layer4.2.conv2'
image, _ = next(iter(data_loader))
saliency_maps = sidu.sidu(model, target_layer, image)
A complete example on CIFAR-10
import torch
import torchvision
from matplotlib import pyplot as plt
import pytorch_sidu as sidu
transform = torchvision.transforms.Compose([torchvision.transforms.Resize((224, 224)), torchvision.transforms.ToTensor()])
data_loader = torch.utils.data.DataLoader(
torchvision.datasets.CIFAR10(root='./data', download=True, transform=transform), batch_size=2)
target_layer = 'layer4.2.conv2'
model_name = 'ResNet34_Weights.IMAGENET1K_V1'
model = sidu.load_torch_model_by_string(model_name)
for image, _ in data_loader:
saliency_maps = sidu.sidu(model, target_layer, image)
image, saliency_maps = image.cpu(), saliency_maps.cpu()
for j in range(len(image)):
plt.figure(figsize=(5, 2.5))
plt.subplot(1, 2, 1)
plt.imshow(image[j].permute(1, 2, 0))
plt.axis('off')
plt.subplot(1, 2, 2)
plt.imshow(image[j].permute(1, 2, 0))
plt.imshow(saliency_maps[j].squeeze().detach().numpy(), cmap='jet', alpha=0.4)
plt.axis('off')
plt.show()
upcoming features:
- integration of xai metrics
- make methods work on both single images and dataloaders
- adding
device
flag to sidu function to allow device selection
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
pytorch-sidu-1.1.3.tar.gz
(17.4 kB
view hashes)
Built Distribution
Close
Hashes for pytorch_sidu-1.1.3-py3-none-any.whl
Algorithm | Hash digest | |
---|---|---|
SHA256 | f1c16242b80d4423c8cbbe77561c23c11730bf1ffa8a421c986387cc7e454a52 |
|
MD5 | fc07938465dde53989e8798006318fa6 |
|
BLAKE2b-256 | 35eae0a155f4dc484bb4c91d0a5a032551cf34b3839f9a0e67f09d5e20de126b |