SIDU: SImilarity Difference and Uniqueness method for explainable AI
Project description
pytorch-sidu
SIDU: SImilarity Difference and Uniqueness method for explainable AI from the original paper
- Pytorch implementation of the SIDU method.
- Simple interface for loading pretrained models by specifying one of the following string name
- Clear interface for generating saliency maps
Some examples made with VGG19 on Caltech-101 dataset:
Installation
pip install pytorch-sidu
Usage
Load models from the pretrainde ones available in pytorch
import pytorch_sidu as sidu
weights = "ResNet18_Weights.IMAGENET1K_V1"
backbone = sidu.load_torch_backbone(weights)
After instantianting your model, generate saliency maps from Dataloader
data_loader = <your dataloader>
image, _ = next(iter(data_loader))
saliency_maps = sidu.sidu(backbone, image)
A complete example on CIFAR-10
import torch
import torchvision
from matplotlib import pyplot as plt
import pytorch_sidu as sidu
transform = torchvision.transforms.Compose([torchvision.transforms.Resize((256, 256)), torchvision.transforms.ToTensor()])
data_loader = torch.utils.data.DataLoader(
torchvision.datasets.CIFAR10(root='./data', download=True, transform=transform), batch_size=2)
weights = "ResNet18_Weights.IMAGENET1K_V1"
backbone = sidu.load_torch_backbone(weights)
for image, _ in data_loader:
saliency_maps = sidu.sidu(backbone, image)
image, saliency_maps = image.cpu(), saliency_maps.cpu()
for j in range(len(image)):
plt.figure(figsize=(5, 2.5))
plt.subplot(1, 2, 1)
plt.imshow(image[j].permute(1, 2, 0))
plt.axis('off')
plt.subplot(1, 2, 2)
plt.imshow(image[j].permute(1, 2, 0))
plt.imshow(saliency_maps[j].squeeze().detach().numpy(), cmap='jet', alpha=0.4)
plt.axis('off')
plt.show()
License
This project is licensed under the GNU General Public License 3.0. For more details, see the LICENSE file in the root directory of this project or check out the GNU General Public License 3.0. here
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
pytorch-sidu-1.0.0.tar.gz
(16.7 kB
view hashes)
Built Distribution
Close
Hashes for pytorch_sidu-1.0.0-py3-none-any.whl
Algorithm | Hash digest | |
---|---|---|
SHA256 | 5a08041a272be8225431604a856f1ac20baca2e56aab0f1a518135e39186cbb1 |
|
MD5 | 94527483e11e17a725cedf75b2c420be |
|
BLAKE2b-256 | 147fef5fd7f059440138c610b7050405b9d97b1213552397e290aec3278777ae |