TorchXAI is a PyTorch-based toolkit for evaluating machine learning models using explainability techniques.
Project description
TorchXAI
TorchXAI is a lightweight PyTorch toolkit for evaluating machine learning models using explainability techniques. It wraps Captum attribution methods and adds multi-target attribution — explain multiple output classes in a single forward pass — plus ready-to-use metrics for quantifying explanation quality.
- Captum-compatible — works alongside the Captum explainers you already use
- Multi-target — compute attributions for all targets at once, not one at a time
- Batch & scalable — built for dataset-scale evaluation across many inputs and explainers
Installation
pip install torchxai-tools
The PyPI distribution is named torchxai-tools; the import name is torchxai.
from torchxai.explainers import SaliencyExplainer # import name is torchxai
Quick start
Generating explanations
import torch
import torch.nn as nn
from torchxai.explainers import SaliencyExplainer, IntegratedGradientsExplainer
from torchxai.data_types import SingleTargetAcrossBatch
model = nn.Sequential(nn.Linear(10, 5), nn.ReLU(), nn.Linear(5, 3))
model.eval()
inputs = torch.randn(1, 10)
# Single target
explainer = SaliencyExplainer(model)
attrs = explainer.explain(inputs=inputs, target=SingleTargetAcrossBatch(index=0))
print(attrs.shape) # (1, 10)
# All three classes in one call
explainer_mt = SaliencyExplainer(model, multi_target=True)
targets = [SingleTargetAcrossBatch(index=i) for i in range(3)]
attrs_list = explainer_mt.explain(inputs=inputs, target=targets)
print(len(attrs_list), attrs_list[0].shape) # 3, (1, 10)
With a baseline (IntegratedGradients, DeepLift, …)
from torchxai.explainers import IntegratedGradientsExplainer
baseline = torch.zeros_like(inputs)
explainer = IntegratedGradientsExplainer(model)
attrs = explainer.explain(
inputs=inputs,
baselines=baseline,
target=SingleTargetAcrossBatch(index=0),
)
Evaluating explanation quality
from torchxai.metrics.axiomatic import completeness
from captum.attr import Saliency
net = ... # your model
saliency = Saliency(net)
input = torch.randn(2, 3, 32, 32, requires_grad=True)
baselines = torch.zeros(2, 3, 32, 32)
attribution = saliency.attribute(input, target=3)
score = completeness(net, input, attribution, baselines)
print("Completeness:", score)
Supported explainers
| Explainer | Requires baseline | Notes |
|---|---|---|
SaliencyExplainer |
✗ | |
InputXGradientExplainer |
✗ | |
GuidedBackpropExplainer |
✗ | Not compatible with transformers |
RandomExplainer |
✗ | Baseline for sanity-checking |
IntegratedGradientsExplainer |
✓ | |
DeepLiftExplainer |
✓ | Not compatible with transformers |
InputXBaselineGradientExplainer |
✓ | |
DeepLiftShapExplainer |
✓ distribution | Not compatible with transformers |
GradientShapExplainer |
✓ distribution | |
FeatureAblationExplainer |
✗ | Optional feature_mask |
LimeExplainer |
✗ | Optional feature_mask |
KernelShapExplainer |
✗ | Optional feature_mask |
OcclusionExplainer |
✗ | Requires sliding_window_shapes |
Documentation
Full documentation including per-explainer API reference and end-to-end examples (image classification, BERT sequence classification, NER):
saifullah3396.github.io/torchxai
License
MIT — see LICENSE.txt.
Project details
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file torchxai_tools-0.1.3.tar.gz.
File metadata
- Download URL: torchxai_tools-0.1.3.tar.gz
- Upload date:
- Size: 3.8 MB
- Tags: Source
- Uploaded using Trusted Publishing? Yes
- Uploaded via: uv/0.11.15 {"installer":{"name":"uv","version":"0.11.15","subcommand":["publish"]},"python":null,"implementation":{"name":null,"version":null},"distro":{"name":"Ubuntu","version":"24.04","id":"noble","libc":null},"system":{"name":null,"release":null},"cpu":null,"openssl_version":null,"setuptools_version":null,"rustc_version":null,"ci":true}
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
ee446d2fc4a6b194a47073f51b0795a85654da8ff044f5b454147b2607e3c950
|
|
| MD5 |
68ab359e3b46cef89bfd4594e779d134
|
|
| BLAKE2b-256 |
10612a7532f9b36e833dec7821f23d0a011d8f6c572eee3d8d20c40e9066dbb6
|
File details
Details for the file torchxai_tools-0.1.3-py3-none-any.whl.
File metadata
- Download URL: torchxai_tools-0.1.3-py3-none-any.whl
- Upload date:
- Size: 199.2 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? Yes
- Uploaded via: uv/0.11.15 {"installer":{"name":"uv","version":"0.11.15","subcommand":["publish"]},"python":null,"implementation":{"name":null,"version":null},"distro":{"name":"Ubuntu","version":"24.04","id":"noble","libc":null},"system":{"name":null,"release":null},"cpu":null,"openssl_version":null,"setuptools_version":null,"rustc_version":null,"ci":true}
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
192900e7a6462457c3103ce4c1528ef83add36d24a2041a84ef1f1b137573258
|
|
| MD5 |
90ebf365035d84eaff73e83a5f4ea9a2
|
|
| BLAKE2b-256 |
18afec39dcf0df245319a544d672c8256de642fd02c977e5eee455518bdd8c8a
|