Skip to main content

A pip package for XAI Inferencing

Project description

License: MIT Build Status

XAI Inference Engine

This package is a wrapper for inferencing with CNN-based PyTorch models. It takes a trained PyTorch CNN model and returns the predictions, sorted prediction indices and saliency maps. For the saliency maps the library uses the FM-G-CAM method. More type of saliency map generation methods will be added in the future. The package also provides a method to superimpose the saliency maps on the input image.

Users can also use the package to create their own inference engine by extending the XAIInferenceEngine class.

Advanced Tutorials: Coming Soon...

Github | PyPi

Requirements

  • Python 3.8+
  • PyTorch 2.0+

Installation

Execute the following command in your terminal to install the package.

pip install xai-inference-engine

Usage

Follow the example below to use the package. Copy and paste the code into a python script and run it. Make sure you have the requirements installed. 😊

print("[INFO]: Testing XAIInferenceEngine...")

print("[INFO]: Importing Libraries...")
from xai_inference_engine import XAIInferenceEngine
from torchvision.models import resnet50, ResNet50_Weights

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print("[INFO]: Device: {}".format(device))

print("[INFO]: Loading Model...")
# Model
model = resnet50(weights=ResNet50_Weights.IMAGENET1K_V2).to(device)
weights = ResNet50_Weights.DEFAULT
preprocess = weights.transforms()

# Model config
# Set model to eval mode
model.eval()
last_conv_layer = model.layer4[2].conv3
class_count = 5
class_list = weights.meta["categories"]
img_h = 224

print("[INFO]: Image Preprocessing...")
# Image Preprocessing
url = "https://raw.githubusercontent.com/utkuozbulak/pytorch-cnn-visualizations/master/input_images/cat_dog.png"
r = requests.get(url, allow_redirects=True)
open("dog-and-cat-cover.jpg", "wb").write(r.content)
img = Image.open("dog-and-cat-cover.jpg")
img = img.resize((img_h, img_h), resample=Image.BICUBIC)
img_tensor = preprocess(img).to(device)


print("[INFO]: Creating XAIInferenceEngine...")
xai_inferencer = XAIInferenceEngine(
    model=model,
    last_conv_layer=last_conv_layer,
    device=device,
)

print("[INFO]: Running XAIInferenceEngine.predict()...")
preds, sorted_pred_indices, super_imp_img, saliency_maps = xai_inferencer.predict(
    img=img,
    img_tensor=img_tensor,
)

print("[INFO]: Saving Results to the root folder...")
super_imp_img.save("super_imp_img.jpg")
saliency_maps.save("saliency_maps.jpg")

print("[INFO]: Displaying Results...")
print("        Predictions: {}".format(preds.shape))
print("        Sorted Prediction Indices: {}".format(sorted_pred_indices.cpu().numpy()[:10]))
print("        Heatmaps shape: {}".format(saliency_maps))
print("        Super Imposed Image: {}".format(super_imp_img))

Results

Following image shows comparison between the saliency maps generated by the FM-G-CAM method and the Grad-CAM method.

FM-G-CAM Comparison with Grad-CAM

Author

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

xai_inference_engine-0.1.3.tar.gz (9.2 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

xai_inference_engine-0.1.3-py3-none-any.whl (8.4 kB view details)

Uploaded Python 3

File details

Details for the file xai_inference_engine-0.1.3.tar.gz.

File metadata

  • Download URL: xai_inference_engine-0.1.3.tar.gz
  • Upload date:
  • Size: 9.2 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/4.0.2 CPython/3.10.5

File hashes

Hashes for xai_inference_engine-0.1.3.tar.gz
Algorithm Hash digest
SHA256 cd5b7e1424e2c2108dff8836a8eed26723cf30ed3132bf81ae9cbf3e0aca0523
MD5 abbb3a3a3d8ec04453493c38bb8c1603
BLAKE2b-256 231fed12a84b5d21dcb0459079a86d4c28be4f5d4d316e77634db1be7df05afa

See more details on using hashes here.

File details

Details for the file xai_inference_engine-0.1.3-py3-none-any.whl.

File metadata

File hashes

Hashes for xai_inference_engine-0.1.3-py3-none-any.whl
Algorithm Hash digest
SHA256 5ff91f4f8dbcf5abeefdacd0cc6f2573e96c55a5fd55bf5946af7af66e825e60
MD5 91d138e93316c5eab46d9632e85744c1
BLAKE2b-256 3ce19ddbb0f0d4b1b6a7a418a54559f87c75e4cf282efa5105dae1381e0621a7

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page