Skip to main content

TorchXAI is a PyTorch-based toolkit for evaluating machine learning models using explainability techniques.

Project description

TorchXAI

TorchXAI is a lightweight PyTorch toolkit for evaluating machine learning models using explainability techniques. It wraps Captum attribution methods and adds multi-target attribution — explain multiple output classes in a single forward pass — plus ready-to-use metrics for quantifying explanation quality.

  • Captum-compatible — works alongside the Captum explainers you already use
  • Multi-target — compute attributions for all targets at once, not one at a time
  • Batch & scalable — built for dataset-scale evaluation across many inputs and explainers

Installation

pip install torchxai-tools

The PyPI distribution is named torchxai-tools; the import name is torchxai.

from torchxai.explainers import SaliencyExplainer   # import name is torchxai

Quick start

Generating explanations

import torch
import torch.nn as nn
from torchxai.explainers import SaliencyExplainer, IntegratedGradientsExplainer
from torchxai.data_types import SingleTargetAcrossBatch

model = nn.Sequential(nn.Linear(10, 5), nn.ReLU(), nn.Linear(5, 3))
model.eval()
inputs = torch.randn(1, 10)

# Single target
explainer = SaliencyExplainer(model)
attrs = explainer.explain(inputs=inputs, target=SingleTargetAcrossBatch(index=0))
print(attrs.shape)   # (1, 10)

# All three classes in one call
explainer_mt = SaliencyExplainer(model, multi_target=True)
targets = [SingleTargetAcrossBatch(index=i) for i in range(3)]
attrs_list = explainer_mt.explain(inputs=inputs, target=targets)
print(len(attrs_list), attrs_list[0].shape)   # 3, (1, 10)

With a baseline (IntegratedGradients, DeepLift, …)

from torchxai.explainers import IntegratedGradientsExplainer

baseline = torch.zeros_like(inputs)
explainer = IntegratedGradientsExplainer(model)
attrs = explainer.explain(
    inputs=inputs,
    baselines=baseline,
    target=SingleTargetAcrossBatch(index=0),
)

Evaluating explanation quality

from torchxai.metrics.axiomatic import completeness
from captum.attr import Saliency

net = ...   # your model
saliency = Saliency(net)
input = torch.randn(2, 3, 32, 32, requires_grad=True)
baselines = torch.zeros(2, 3, 32, 32)

attribution = saliency.attribute(input, target=3)
score = completeness(net, input, attribution, baselines)
print("Completeness:", score)

Supported explainers

Explainer Requires baseline Notes
SaliencyExplainer
InputXGradientExplainer
GuidedBackpropExplainer Not compatible with transformers
RandomExplainer Baseline for sanity-checking
IntegratedGradientsExplainer
DeepLiftExplainer Not compatible with transformers
InputXBaselineGradientExplainer
DeepLiftShapExplainer ✓ distribution Not compatible with transformers
GradientShapExplainer ✓ distribution
FeatureAblationExplainer Optional feature_mask
LimeExplainer Optional feature_mask
KernelShapExplainer Optional feature_mask
OcclusionExplainer Requires sliding_window_shapes

Documentation

Full documentation including per-explainer API reference and end-to-end examples (image classification, BERT sequence classification, NER):

saifullah3396.github.io/torchxai

License

MIT — see LICENSE.txt.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

torchxai_tools-0.1.2.tar.gz (3.7 MB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

torchxai_tools-0.1.2-py3-none-any.whl (195.5 kB view details)

Uploaded Python 3

File details

Details for the file torchxai_tools-0.1.2.tar.gz.

File metadata

  • Download URL: torchxai_tools-0.1.2.tar.gz
  • Upload date:
  • Size: 3.7 MB
  • Tags: Source
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: uv/0.11.15 {"installer":{"name":"uv","version":"0.11.15","subcommand":["publish"]},"python":null,"implementation":{"name":null,"version":null},"distro":{"name":"Ubuntu","version":"24.04","id":"noble","libc":null},"system":{"name":null,"release":null},"cpu":null,"openssl_version":null,"setuptools_version":null,"rustc_version":null,"ci":true}

File hashes

Hashes for torchxai_tools-0.1.2.tar.gz
Algorithm Hash digest
SHA256 289b70099bbc7e73e9f2c50a1b957d400098bb58469fadf5597f4e2a6a356de4
MD5 872b88f9b43df11018c6c5d68420de5c
BLAKE2b-256 594491d143c1bbd01e9a4e4114bb98d8b767ecb44a3dbf3ff69218a8264e8e79

See more details on using hashes here.

File details

Details for the file torchxai_tools-0.1.2-py3-none-any.whl.

File metadata

  • Download URL: torchxai_tools-0.1.2-py3-none-any.whl
  • Upload date:
  • Size: 195.5 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: uv/0.11.15 {"installer":{"name":"uv","version":"0.11.15","subcommand":["publish"]},"python":null,"implementation":{"name":null,"version":null},"distro":{"name":"Ubuntu","version":"24.04","id":"noble","libc":null},"system":{"name":null,"release":null},"cpu":null,"openssl_version":null,"setuptools_version":null,"rustc_version":null,"ci":true}

File hashes

Hashes for torchxai_tools-0.1.2-py3-none-any.whl
Algorithm Hash digest
SHA256 5af0e0a596b64f2599b7fe60ceae825f492bbb869264693ef2dac321df62dd5c
MD5 e41cb718c445bf09c229877ff99059ee
BLAKE2b-256 2bd6117168271e799e80c781471b13093b07602b3846db336bdccdc92b78d8a5

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page