Skip to main content

SHAP and LLM-based explanations for model interpretability

Project description

SHAPXplain

PyPI version Python Versions License: MIT CI Status Documentation Status

SHAPXplain combines SHAP (SHapley Additive exPlanations) with Large Language Models (LLMs) to provide natural language explanations of machine learning model predictions. The package helps bridge the gap between technical SHAP values and human-understandable insights.

Features

  • Natural Language Explanations: Convert complex SHAP values into clear, actionable explanations using LLMs.
  • Flexible LLM Integration: Works with any LLM via the pydantic-ai interface.
  • Structured Outputs: Get standardized explanation formats including summaries, detailed analysis, and recommendations.
  • Asynchronous API: Process explanations in parallel with async/await support.
  • Robust Error Handling: Built-in retry logic with exponential backoff for API reliability.
  • Batch Processing: Handle multiple predictions efficiently with both sync and async methods.
  • Confidence Levels: Understand the reliability of explanations.
  • Feature Interaction Analysis: Identify and explain how features work together.
  • Data Contracts: Provide domain-specific context to enhance explanation quality.

Installation

You can install SHAPXplain using pip:

pip install shapxplain

Or with Poetry:

poetry add shapxplain

API Key Setup

SHAPXplain uses LLMs through the pydantic-ai interface. You'll need to set up API keys for your preferred provider:

# For OpenAI
import os
os.environ["OPENAI_API_KEY"] = "your-api-key"

# Or using a .env file
from dotenv import load_dotenv
load_dotenv()  # Will load OPENAI_API_KEY from .env file

Supported providers via pydantic-ai include:

  • OpenAI (environment variable: OPENAI_API_KEY)
  • Anthropic (environment variable: ANTHROPIC_API_KEY)
  • DeepSeek (environment variable: DEEPSEEK_API_KEY)
  • Others as supported by pydantic-ai

Quick Start

Here's a complete example using the Iris dataset:

import numpy as np
from sklearn.datasets import load_iris
from sklearn.ensemble import RandomForestClassifier
from shap import TreeExplainer
from shapxplain import ShapLLMExplainer
from pydantic_ai import Agent

# Load data and train model
data = load_iris()
X, y = data.data, data.target
model = RandomForestClassifier(random_state=42)
model.fit(X, y)

# Generate SHAP values
explainer = TreeExplainer(model)
shap_values = explainer.shap_values(X)

# Create LLM agent and SHAPXplain explainer
llm_agent = Agent(model="openai:gpt-4o")  # Or your preferred LLM
llm_explainer = ShapLLMExplainer(
    model=model,
    llm_agent=llm_agent,
    feature_names=data.feature_names,
    significance_threshold=0.1
)

# Explain a single prediction
data_point = X[0]
prediction_probs = model.predict_proba(data_point.reshape(1, -1))[0]
predicted_class_idx = model.predict(data_point.reshape(1, -1))[0]
prediction_class = data.target_names[predicted_class_idx]

# Get class-specific SHAP values
class_shap_values = shap_values[predicted_class_idx][0]

# Generate explanation
explanation = llm_explainer.explain(
    shap_values=class_shap_values,
    data_point=data_point,
    prediction=prediction_probs[predicted_class_idx],
    prediction_class=prediction_class
)

# Access different parts of the explanation
print("Summary:", explanation.summary)
print("\nDetailed Explanation:", explanation.detailed_explanation)
print("\nRecommendations:", explanation.recommendations)
print("\nConfidence Level:", explanation.confidence_level)

Explanation Structure

The package provides structured explanations with the following components:

class SHAPExplanationResponse:
    summary: str  # Brief overview of key drivers
    detailed_explanation: str  # Comprehensive analysis
    recommendations: List[str]  # Actionable insights
    confidence_level: str  # high/medium/low
    feature_interactions: Dict[str, str]  # How features work together
    features: List[SHAPFeatureContribution]  # Detailed feature impacts

Batch Processing

Synchronous Batch Processing

batch_response = llm_explainer.explain_batch(
    shap_values_batch=shap_values,
    data_points=X,
    predictions=predictions,
    batch_size=5,  # Optional: control batch size
    additional_context={
        "dataset": "Iris",
        "feature_descriptions": {...}
    }
)

# Access batch results
for response in batch_response.responses:
    print(response.summary)

