Skip to main content

A pip package for XAI Inferencing

Project description

License: MIT Build Status

XAI Inference Engine

This package is a wrapper for CNN-based PyTorch models that is capable of performing XAI inferencing. It wraps a trained PyTorch CNN model and allows it to return the predictions, sorted prediction indices and saliency maps when provided with a preprocessed input. For the saliency maps the library uses the FM-G-CAM method. More types of saliency map generation methods will be added in the future. The package also provides a method to superimpose the saliency maps on the input image.

Users can also use the package to create their own inference engine by extending the XAIInferenceEngine class.

Advanced Tutorials: Coming Soon...

Github | PyPi

Requirements

  • Python 3.8+
  • PyTorch 2.0+

Installation

Execute the following command in your terminal to install the package.

pip install xai-inference-engine

Usage

Follow the example below to use the package. Copy and paste the code into a python script and run it. Make sure you have the requirements installed. 😊

print("[INFO]: Testing XAIInferenceEngine...")

print("[INFO]: Importing Libraries...")
from xai_inference_engine import XAIInferenceEngine
from torchvision.models import resnet50, ResNet50_Weights

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print("[INFO]: Device: {}".format(device))

print("[INFO]: Loading Model...")
# Model
model = resnet50(weights=ResNet50_Weights.IMAGENET1K_V2).to(device)
weights = ResNet50_Weights.DEFAULT
preprocess = weights.transforms()

# Model config
# Set model to eval mode
model.eval()
last_conv_layer = model.layer4[2].conv3
class_count = 5
class_list = weights.meta["categories"]
img_h = 224

print("[INFO]: Image Preprocessing...")
# Image Preprocessing
url = "https://raw.githubusercontent.com/utkuozbulak/pytorch-cnn-visualizations/master/input_images/cat_dog.png"
r = requests.get(url, allow_redirects=True)
open("dog-and-cat-cover.jpg", "wb").write(r.content)
img = Image.open("dog-and-cat-cover.jpg")
img = img.resize((img_h, img_h), resample=Image.BICUBIC)
img_tensor = preprocess(img).to(device)


print("[INFO]: Creating XAIInferenceEngine...")
xai_inferencer = XAIInferenceEngine(
    model=model,
    last_conv_layer=last_conv_layer,
    device=device,
)

print("[INFO]: Running XAIInferenceEngine.predict()...")
preds, sorted_pred_indices, super_imp_img, saliency_maps = xai_inferencer.predict(
    img=img,
    img_tensor=img_tensor,
)

print("[INFO]: Saving Results to the root folder...")
super_imp_img.save("super_imp_img.jpg")
saliency_maps.save("saliency_maps.jpg")

print("[INFO]: Displaying Results...")
print("        Predictions: {}".format(preds.shape))
print("        Sorted Prediction Indices: {}".format(sorted_pred_indices.cpu().numpy()[:10]))
print("        Heatmaps shape: {}".format(saliency_maps))
print("        Super Imposed Image: {}".format(super_imp_img))

Results

Following image shows comparison between the saliency maps generated by the FM-G-CAM method and the Grad-CAM method.

FM-G-CAM Comparison with Grad-CAM

Author

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

xai_inference_engine-0.1.4.tar.gz (9.3 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

xai_inference_engine-0.1.4-py3-none-any.whl (8.4 kB view details)

Uploaded Python 3

File details

Details for the file xai_inference_engine-0.1.4.tar.gz.

File metadata

  • Download URL: xai_inference_engine-0.1.4.tar.gz
  • Upload date:
  • Size: 9.3 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/4.0.2 CPython/3.10.5

File hashes

Hashes for xai_inference_engine-0.1.4.tar.gz
Algorithm Hash digest
SHA256 ffdab73bb7d5767fe919f67079abce2789cfb8b4049a71717ade1e909795db71
MD5 0b8a4f025a0529790f0be25f56ea9fbb
BLAKE2b-256 3e663942fb87b4e0464b44435b77fd3f55bbc4a15bdd893233b349b0d0ae5d98

See more details on using hashes here.

File details

Details for the file xai_inference_engine-0.1.4-py3-none-any.whl.

File metadata

File hashes

Hashes for xai_inference_engine-0.1.4-py3-none-any.whl
Algorithm Hash digest
SHA256 85910a8e32be74018800cc149e00d7c0a782cfd680367c2be50f2ebc90688cca
MD5 1dc840b577648c0d5d052e73f86166b2
BLAKE2b-256 17285c73b428cdb556ccd84ca7c6e434601b8a2852ae9759a57bbd667cf3c719

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page