SIDU: SImilarity Difference and Uniqueness method for explainable AI
Project description
pytorch-sidu
SIDU: SImilarity Difference and Uniqueness method for explainable AI from the original paper
- Pytorch implementation of the SIDU method.
- Simple interface for loading pretrained models by specifying one of the following string name
- Clear interface for generating saliency maps
Some examples made with VGG19 on Caltech-101 dataset:
Installation
pip install pytorch-sidu
Usage
Load models from the pretrainde ones available in pytorch
import pytorch_sidu as sidu
weights = "ResNet18_Weights.IMAGENET1K_V1"
backbone = sidu.load_torch_backbone(weights)
After instantianting your model, generate saliency maps from Dataloader
data_loader = <your dataloader>
image, _ = next(iter(data_loader))
saliency_maps = sidu.sidu(backbone, image)
A complete example on CIFAR-10
import torch
import torchvision
from matplotlib import pyplot as plt
import pytorch_sidu as sidu
transform = torchvision.transforms.Compose([torchvision.transforms.Resize((256, 256)), torchvision.transforms.ToTensor()])
data_loader = torch.utils.data.DataLoader(
torchvision.datasets.CIFAR10(root='./data', download=True, transform=transform), batch_size=2)
weights = "ResNet18_Weights.IMAGENET1K_V1"
backbone = sidu.load_torch_backbone(weights)
for image, _ in data_loader:
saliency_maps = sidu.sidu(backbone, image)
image, saliency_maps = image.cpu(), saliency_maps.cpu()
for j in range(len(image)):
plt.figure(figsize=(5, 2.5))
plt.subplot(1, 2, 1)
plt.imshow(image[j].permute(1, 2, 0))
plt.axis('off')
plt.subplot(1, 2, 2)
plt.imshow(image[j].permute(1, 2, 0))
plt.imshow(saliency_maps[j].squeeze().detach().numpy(), cmap='jet', alpha=0.4)
plt.axis('off')
plt.show()
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
pytorch-sidu-1.0.1.tar.gz
(16.5 kB
view hashes)
Built Distribution
Close
Hashes for pytorch_sidu-1.0.1-py3-none-any.whl
Algorithm | Hash digest | |
---|---|---|
SHA256 | 1909fb1bf9905469ddab769d5c0eab2f8ee9c9d9e4466058703ffbfb5bcfff51 |
|
MD5 | cbafb6b816c88dd7e0819aaf29e56a1f |
|
BLAKE2b-256 | e3cca34c221fec9339ec9f5d253ac911024028e5d30f1d7f6e07dceadc9c2d44 |