Skip to main content

MEGAN: Multi Explanation Graph Attention Network

Project description

made-with-python made-with-pytorch python-version os-linux

Architecture Overview

👩‍🏫 MEGAN: Multi Explanation Graph Attention Student

Abstract. Explainable artificial intelligence (XAI) methods are expected to improve trust during human-AI interactions, provide tools for model analysis and extend human understanding of complex problems. Attention-based models are an important subclass of XAI methods, partly due to their full differentiability and the potential to improve explanations by means of explanation-supervised training. We propose the novel multi-explanation graph attention network (MEGAN). Our graph regression and classification model features multiple explanation channels, which can be chosen independently of the task specifications. We first validate our model on a synthetic graph regression dataset, where our model produces single-channel explanations with quality similar to GNNExplainer. Furthermore, we demonstrate the advantages of multi-channel explanations on one synthetic and two real-world datasets: The prediction of water solubility of molecular graphs and sentiment classification of movie reviews. We find that our model produces explanations consistent with human intuition, opening the way to learning from our model in less well-understood tasks.

🔔 News

📦 Package Dependencies

  • The package is designed to run in an environment 3.8 <= python <= 3.13.

  • A graphics card with CUDA support (cuDNN) is recommended for model training.

  • A Linux operating system is recommended for development.

📦 Installation by Package

The package is also published as a library on PyPi and can be installed like this:

uv pip install graph_attention_student

📦 Installation from Source

Clone the repository from github:

git clone https://github.com/aimat-lab/graph_attention_student

Then in the main folder run a pip install:

cd graph_attention_student
uv pip install -e .

🚀 Quickstart

The fastest way to train a MEGAN model is using the built-in experiment scripts. Prepare a CSV file with SMILES strings and target values, then run:

# Clone and install
git clone https://github.com/aimat-lab/graph_attention_student
cd graph_attention_student
uv pip install -e .

# Train a regression model
python graph_attention_student/experiments/train_model__megan.py \
    --CSV_FILE_PATH='"/path/to/your/data.csv"' \
    --VALUE_COLUMN_NAME='"smiles"' \
    --TARGET_COLUMN_NAMES='["target"]' \
    --DATASET_TYPE='"regression"' \
    --EPOCHS=150

Your CSV should have a smiles column and your target column(s):

smiles,target
CCO,1.23
CCN,2.45
CCC,0.89

Key parameters: CSV_FILE_PATH (path to data), TARGET_COLUMN_NAMES (prediction target), DATASET_TYPE (‘regression’ or ‘classification’). See train_model__megan.py --help for all options.

📄 Config Files

Instead of passing parameters on the command line, you can create a YAML config file:

# config.yml
extend: train_model__megan.py
parameters:
  CSV_FILE_PATH: /path/to/your/data.csv
  TARGET_COLUMN_NAMES:
    - target
  VALUE_COLUMN_NAME: smiles
  DATASET_TYPE: regression
  EPOCHS: 100
  BATCH_SIZE: 64
  LEARNING_RATE: 0.0001

Then run the experiment with:

pycomex run graph_attention_student/experiments/config.yml

💻 Command Line Interface

For quick predictions, use the megan CLI:

# Train from CSV
megan train dataset.csv

# Make predictions with explanations
# Optionally pass the path to a model checkpoint to use for the prediction.
megan predict "CCO"

Use megan --help for all options.

🤖 Python API

For custom workflows, use the Python API directly:

import pytorch_lightning as pl
from torch_geometric.loader import DataLoader
from visual_graph_datasets.processing.molecules import MoleculeProcessing
from graph_attention_student import Megan, SmilesDataset

# Setup
processing = MoleculeProcessing()
dataset = SmilesDataset(
    dataset="data.csv",
    smiles_column='smiles',
    target_columns=['target'],
    processing=processing,
)
loader = DataLoader(dataset, batch_size=64)

# Create model
model = Megan(
    node_dim=processing.get_num_node_attributes(),
    edge_dim=processing.get_num_edge_attributes(),
    units=[64, 64, 64],
    final_units=[64, 32, 1],
    prediction_mode='regression',
    importance_factor=1.0,
    importance_mode='regression',
)

# Train
trainer = pl.Trainer(max_epochs=150, accelerator='auto')
trainer.fit(model, train_dataloaders=loader)
model.eval()
model.save("model.ckpt")

Loading and Using Models:

from graph_attention_student import Megan
from graph_attention_student.torch.advanced import megan_prediction_report

model = Megan.load("model.ckpt")
model.eval()

