Skip to main content

Comprehensive benchmarking framework for hypergraph learning with PyTorch Lightning integration

Project description

pyg-hyper-bench

Python 3.12+ PyTorch 2.9+ License: MIT

Comprehensive benchmarking framework for hypergraph learning with statistical evaluation and multi-run support.

Overview

pyg-hyper-bench provides a standardized framework for evaluating hypergraph neural networks with:

  • ๐Ÿ“Š Statistical Evaluation: Multi-run evaluation with mean, std, and 95% confidence intervals
  • ๐ŸŽฏ Standardized Protocols: Node classification, link prediction, clustering, SSL evaluation
  • ๐Ÿ”„ Reproducibility: Seed management for consistent results
  • ๐Ÿงช Comprehensive Testing: 55+ tests with real datasets
  • ๐Ÿš€ Easy to Use: Simple API for single-run and multi-run evaluation

Architecture

pyg-hyper-bench follows a clean separation of concerns:

pyg-hyper-bench/
โ”œโ”€โ”€ protocols/              # Evaluation protocols
โ”‚   โ”œโ”€โ”€ base.py            # BenchmarkProtocol (abstract base)
โ”‚   โ”œโ”€โ”€ node_classification.py
โ”‚   โ”œโ”€โ”€ link_prediction.py
โ”‚   โ”œโ”€โ”€ clustering.py
โ”‚   โ””โ”€โ”€ ssl_linear_evaluation.py  # SSL linear evaluation
โ””โ”€โ”€ evaluators/            # Evaluation engines
    โ”œโ”€โ”€ single_run.py      # Single-run evaluator
    โ””โ”€โ”€ multi_run.py       # Multi-run with statistics

Design Principles:

  • pyg-hyper-data: Datasets + Split utilities (data layer)
  • pyg-hyper-bench: Evaluation protocols + Evaluators (evaluation layer)

Installation

Requirements

  • Python 3.12+
  • uv (recommended) or pip
  • PyTorch 2.9+ with CUDA 12.6 (optional, for GPU acceleration)

Install from source

# Clone the repository
git clone https://github.com/nishide-dev/pyg-hyper-bench.git
cd pyg-hyper-bench

# Create virtual environment and install dependencies
uv venv
uv sync

# Activate the virtual environment
source .venv/bin/activate  # Linux/macOS
# or
.venv\Scripts\activate  # Windows

Quick Start

Single-Run Evaluation

from pyg_hyper_bench import SingleRunEvaluator, NodeClassificationProtocol
from pyg_hyper_data.datasets import CoraCocitation

# Load dataset
dataset = CoraCocitation()

# Create protocol
protocol = NodeClassificationProtocol(
    split_type="transductive",
    stratified=True,
    seed=42
)

# Create evaluator
evaluator = SingleRunEvaluator(dataset, protocol, device="cpu")

# Get data splits
split = evaluator.get_split()
train_mask = split["train_mask"]
val_mask = split["val_mask"]
test_mask = split["test_mask"]
data = split["data"]

# Train your model
model = YourHypergraphModel(...)
# ... training code ...

# Evaluate
model.eval()
with torch.no_grad():
    predictions = model(data.x, data.hyperedge_index)
    test_metrics = evaluator.evaluate(
        predictions[test_mask],
        data.y[test_mask],
        split="test"
    )

print(f"Test Accuracy: {test_metrics['accuracy']:.4f}")
print(f"Test F1-macro: {test_metrics['f1_macro']:.4f}")

Multi-Run Evaluation with Statistics

from pyg_hyper_bench import MultiRunEvaluator, NodeClassificationProtocol
from pyg_hyper_data.datasets import CoraCocitation
import torch.nn as nn

# Load dataset
dataset = CoraCocitation()

# Create protocol
protocol = NodeClassificationProtocol(
    split_type="transductive",
    stratified=True,
    seed=42
)

# Create multi-run evaluator
evaluator = MultiRunEvaluator(
    dataset=dataset,
    protocol=protocol,
    n_runs=10,  # Run 10 times with different seeds
    device="cpu",
    verbose=True
)

# Model factory (creates fresh model for each run)
def model_fn(seed):
    torch.manual_seed(seed)
    return YourHypergraphModel(
        in_channels=dataset.num_node_features,
        out_channels=dataset.num_classes
    )

# Training function
def train_fn(model, data, train_mask, val_mask):
    optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
    criterion = nn.CrossEntropyLoss()

    for epoch in range(100):
        model.train()
        optimizer.zero_grad()
        out = model(data.x, data.hyperedge_index)
        loss = criterion(out[train_mask], data.y[train_mask])
        loss.backward()
        optimizer.step()

    return model

