Skip to main content

TorchXAI is a PyTorch-based toolkit for evaluating machine learning models using explainability techniques.

Project description

TorchXAI

TorchXAI is a lightweight PyTorch toolkit for evaluating machine learning models using explainability techniques. It wraps Captum attribution methods and adds multi-target attribution — explain multiple output classes in a single forward pass — plus ready-to-use metrics for quantifying explanation quality.

  • Captum-compatible — works alongside the Captum explainers you already use
  • Multi-target — compute attributions for all targets at once, not one at a time
  • Batch & scalable — built for dataset-scale evaluation across many inputs and explainers

Installation

pip install torchxai-tools

The PyPI distribution is named torchxai-tools; the import name is torchxai.

from torchxai.explainers import SaliencyExplainer   # import name is torchxai

Quick start

Generating explanations

import torch
import torch.nn as nn
from torchxai.explainers import SaliencyExplainer, IntegratedGradientsExplainer
from torchxai.data_types import SingleTargetAcrossBatch

model = nn.Sequential(nn.Linear(10, 5), nn.ReLU(), nn.Linear(5, 3))
model.eval()
inputs = torch.randn(1, 10)

# Single target
explainer = SaliencyExplainer(model)
attrs = explainer.explain(inputs=inputs, target=SingleTargetAcrossBatch(index=0))
print(attrs.shape)   # (1, 10)

# All three classes in one call
explainer_mt = SaliencyExplainer(model, multi_target=True)
targets = [SingleTargetAcrossBatch(index=i) for i in range(3)]
attrs_list = explainer_mt.explain(inputs=inputs, target=targets)
print(len(attrs_list), attrs_list[0].shape)   # 3, (1, 10)

With a baseline (IntegratedGradients, DeepLift, …)

from torchxai.explainers import IntegratedGradientsExplainer

baseline = torch.zeros_like(inputs)
explainer = IntegratedGradientsExplainer(model)
attrs = explainer.explain(
    inputs=inputs,
    baselines=baseline,
    target=SingleTargetAcrossBatch(index=0),
)

Evaluating explanation quality

from torchxai.metrics.axiomatic import completeness
from captum.attr import Saliency

net = ...   # your model
saliency = Saliency(net)
input = torch.randn(2, 3, 32, 32, requires_grad=True)
baselines = torch.zeros(2, 3, 32, 32)

attribution = saliency.attribute(input, target=3)
score = completeness(net, input, attribution, baselines)
print("Completeness:", score)

Supported explainers

Explainer Requires baseline Notes
SaliencyExplainer
InputXGradientExplainer
GuidedBackpropExplainer Not compatible with transformers
RandomExplainer Baseline for sanity-checking
IntegratedGradientsExplainer
DeepLiftExplainer Not compatible with transformers
InputXBaselineGradientExplainer
DeepLiftShapExplainer ✓ distribution Not compatible with transformers
GradientShapExplainer ✓ distribution
FeatureAblationExplainer Optional feature_mask
LimeExplainer Optional feature_mask
KernelShapExplainer Optional feature_mask
OcclusionExplainer Requires sliding_window_shapes

Documentation

Full documentation including per-explainer API reference and end-to-end examples (image classification, BERT sequence classification, NER):

saifullah3396.github.io/torchxai

License

MIT — see LICENSE.txt.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

torchxai_tools-0.1.3.tar.gz (3.8 MB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

torchxai_tools-0.1.3-py3-none-any.whl (199.2 kB view details)

Uploaded Python 3

File details

Details for the file torchxai_tools-0.1.3.tar.gz.

File metadata

  • Download URL: torchxai_tools-0.1.3.tar.gz
  • Upload date:
  • Size: 3.8 MB
  • Tags: Source
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: uv/0.11.15 {"installer":{"name":"uv","version":"0.11.15","subcommand":["publish"]},"python":null,"implementation":{"name":null,"version":null},"distro":{"name":"Ubuntu","version":"24.04","id":"noble","libc":null},"system":{"name":null,"release":null},"cpu":null,"openssl_version":null,"setuptools_version":null,"rustc_version":null,"ci":true}

File hashes

Hashes for torchxai_tools-0.1.3.tar.gz
Algorithm Hash digest
SHA256 ee446d2fc4a6b194a47073f51b0795a85654da8ff044f5b454147b2607e3c950
MD5 68ab359e3b46cef89bfd4594e779d134
BLAKE2b-256 10612a7532f9b36e833dec7821f23d0a011d8f6c572eee3d8d20c40e9066dbb6

See more details on using hashes here.

File details

Details for the file torchxai_tools-0.1.3-py3-none-any.whl.

File metadata

  • Download URL: torchxai_tools-0.1.3-py3-none-any.whl
  • Upload date:
  • Size: 199.2 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: uv/0.11.15 {"installer":{"name":"uv","version":"0.11.15","subcommand":["publish"]},"python":null,"implementation":{"name":null,"version":null},"distro":{"name":"Ubuntu","version":"24.04","id":"noble","libc":null},"system":{"name":null,"release":null},"cpu":null,"openssl_version":null,"setuptools_version":null,"rustc_version":null,"ci":true}

File hashes

Hashes for torchxai_tools-0.1.3-py3-none-any.whl
Algorithm Hash digest
SHA256 192900e7a6462457c3103ce4c1528ef83add36d24a2041a84ef1f1b137573258
MD5 90ebf365035d84eaff73e83a5f4ea9a2
BLAKE2b-256 18afec39dcf0df245319a544d672c8256de642fd02c977e5eee455518bdd8c8a

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page