SIDU: SImilarity Difference and Uniqueness method for explainable AI
Project description
pytorch-sidu
SIDU: SImilarity Difference and Uniqueness method for explainable AI from the original paper
- Pytorch implementation of the SIDU method.
- Simple interface for loading pretrained models by specifying one of the following string name
- Clear interface for generating saliency maps
Some examples made with VGG19 on Caltech-101 dataset:
Installation
pip install pytorch-sidu
Usage
Load models from the pretrainde ones available in pytorch
import pytorch_sidu as sidu
model_name = 'ResNet34_Weights.IMAGENET1K_V1'
model = load_torch_model_by_string(model_name)
After instantianting your model, generate saliency maps from Dataloader
data_loader = <your dataloader>
target_layer = 'layer4.2.conv2'
image, _ = next(iter(data_loader))
saliency_maps = sidu.sidu(model, target_layer, image)
A complete example on CIFAR-10
import torch
import torchvision
from matplotlib import pyplot as plt
import pytorch_sidu as sidu
transform = torchvision.transforms.Compose([torchvision.transforms.Resize((224, 224)), torchvision.transforms.ToTensor()])
data_loader = torch.utils.data.DataLoader(
torchvision.datasets.CIFAR10(root='./data', download=True, transform=transform), batch_size=2)
target_layer = 'layer4.2.conv2'
model_name = 'ResNet34_Weights.IMAGENET1K_V1'
model = sidu.load_torch_model_by_string(model_name)
for image, _ in data_loader:
saliency_maps = sidu.sidu(model, target_layer, image)
image, saliency_maps = image.cpu(), saliency_maps.cpu()
for j in range(len(image)):
plt.figure(figsize=(5, 2.5))
plt.subplot(1, 2, 1)
plt.imshow(image[j].permute(1, 2, 0))
plt.axis('off')
plt.subplot(1, 2, 2)
plt.imshow(image[j].permute(1, 2, 0))
plt.imshow(saliency_maps[j].squeeze().detach().numpy(), cmap='jet', alpha=0.4)
plt.axis('off')
plt.show()
upcoming features:
- integration of xai metrics
- make methods work on both single images and dataloaders
- adding
deviceflag to sidu function to allow device selection
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file pytorch-sidu-1.1.3.tar.gz.
File metadata
- Download URL: pytorch-sidu-1.1.3.tar.gz
- Upload date:
- Size: 17.4 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/5.0.0 CPython/3.12.2
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
23aef4fb24c5ee81cba4723feb364ac915de869245326c37c6551ce743d78a93
|
|
| MD5 |
3f38e2fa159da0cee673a57359da34b7
|
|
| BLAKE2b-256 |
dd960783223825d6fc019a91bf0ce6be7785cf1b0487cb2e33e4afbada947874
|
File details
Details for the file pytorch_sidu-1.1.3-py3-none-any.whl.
File metadata
- Download URL: pytorch_sidu-1.1.3-py3-none-any.whl
- Upload date:
- Size: 17.7 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/5.0.0 CPython/3.12.2
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
f1c16242b80d4423c8cbbe77561c23c11730bf1ffa8a421c986387cc7e454a52
|
|
| MD5 |
fc07938465dde53989e8798006318fa6
|
|
| BLAKE2b-256 |
35eae0a155f4dc484bb4c91d0a5a032551cf34b3839f9a0e67f09d5e20de126b
|