# Make prediction
results = model.forward_graph(processing.process("CCO"))
print(f"Prediction: {results['graph_output'].item():.3f}")

# Generate explanation PDF
megan_prediction_report(
    value="CCO",
    model=model,
    processing=processing,
    output_path="report.pdf"
)

🔍 Examples

The following examples show some of the cherry picked examples that show the explanatory capabilities of the model.

RB-Motifs Dataset

This is a synthetic dataset, which basically consists of randomly generated graphs with nodes of different colors. Some of the graphs contain special sub-graph motifs, which are either blue-heavy or red-heavy structures. The blue-heavy sub-graphs contribute a certain negative value to the overall value of the graph, while red-heavy structures contain a certain positive value.

This way, every graph has a certain value associated with it, which is between -3 and 3. The network was trained to predict this value for each graph.

Rb-Motifs Example

The examples shows from left to right: (1) The ground truth explanations, (2) a baseline MEGAN model trained only on the prediction task, (3) explanation-supervised MEGAN model and (4) GNNExplainer explanations for a basic GCN network. While the baseline MEGAN and GNNExplainer focus only on one of the ground truth motifs, the explanation-supervised MEGAN model correctly finds both.

Water Solubility Dataset

This is the AqSolDB dataset, which consists of ~10000 molecules and measured values for the solubility in water (logS value).

The network was trained to predict the solubility value for each molecule.

Solubility Example.png

Movie Reviews

Originally the MovieReviews dataset is a natural language processing dataset from the ERASER benchmark. The task is to classify the sentiment of ~2000 movie reviews collected from the IMDB database into the classes “positive” and “negative”. This dataset was converted into a graph dataset by considering all words as nodes of a graph and then connecting adjacent words by undirected edges with a sliding window of size 2. Words were converted into numeric feature vectors by using a pre-trained GLOVE model.

Example for a positive review:

Positive Movie Review

Example for a negative review:

Negative Movie Review

Examples show the explanation channel for the “negative” class left and the “positive” class right. Sentences with negative / positive adjectives are appropriately attributed to the corresponding channels.

📖 Referencing

If you use, extend or otherwise mention or work, please cite the paper as follows:

@article{teufel2023megan
    title={MEGAN: Multi-Explanation Graph Attention Network},
    author={Teufel, Jonas and Torresi, Luca and Reiser, Patrick and Friederich, Pascal},
    journal={xAI 2023},
    year={2023},
    doi={10.1007/978-3-031-44067-0_18},
    url="\url{https://link.springer.com/chapter/10.1007/978-3-031-44067-0_18\}",
}

Credits

  • PyTorch Lightning provides the high-level training framework that powers the modern MEGAN implementation, offering easy GPU acceleration, distributed training, and experiment management.

  • PyTorch Geometric supplies the fundamental graph neural network building blocks and efficient graph data handling that enable MEGAN’s attention mechanisms and message passing operations.

  • VisualGraphDataset is a library which aims to establish a special dataset format specifically for graph XAI applications with the aim of streamlining the visualization of graph explanations and to make them more comparable by packaging canonical graph visualizations directly with the dataset.

  • PyComex is a micro framework which simplifies the setup, processing and management of computational experiments. It is also used to auto-generate the command line interface that can be used to interact with these experiments.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

graph_attention_student-1.4.0.tar.gz (6.8 MB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

graph_attention_student-1.4.0-py3-none-any.whl (2.7 MB view details)

Uploaded Python 3

File details

Details for the file graph_attention_student-1.4.0.tar.gz.

File metadata

File hashes

Hashes for graph_attention_student-1.4.0.tar.gz
Algorithm Hash digest
SHA256 e5b7b363c9322eb80f443d3e771d38321a9bf042068a66d13cebc510ade6bb3d
MD5 b7a890cf757901610a076072c5305d79
BLAKE2b-256 b69916520c244244e7b1c4f695b455ad42e0cd8a10176b800e72b32d31eeda3d

See more details on using hashes here.

File details

Details for the file graph_attention_student-1.4.0-py3-none-any.whl.

File metadata

File hashes

Hashes for graph_attention_student-1.4.0-py3-none-any.whl
Algorithm Hash digest
SHA256 17cb4379064b4f47c773f30d3067f75244b1d94df886781f6c028e1083c4e60a
MD5 59aa4b76a43e0311400167615fc537e2
BLAKE2b-256 4421a18da142f1b3c315788dad5850713d0222ce5ceb010b6446791a66f301ab

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page