# Run evaluation
results = evaluator.run_evaluation(
    model_fn=model_fn,
    train_fn=train_fn,
    splits=["val", "test"]
)

# Print results with statistics
print(results["test"].summary_table(split="test"))

Output:

## Test Results (n=10 runs)

| Metric | Mean ยฑ Std | 95% CI |
|--------|------------|--------|
| accuracy | 0.8550 ยฑ 0.0081 | [0.8305, 0.8796] |
| f1_macro | 0.8434 ยฑ 0.0076 | [0.8204, 0.8665] |
| f1_micro | 0.8550 ยฑ 0.0081 | [0.8305, 0.8796] |

Features

Evaluation Protocols

NodeClassificationProtocol

Standard protocol for node classification tasks:

protocol = NodeClassificationProtocol(
    train_ratio=0.6,      # 60% training
    val_ratio=0.2,        # 20% validation
    test_ratio=0.2,       # 20% testing
    split_type="transductive",  # or "inductive"
    stratified=True,      # Maintain class balance
    seed=42
)

Features:

  • Transductive or inductive splits
  • Stratified or random sampling
  • Metrics: accuracy, F1-macro, F1-micro
  • Based on standard benchmarking setups (HyperGCN, AllSet, ED-HNN)

LinkPredictionProtocol

Hyperedge link prediction protocol:

protocol = LinkPredictionProtocol(
    train_ratio=0.7,      # 70% training
    val_ratio=0.1,        # 10% validation
    test_ratio=0.2,       # 20% testing
    negative_sampling_ratio=1.0,  # 1:1 positive:negative ratio
    seed=42
)

# Split data
split = protocol.split_data(data)
train_data = split["train_data"]  # Graph for training
val_pos = split["val_pos_edge"]   # Positive validation samples
val_neg = split["val_neg_edge"]   # Negative validation samples

# Train model and get scores
pos_scores = model.predict(val_pos)
neg_scores = model.predict(val_neg)

# Evaluate
metrics = protocol.evaluate(pos_scores, neg_scores)
print(f"AUC: {metrics['auc']:.4f}")
print(f"MRR: {metrics['mrr']:.4f}")
print(f"Hits@10: {metrics['hits@10']:.4f}")

Features:

  • Binary classification: real hyperedge vs random node set
  • Metrics: AUC, AP, MRR, Hits@10/50/100
  • Configurable negative sampling ratio
  • Based on HyperGCN link prediction task

ClusteringProtocol

Unsupervised clustering evaluation:

protocol = ClusteringProtocol(
    seed=42,
    n_clusters=7  # or None for auto-detection from labels
)

# Train model to learn embeddings (unsupervised)
embeddings = model.encode(data)  # [num_nodes, embedding_dim]

# Evaluate clustering quality
metrics = protocol.evaluate(embeddings, data.y)
print(f"NMI: {metrics['nmi']:.4f}")
print(f"ARI: {metrics['ari']:.4f}")
print(f"AMI: {metrics['ami']:.4f}")

Features:

  • Unsupervised evaluation (labels only for evaluation, not training)
  • K-Means clustering on learned embeddings
  • Metrics: NMI (Normalized Mutual Information), ARI (Adjusted Rand Index), AMI (Adjusted Mutual Information)
  • Auto-detection of number of clusters from labels

SSLLinearEvaluationProtocol

Linear evaluation protocol for self-supervised learning (SSL) methods:

from pyg_hyper_bench import SSLLinearEvaluationProtocol

# Create protocol for node classification task
protocol = SSLLinearEvaluationProtocol(
    task="node_classification",  # or "hyperedge_prediction"
    classifier_type="logistic_regression",  # or "mlp"
    classifier_epochs=200,
    seed=42
)

# Split data (SSL pre-training does NOT use labels)
split = protocol.split_data(data)

# Get frozen embeddings from SSL model (trained separately)
model.eval()
with torch.no_grad():
    embeddings = model.get_embeddings(data)

# Linear evaluation (train linear classifier on frozen embeddings)
metrics = protocol.evaluate(
    embeddings=embeddings,
    labels=data.y,
    train_mask=split["train_mask"],
    val_mask=split["val_mask"],
    test_mask=split["test_mask"],
)

print(f"Test Accuracy: {metrics['test_accuracy']:.4f}")

Features:

  • Two evaluation tasks:
    • task="node_classification": Multi-class node classification
    • task="hyperedge_prediction": Binary node-hyperedge membership prediction
  • Selectable classifiers: Logistic Regression (sklearn) or MLP (PyTorch)
  • Frozen embeddings: Evaluates representation quality without fine-tuning
  • Metrics:
    • Node classification: accuracy, F1-macro, F1-micro
    • Hyperedge prediction: AUC, AP (Average Precision)
  • Based on TriCL (AAAI'23) and HypeBoy (KDD'23) evaluation protocols

