SHAP and LLM-based explanations for model interpretability
Project description
SHAPXplain
SHAPXplain combines SHAP (SHapley Additive exPlanations) with Large Language Models (LLMs) to provide natural language explanations of machine learning model predictions. The package helps bridge the gap between technical SHAP values and human-understandable insights.
Features
- Natural Language Explanations: Convert complex SHAP values into clear, actionable explanations using LLMs.
- Flexible LLM Integration: Works with any LLM via the
pydantic-ai
interface. - Structured Outputs: Get standardized explanation formats including summaries, detailed analysis, and recommendations.
- Asynchronous API: Process explanations in parallel with async/await support.
- Robust Error Handling: Built-in retry logic with exponential backoff for API reliability.
- Batch Processing: Handle multiple predictions efficiently with both sync and async methods.
- Confidence Levels: Understand the reliability of explanations.
- Feature Interaction Analysis: Identify and explain how features work together.
- Data Contracts: Provide domain-specific context to enhance explanation quality.
Installation
You can install SHAPXplain using pip:
pip install shapxplain
Or with Poetry:
poetry add shapxplain
API Key Setup
SHAPXplain uses LLMs through the pydantic-ai interface. You'll need to set up API keys for your preferred provider:
# For OpenAI
import os
os.environ["OPENAI_API_KEY"] = "your-api-key"
# Or using a .env file
from dotenv import load_dotenv
load_dotenv() # Will load OPENAI_API_KEY from .env file
Supported providers via pydantic-ai include:
- OpenAI (environment variable:
OPENAI_API_KEY
) - Anthropic (environment variable:
ANTHROPIC_API_KEY
) - DeepSeek (environment variable:
DEEPSEEK_API_KEY
) - Others as supported by pydantic-ai
Quick Start
Here's a complete example using the Iris dataset:
import numpy as np
from sklearn.datasets import load_iris
from sklearn.ensemble import RandomForestClassifier
from shap import TreeExplainer
from shapxplain import ShapLLMExplainer
from pydantic_ai import Agent
# Load data and train model
data = load_iris()
X, y = data.data, data.target
model = RandomForestClassifier(random_state=42)
model.fit(X, y)
# Generate SHAP values
explainer = TreeExplainer(model)
shap_values = explainer.shap_values(X)
# Create LLM agent and SHAPXplain explainer
llm_agent = Agent(model="openai:gpt-4o") # Or your preferred LLM
llm_explainer = ShapLLMExplainer(
model=model,
llm_agent=llm_agent,
feature_names=data.feature_names,
significance_threshold=0.1
)
# Explain a single prediction
data_point = X[0]
prediction_probs = model.predict_proba(data_point.reshape(1, -1))[0]
predicted_class_idx = model.predict(data_point.reshape(1, -1))[0]
prediction_class = data.target_names[predicted_class_idx]
# Get class-specific SHAP values
class_shap_values = shap_values[predicted_class_idx][0]
# Generate explanation
explanation = llm_explainer.explain(
shap_values=class_shap_values,
data_point=data_point,
prediction=prediction_probs[predicted_class_idx],
prediction_class=prediction_class
)
# Access different parts of the explanation
print("Summary:", explanation.summary)
print("\nDetailed Explanation:", explanation.detailed_explanation)
print("\nRecommendations:", explanation.recommendations)
print("\nConfidence Level:", explanation.confidence_level)
Explanation Structure
The package provides structured explanations with the following components:
class SHAPExplanationResponse:
summary: str # Brief overview of key drivers
detailed_explanation: str # Comprehensive analysis
recommendations: List[str] # Actionable insights
confidence_level: str # high/medium/low
feature_interactions: Dict[str, str] # How features work together
features: List[SHAPFeatureContribution] # Detailed feature impacts
Batch Processing
Synchronous Batch Processing
batch_response = llm_explainer.explain_batch(
shap_values_batch=shap_values,
data_points=X,
predictions=predictions,
batch_size=5, # Optional: control batch size
additional_context={
"dataset": "Iris",
"feature_descriptions": {...}
}
)
# Access batch results
for response in batch_response.responses:
print(response.summary)
# Get batch insights
print("Batch Insights:", batch_response.batch_insights)
print("Summary Statistics:", batch_response.summary_statistics)
Asynchronous Batch Processing
For significantly improved performance with large batches:
import asyncio
async def process_batch():
batch_response = await llm_explainer.explain_batch_async(
shap_values_batch=shap_values,
data_points=X,
predictions=predictions,
additional_context={
"dataset": "Iris",
"feature_descriptions": {...}
}
)
# Process results asynchronously
return batch_response
# Run the async function
batch_results = asyncio.run(process_batch())
Advanced Usage
Data Contracts for Enhanced Explanations
One of SHAPXplain's most powerful features is the ability to provide domain-specific context through the
additional_context
parameter, effectively creating a "data contract" that guides the LLM:
explanation = llm_explainer.explain(
shap_values=class_shap_values,
data_point=data_point,
prediction=prediction,
additional_context={
"domain": "medical_diagnosis",
"feature_descriptions": {
"glucose": "Blood glucose level in mg/dL. Normal range: 70-99 mg/dL fasting",
"blood_pressure": "Systolic blood pressure in mmHg. Normal range: <120 mmHg",
"bmi": "Body Mass Index. Normal range: 18.5-24.9"
},
"reference_ranges": {
"glucose": {"low": "<70", "normal": "70-99", "prediabetes": "100-125", "diabetes": ">126"},
"blood_pressure": {"normal": "<120", "elevated": "120-129", "stage1": "130-139", "stage2": ">=140"}
},
"measurement_units": {
"glucose": "mg/dL",
"blood_pressure": "mmHg",
"bmi": "kg/m²"
},
"patient_context": "65-year-old male with family history of type 2 diabetes"
}
)
Error Handling
try:
explanation = llm_explainer.explain(
shap_values=class_shap_values,
data_point=data_point,
prediction=prediction
)
except ValueError as e:
print(f"Input validation error: {e}")
except RuntimeError as e:
print(f"LLM query error: {e}")
Contributing
Contributions are welcome! Please feel free to submit a Pull Request.
Documentation
For detailed documentation, tutorials, and API reference, visit: https://shapxplain.readthedocs.io/
Development
To set up the development environment:
# Clone the repo
git clone https://github.com/mpearmain/shapxplain.git
cd shapxplain
# Install with development dependencies
poetry install --with dev
Run tests:
poetry run pytest
Format code:
poetry run black src tests
poetry run ruff check --fix src tests
License
This project is licensed under the MIT License - see the LICENSE file for details.
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
File details
Details for the file shapxplain-0.1.0.tar.gz
.
File metadata
- Download URL: shapxplain-0.1.0.tar.gz
- Upload date:
- Size: 18.5 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: poetry/1.7.1 CPython/3.12.9 Linux/6.8.0-1021-azure
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 |
9dd6d457a60549d5ff3a9073f4a4d6e2220ba2b0902eacdf481cc9495fcb60cf
|
|
MD5 |
d2aac306df5ce3ffd8250d55b8bf8c41
|
|
BLAKE2b-256 |
9a94dc7bd649d465e3ed95870e85e11a06aa6d5c675221d2c97d01cef10fed3e
|
File details
Details for the file shapxplain-0.1.0-py3-none-any.whl
.
File metadata
- Download URL: shapxplain-0.1.0-py3-none-any.whl
- Upload date:
- Size: 18.0 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: poetry/1.7.1 CPython/3.12.9 Linux/6.8.0-1021-azure
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 |
9cdfdee8af3533a2889975ec881a36cdfa1deedfdc32b5e05bffa8620b8d33a0
|
|
MD5 |
a2230dd983837b11729cd278aeb0d859
|
|
BLAKE2b-256 |
707fb8d02a8dac6deb4aefbe99c33d737bcf6519122b3bfe89accbdf6d78cb53
|