# Get batch insights
print("Batch Insights:", batch_response.batch_insights)
print("Summary Statistics:", batch_response.summary_statistics)

Asynchronous Batch Processing

For significantly improved performance with large batches:

import asyncio

async def process_batch():
    batch_response = await llm_explainer.explain_batch_async(
        shap_values_batch=shap_values,
        data_points=X,
        predictions=predictions,
        additional_context={
            "dataset": "Iris",
            "feature_descriptions": {...}
        }
    )
    
    # Process results asynchronously
    return batch_response

# Run the async function
batch_results = asyncio.run(process_batch())

Advanced Usage

Data Contracts for Enhanced Explanations

One of SHAPXplain's most powerful features is the ability to provide domain-specific context through the additional_context parameter, effectively creating a "data contract" that guides the LLM:

explanation = llm_explainer.explain(
    shap_values=class_shap_values,
    data_point=data_point,
    prediction=prediction,
    additional_context={
        "domain": "medical_diagnosis",
        "feature_descriptions": {
            "glucose": "Blood glucose level in mg/dL. Normal range: 70-99 mg/dL fasting",
            "blood_pressure": "Systolic blood pressure in mmHg. Normal range: <120 mmHg",
            "bmi": "Body Mass Index. Normal range: 18.5-24.9"
        },
        "reference_ranges": {
            "glucose": {"low": "<70", "normal": "70-99", "prediabetes": "100-125", "diabetes": ">126"},
            "blood_pressure": {"normal": "<120", "elevated": "120-129", "stage1": "130-139", "stage2": ">=140"}
        },
        "measurement_units": {
            "glucose": "mg/dL",
            "blood_pressure": "mmHg",
            "bmi": "kg/m²"
        },
        "patient_context": "65-year-old male with family history of type 2 diabetes"
    }
)

Error Handling

try:
    explanation = llm_explainer.explain(
        shap_values=class_shap_values,
        data_point=data_point,
        prediction=prediction
    )
except ValueError as e:
    print(f"Input validation error: {e}")
except RuntimeError as e:
    print(f"LLM query error: {e}")

Contributing

Contributions are welcome! Please feel free to submit a Pull Request.

Documentation

For detailed documentation, tutorials, and API reference, visit: https://shapxplain.readthedocs.io/

Development

To set up the development environment:

# Clone the repo
git clone https://github.com/mpearmain/shapxplain.git
cd shapxplain

# Install with development dependencies
poetry install --with dev

Run tests:

poetry run pytest

Format code:

poetry run black src tests
poetry run ruff check --fix src tests

License

This project is licensed under the MIT License - see the LICENSE file for details.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

shapxplain-0.1.0.tar.gz (18.5 kB view details)

Uploaded Source

Built Distribution

shapxplain-0.1.0-py3-none-any.whl (18.0 kB view details)

Uploaded Python 3

File details

Details for the file shapxplain-0.1.0.tar.gz.

File metadata

  • Download URL: shapxplain-0.1.0.tar.gz
  • Upload date:
  • Size: 18.5 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: poetry/1.7.1 CPython/3.12.9 Linux/6.8.0-1021-azure

File hashes

Hashes for shapxplain-0.1.0.tar.gz
Algorithm Hash digest
SHA256 9dd6d457a60549d5ff3a9073f4a4d6e2220ba2b0902eacdf481cc9495fcb60cf
MD5 d2aac306df5ce3ffd8250d55b8bf8c41
BLAKE2b-256 9a94dc7bd649d465e3ed95870e85e11a06aa6d5c675221d2c97d01cef10fed3e

See more details on using hashes here.

File details

Details for the file shapxplain-0.1.0-py3-none-any.whl.

File metadata

  • Download URL: shapxplain-0.1.0-py3-none-any.whl
  • Upload date:
  • Size: 18.0 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: poetry/1.7.1 CPython/3.12.9 Linux/6.8.0-1021-azure

File hashes

Hashes for shapxplain-0.1.0-py3-none-any.whl
Algorithm Hash digest
SHA256 9cdfdee8af3533a2889975ec881a36cdfa1deedfdc32b5e05bffa8620b8d33a0
MD5 a2230dd983837b11729cd278aeb0d859
BLAKE2b-256 707fb8d02a8dac6deb4aefbe99c33d737bcf6519122b3bfe89accbdf6d78cb53

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page