A pip package for XAI Inferencing
Project description
XAI Inference Engine
This package is a wrapper for inferencing with CNN-based PyTorch models. It takes a trained PyTorch CNN model and returns the predictions, sorted prediction indices and saliency maps. For the saliency maps the library uses the FM-G-CAM method. More type of saliency map generation methods will be added in the future. The package also provides a method to superimpose the saliency maps on the input image.
Users can also use the package to create their own inference engine by extending the XAIInferenceEngine class.
Advanced Tutorials: Coming Soon...
Requirements
- Python 3.8+
- PyTorch 2.0+
Installation
Execute the following command in your terminal to install the package.
pip install xai-inference-engine
Usage
Follow the example below to use the package. Copy and paste the code into a python script and run it. Make sure you have the requirements installed. 😊
print("[INFO]: Testing XAIInferenceEngine...")
print("[INFO]: Importing Libraries...")
from xai_inference_engine import XAIInferenceEngine
from torchvision.models import resnet50, ResNet50_Weights
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print("[INFO]: Device: {}".format(device))
print("[INFO]: Loading Model...")
# Model
model = resnet50(weights=ResNet50_Weights.IMAGENET1K_V2).to(device)
weights = ResNet50_Weights.DEFAULT
preprocess = weights.transforms()
# Model config
# Set model to eval mode
model.eval()
last_conv_layer = model.layer4[2].conv3
class_count = 5
class_list = weights.meta["categories"]
img_h = 224
print("[INFO]: Image Preprocessing...")
# Image Preprocessing
url = "https://raw.githubusercontent.com/utkuozbulak/pytorch-cnn-visualizations/master/input_images/cat_dog.png"
r = requests.get(url, allow_redirects=True)
open("dog-and-cat-cover.jpg", "wb").write(r.content)
img = Image.open("dog-and-cat-cover.jpg")
img = img.resize((img_h, img_h), resample=Image.BICUBIC)
img_tensor = preprocess(img).to(device)
print("[INFO]: Creating XAIInferenceEngine...")
xai_inferencer = XAIInferenceEngine(
model=model,
last_conv_layer=last_conv_layer,
device=device,
)
print("[INFO]: Running XAIInferenceEngine.predict()...")
preds, sorted_pred_indices, super_imp_img, saliency_maps = xai_inferencer.predict(
img=img,
img_tensor=img_tensor,
)
print("[INFO]: Saving Results to the root folder...")
super_imp_img.save("super_imp_img.jpg")
saliency_maps.save("saliency_maps.jpg")
print("[INFO]: Displaying Results...")
print(" Predictions: {}".format(preds.shape))
print(" Sorted Prediction Indices: {}".format(sorted_pred_indices.cpu().numpy()[:10]))
print(" Heatmaps shape: {}".format(saliency_maps))
print(" Super Imposed Image: {}".format(super_imp_img))
Results
Following image shows comparison between the saliency maps generated by the FM-G-CAM method and the Grad-CAM method.
Author
Project details
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file xai_inference_engine-0.1.3.tar.gz.
File metadata
- Download URL: xai_inference_engine-0.1.3.tar.gz
- Upload date:
- Size: 9.2 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/4.0.2 CPython/3.10.5
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
cd5b7e1424e2c2108dff8836a8eed26723cf30ed3132bf81ae9cbf3e0aca0523
|
|
| MD5 |
abbb3a3a3d8ec04453493c38bb8c1603
|
|
| BLAKE2b-256 |
231fed12a84b5d21dcb0459079a86d4c28be4f5d4d316e77634db1be7df05afa
|
File details
Details for the file xai_inference_engine-0.1.3-py3-none-any.whl.
File metadata
- Download URL: xai_inference_engine-0.1.3-py3-none-any.whl
- Upload date:
- Size: 8.4 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/4.0.2 CPython/3.10.5
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
5ff91f4f8dbcf5abeefdacd0cc6f2573e96c55a5fd55bf5946af7af66e825e60
|
|
| MD5 |
91d138e93316c5eab46d9632e85744c1
|
|
| BLAKE2b-256 |
3ce19ddbb0f0d4b1b6a7a418a54559f87c75e4cf282efa5105dae1381e0621a7
|