Important: SSL pre-training is done separately (e.g., in pyg-hyper-ssl). This protocol only evaluates the learned representations.

Evaluators

SingleRunEvaluator

For single-run evaluation:

evaluator = SingleRunEvaluator(dataset, protocol, device="cpu")
split = evaluator.get_split()
metrics = evaluator.evaluate(predictions, targets, split="test")

MultiRunEvaluator

For multi-run evaluation with statistical aggregation:

evaluator = MultiRunEvaluator(
    dataset=dataset,
    protocol=protocol,
    n_runs=10,
    seeds=[42, 43, 44, ...],  # Optional custom seeds
    device="cpu",
    verbose=True
)

results = evaluator.run_evaluation(model_fn, train_fn, splits=["val", "test"])

Statistical Measures:

  • Mean across all runs
  • Standard deviation
  • 95% confidence interval (using t-distribution)
  • Raw results from each run

Reproducibility

All evaluators support seed management for reproducibility:

# Same seed = same results
protocol1 = NodeClassificationProtocol(seed=42)
protocol2 = NodeClassificationProtocol(seed=42)
split1 = protocol1.split_data(data)
split2 = protocol2.split_data(data)
# split1 == split2 โœ“

# Multi-run with custom seeds
evaluator = MultiRunEvaluator(
    dataset=dataset,
    protocol=protocol,
    n_runs=5,
    seeds=[0, 10, 20, 30, 40]  # Custom seeds
)

Verified Performance

Integration tests with real datasets and models:

Dataset: CoraCocitation (1,434 nodes, 1,579 hyperedges, 7 classes)

Model: Simple Hypergraph GNN (64 hidden dimensions, 10 epochs)

Results:

  • Single run: 83.22% test accuracy
  • Multi-run (n=3): 85.50% ยฑ 0.81% test accuracy (95% CI: [83.05%, 87.96%])

See tests/test_integration.py for full integration tests.

Development

Pre-commit hooks

This project uses pre-commit hooks to ensure code quality:

# Install pre-commit hooks
uv run pre-commit install
uv run pre-commit install --hook-type commit-msg

# Run hooks manually on all files
uv run pre-commit run --all-files

# Run hooks on staged files (happens automatically on git commit)
git commit -m "Your message"

Hooks:

  • ruff lint --fix: Auto-fix linting issues
  • ruff format: Format code
  • ty check: Type checking (runs on entire project)

Running tests

# Run all tests
uv run pytest

# Run with verbose output
uv run pytest -v

# Run specific test file
uv run pytest tests/test_integration.py

# Run with coverage
uv run pytest --cov=src --cov-report=html

Test Coverage:

  • Unit tests: 10 tests for MultiRunEvaluator, 12 tests for LinkPrediction, 12 tests for Clustering, 13 tests for SSL
  • Integration tests: 6 tests with real datasets and models
  • Total: 55 tests (53 passed, 2 skipped)

Code quality

# Format code
uv run ruff format .

# Lint code
uv run ruff check .

# Type checking
uv run ty check

Adding dependencies

# Add runtime dependency
uv add <package-name>

# Add development dependency
uv add --dev <package-name>

# Update all dependencies
uv lock --upgrade

Project Structure

pyg-hyper-bench/
โ”œโ”€โ”€ src/pyg_hyper_bench/
โ”‚   โ”œโ”€โ”€ __init__.py
โ”‚   โ”œโ”€โ”€ protocols/              # Evaluation protocols
โ”‚   โ”‚   โ”œโ”€โ”€ __init__.py
โ”‚   โ”‚   โ”œโ”€โ”€ base.py            # BenchmarkProtocol (abstract)
โ”‚   โ”‚   โ”œโ”€โ”€ node_classification.py
โ”‚   โ”‚   โ”œโ”€โ”€ link_prediction.py
โ”‚   โ”‚   โ”œโ”€โ”€ clustering.py
โ”‚   โ”‚   โ””โ”€โ”€ ssl_linear_evaluation.py  # SSL linear evaluation
โ”‚   โ””โ”€โ”€ evaluators/             # Evaluation engines
โ”‚       โ”œโ”€โ”€ __init__.py
โ”‚       โ”œโ”€โ”€ single_run.py      # Single-run evaluator
โ”‚       โ””โ”€โ”€ multi_run.py       # Multi-run evaluator
โ”œโ”€โ”€ tests/
โ”‚   โ”œโ”€โ”€ test_multi_run_evaluator.py  # Unit tests
โ”‚   โ”œโ”€โ”€ test_link_prediction.py      # Link prediction tests
โ”‚   โ”œโ”€โ”€ test_clustering.py           # Clustering tests
โ”‚   โ”œโ”€โ”€ test_ssl_linear_evaluation.py # SSL evaluation tests
โ”‚   โ””โ”€โ”€ test_integration.py          # Integration tests
โ”œโ”€โ”€ docs/
โ”‚   โ””โ”€โ”€ DESIGN.md              # Detailed design document
โ”œโ”€โ”€ pyproject.toml             # Project configuration
โ””โ”€โ”€ README.md                  # This file

