Skip to main content

SIDU: SImilarity Difference and Uniqueness method for explainable AI

Project description

pytorch-sidu

size

SIDU: SImilarity Difference and Uniqueness method for explainable AI from the original paper

  • Pytorch implementation of the SIDU method.
  • Simple interface for loading pretrained models by specifying one of the following string name
  • Clear interface for generating saliency maps

Some examples made with VGG19 on Caltech-101 dataset:

img1 img7 img9

Installation

pip install pytorch-sidu

Usage

Load models from the pretrainde ones available in pytorch

import pytorch_sidu as sidu

model_name = 'ResNet34_Weights.IMAGENET1K_V1'
model = load_torch_model_by_string(model_name)

After instantianting your model, generate saliency maps from Dataloader

data_loader = <your dataloader>
target_layer = 'layer4.2.conv2'
image, _ = next(iter(data_loader))
saliency_maps = sidu.sidu(model, target_layer, image)

A complete example on CIFAR-10

import torch
import torchvision
from matplotlib import pyplot as plt
import pytorch_sidu as sidu


transform = torchvision.transforms.Compose([torchvision.transforms.Resize((224, 224)), torchvision.transforms.ToTensor()])
data_loader = torch.utils.data.DataLoader(
    torchvision.datasets.CIFAR10(root='./data', download=True, transform=transform), batch_size=2)

target_layer = 'layer4.2.conv2'
model_name = 'ResNet34_Weights.IMAGENET1K_V1'
model = sidu.load_torch_model_by_string(model_name)

for image, _ in data_loader:
    saliency_maps = sidu.sidu(model, target_layer, image)
    image, saliency_maps = image.cpu(), saliency_maps.cpu()

    for j in range(len(image)):
        plt.figure(figsize=(5, 2.5))
        plt.subplot(1, 2, 1)
        plt.imshow(image[j].permute(1, 2, 0))
        plt.axis('off')
        plt.subplot(1, 2, 2)
        plt.imshow(image[j].permute(1, 2, 0))
        plt.imshow(saliency_maps[j].squeeze().detach().numpy(), cmap='jet', alpha=0.4)
        plt.axis('off')
        plt.show()

upcoming features:

  • integration of xai metrics
  • make methods work on both single images and dataloaders
  • adding device flag to sidu function to allow device selection

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

pytorch-sidu-1.1.3.tar.gz (17.4 kB view hashes)

Uploaded Source

Built Distribution

pytorch_sidu-1.1.3-py3-none-any.whl (17.7 kB view hashes)

Uploaded Python 3

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page