Dependencies

Core:

  • pyg-hyper-data: Dataset and data utilities
  • PyTorch: Deep learning framework
  • torch-scatter: Scatter operations for hypergraph aggregation
  • NumPy: Numerical computing
  • SciPy: Statistical functions
  • pandas: Data manipulation
  • scikit-learn: Machine learning utilities (classifiers, metrics)
  • tqdm: Progress bars

Development:

  • pytest: Testing framework
  • ruff: Linter and formatter
  • ty: Type checker

Related Projects

Citation

If you use this package in your research, please cite:

@software{pyg_hyper_bench,
  title = {pyg-hyper-bench: Benchmarking Framework for Hypergraph Learning},
  author = {Nishide},
  year = {2025},
  url = {https://github.com/nishide-dev/pyg-hyper-bench}
}

License

MIT License - see LICENSE file for details.

Contributing

Contributions are welcome! Please feel free to submit a Pull Request.

  1. Fork the repository
  2. Create your feature branch (git checkout -b feature/amazing-feature)
  3. Commit your changes (git commit -m 'Add amazing feature')
  4. Push to the branch (git push origin feature/amazing-feature)
  5. Open a Pull Request

Acknowledgments

This project follows best practices from:

  • TriCL (AAAI'23): Multi-seed evaluation with mean ยฑ std reporting, LogisticRegression for linear evaluation
  • HyperGCL: Logger-based multi-run tracking
  • HypeBoy (KDD'23): 20-split evaluation with statistical aggregation, MLP for linear evaluation, hyperedge prediction task

Built with:


Generated with โค๏ธ for reproducible hypergraph learning research

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

pyg_hyper_bench-0.1.1.tar.gz (29.3 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

pyg_hyper_bench-0.1.1-py3-none-any.whl (26.1 kB view details)

Uploaded Python 3

File details

Details for the file pyg_hyper_bench-0.1.1.tar.gz.

File metadata

  • Download URL: pyg_hyper_bench-0.1.1.tar.gz
  • Upload date:
  • Size: 29.3 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: uv/0.11.7 {"installer":{"name":"uv","version":"0.11.7","subcommand":["publish"]},"python":null,"implementation":{"name":null,"version":null},"distro":{"name":"Ubuntu","version":"24.04","id":"noble","libc":null},"system":{"name":null,"release":null},"cpu":null,"openssl_version":null,"setuptools_version":null,"rustc_version":null,"ci":true}

File hashes

Hashes for pyg_hyper_bench-0.1.1.tar.gz
Algorithm Hash digest
SHA256 ec8d09e5a1ef50fd4a790b13a67d7dd57c6bf95611f8366d692abcabd3b3f073
MD5 bc4aa548f933c5ba43018faa890301b8
BLAKE2b-256 0fb504cec7c32a36494095e127ec6da7f413cdd16c6f8828a0b67a7f9f3bfa1c

See more details on using hashes here.

File details

Details for the file pyg_hyper_bench-0.1.1-py3-none-any.whl.

File metadata

  • Download URL: pyg_hyper_bench-0.1.1-py3-none-any.whl
  • Upload date:
  • Size: 26.1 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: uv/0.11.7 {"installer":{"name":"uv","version":"0.11.7","subcommand":["publish"]},"python":null,"implementation":{"name":null,"version":null},"distro":{"name":"Ubuntu","version":"24.04","id":"noble","libc":null},"system":{"name":null,"release":null},"cpu":null,"openssl_version":null,"setuptools_version":null,"rustc_version":null,"ci":true}

File hashes

Hashes for pyg_hyper_bench-0.1.1-py3-none-any.whl
Algorithm Hash digest
SHA256 851857f2f4249abf7eaf725efe155609328b3928e7e9ab42368a214a4b7f3f0e
MD5 e81bcae26c03ca9022a5786eada69a27
BLAKE2b-256 41abb2bfd1d6d7f44eb7b156f31075efa5053a87437586d2dc5c1b2aa7a68cde

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page