Skip to main content

Self-supervised learning methods for hypergraphs

Project description

pyg-hyper-ssl

Self-supervised learning methods for hypergraphs built on PyTorch Geometric.

Tests Python 3.12+ PyTorch 2.0+ Code style: ruff

Overview

pyg-hyper-ssl provides state-of-the-art self-supervised learning (SSL) methods for hypergraphs. Built on top of pyg-hyper-nn and pyg-hyper-data, this library implements cutting-edge SSL algorithms from recent research papers.

Key Features

  • ๐ŸŽฏ State-of-the-art SSL Methods: TriCL (AAAI'23), and more coming soon
  • ๐Ÿงฉ Modular Design: Extensible base classes for methods, augmentations, and losses
  • ๐Ÿ”„ Rich Augmentations: Structural (edge drop) and attribute (feature mask) augmentations
  • ๐Ÿš€ Production Ready: Comprehensive tests (83% coverage), type hints, and documentation
  • ๐Ÿ”— Seamless Integration: Works with all 19 models from pyg-hyper-nn
  • โšก Optimized: Built on PyTorch Geometric for efficient graph operations

Installation

Prerequisites

This package requires PyTorch Geometric to be installed. Install it first:

pip install torch torch-geometric

For GPU support with CUDA 12.6:

pip install torch --index-url https://download.pytorch.org/whl/cu126
pip install torch-geometric

From PyPI (Recommended)

pip install pyg-hyper-ssl

From Source

git clone https://github.com/nishide-dev/pyg-hyper-ssl.git
cd pyg-hyper-ssl
uv sync  # or pip install -e .

Quick Start

TriCL: Tri-directional Contrastive Learning

import torch
from pyg_hyper_data.datasets import CoraCocitation
from pyg_hyper_ssl.methods.contrastive import TriCL, TriCLEncoder
from pyg_hyper_ssl.augmentations import EdgeDrop, FeatureMask

# Load dataset
dataset = CoraCocitation()
data = dataset[0]

# Create TriCL model
encoder = TriCLEncoder(
    in_dim=data.num_node_features,
    edge_dim=128,
    node_dim=256,
    num_layers=2
)
model = TriCL(
    encoder=encoder,
    proj_dim=256,
    node_tau=0.5,
    edge_tau=0.5,
    membership_tau=0.1
)

# Create augmentations
aug1 = EdgeDrop(drop_prob=0.2)
aug2 = FeatureMask(mask_prob=0.3)

# Training loop
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

model.train()
for epoch in range(100):
    # Apply augmentations
    data_aug1 = aug1(data)
    data_aug2 = aug2(data)

    # Compute loss
    loss = model.train_step(data_aug1, data_aug2)

    # Optimize
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    if epoch % 10 == 0:
        print(f"Epoch {epoch:03d}, Loss: {loss.item():.4f}")

# Get embeddings for downstream tasks
model.eval()
embeddings = model.get_embeddings(data)
print(f"Embeddings shape: {embeddings.shape}")

Using with pyg-hyper-nn Models

from pyg_hyper_nn.models import HGNN, UniGNN, HyperGCN
from pyg_hyper_ssl.encoders import EncoderWrapper
from pyg_hyper_ssl.methods.base import BaseSSLMethod

# Wrap any pyg-hyper-nn model for SSL
encoder = EncoderWrapper(
    model_class=HGNN,  # or UniGNN, HyperGCN, etc.
    in_channels=128,
    hidden_channels=256,
    num_layers=2,
    use_projection=True  # Add projection head for contrastive learning
)

# Use in your custom SSL method
class MySSLMethod(BaseSSLMethod):
    def forward(self, data, data_aug):
        h1, z1 = encoder.forward_from_data(data)
        h2, z2 = encoder.forward_from_data(data_aug)
        return (h1, z1), (h2, z2)

Implementation Accuracy

All implementations are verified against official reference implementations.

We have carefully compared our implementations with the original papers' reference code to ensure accuracy. See IMPLEMENTATION_ACCURACY.md for detailed verification results.

Key verification points:

  • โœ… FeatureMask: Dimension-wise masking (matches reference exactly)
  • โœ… EdgeDrop: Sparse matrix approach (matches reference + improvements)
  • โœ… TriCL: All three loss levels verified (31 tests)
  • โœ… HyperGCL: InfoNCE loss verified (20 tests)
  • โœ… HypeBoy: Two-stage generative SSL verified (20 tests)
  • โœ… SE-HSSL: Fairness-aware components verified (29 tests)

Total: 119 tests, 34 accuracy verification tests

Implemented Methods

TriCL (AAAI 2023)

Tri-directional Contrastive Learning for Hypergraphs

TriCL performs contrastive learning at three levels:

  1. Node-level: Contrast node embeddings across augmented views
  2. Group-level: Contrast hyperedge embeddings across views
  3. Membership-level: Contrast node-hyperedge relationships
from pyg_hyper_ssl.methods.contrastive import TriCL, TriCLEncoder

encoder = TriCLEncoder(in_dim=128, edge_dim=256, node_dim=512, num_layers=3)
model = TriCL(
    encoder=encoder,
    proj_dim=512,
    lambda_n=1.0,    # Node-level loss weight
    lambda_e=1.0,    # Group-level loss weight
    lambda_m=1.0,    # Membership-level loss weight
)

Reference: Huang et al. "Contrastive Learning Meets Homophily: Two Birds with One Stone" AAAI 2023.

HyperGCL (NeurIPS 2022)

Contrastive Learning for Hypergraphs with Fabricated Augmentations

HyperGCL performs node-level contrastive learning using InfoNCE loss. It works with any hypergraph encoder and uses fabricated augmentations (edge drop, feature mask, etc.).

from pyg_hyper_nn.models import HGNN
from pyg_hyper_ssl.encoders import EncoderWrapper
from pyg_hyper_ssl.methods.contrastive import HyperGCL

# Wrap any pyg-hyper-nn model
encoder = EncoderWrapper(
    model_class=HGNN,
    in_channels=128,
    hidden_channels=256,
    num_layers=2
)

model = HyperGCL(
    encoder=encoder,
    proj_hidden=256,
    proj_out=256,
    tau=0.5,  # Temperature for InfoNCE loss
)

Reference: Wei et al. "Augmentations in Hypergraph Contrastive Learning: Fabricated and Generative" NeurIPS 2022.

HypeBoy (ICLR 2024)

Generative Self-Supervised Learning on Hypergraphs

HypeBoy performs two-stage generative SSL:

  1. Feature Reconstruction: Mask and reconstruct node features using a decoder
  2. Hyperedge Filling: Predict missing nodes in hyperedges via contrastive loss
from pyg_hyper_nn.models import HGNN
from pyg_hyper_ssl.encoders import EncoderWrapper
from pyg_hyper_ssl.methods.generative import HypeBoy, HypeBoyDecoder

# Encoder for both stages
encoder = EncoderWrapper(
    model_class=HGNN,
    in_channels=128,
    hidden_channels=64,
    num_layers=2
)

# Decoder for feature reconstruction
decoder = HypeBoyDecoder(
    encoder=encoder,
    in_dim=64,
    out_dim=128,
    hidden_dim=64,
    num_layers=2
)

model = HypeBoy(
    encoder=encoder,
    decoder=decoder,
    feature_recon_epochs=300,      # Stage 1 epochs
    hyperedge_fill_epochs=200,     # Stage 2 epochs
    feature_mask_prob=0.5,         # Feature masking probability
    edge_drop_prob_stage2=0.9      # Edge dropping probability (stage 2)
)

Reference: Kim et al. "HypeBoy: Generative Self-Supervised Representation Learning on Hypergraphs" ICLR 2024.

Augmentations

Structural Augmentations

from pyg_hyper_ssl.augmentations import EdgeDrop

# Randomly drop hyperedges
aug = EdgeDrop(drop_prob=0.2)  # Drop 20% of hyperedges
data_aug = aug(data)

Attribute Augmentations

from pyg_hyper_ssl.augmentations import FeatureMask

# Randomly mask node features
aug = FeatureMask(mask_prob=0.3)  # Mask 30% of features
data_aug = aug(data)

Composition

from pyg_hyper_ssl.augmentations import ComposedAugmentation, RandomChoice

# Sequential composition
aug = ComposedAugmentation([
    EdgeDrop(drop_prob=0.2),
    FeatureMask(mask_prob=0.3)
])

# Random choice
aug = RandomChoice([
    EdgeDrop(drop_prob=0.2),
    FeatureMask(mask_prob=0.3)
])

Loss Functions

Contrastive Losses

from pyg_hyper_ssl.losses import InfoNCE, NTXent, CosineSimilarityLoss

# InfoNCE loss (SimCLR-style)
loss_fn = InfoNCE(temperature=0.5)

# NT-Xent (alias for InfoNCE)
loss_fn = NTXent(temperature=0.5)

# Simple cosine similarity
loss_fn = CosineSimilarityLoss()

Fairness-Aware Losses

from pyg_hyper_ssl.losses import CCALoss, orthogonal_projection, balance_hyperedges

# CCA Loss for fairness-aware SSL (SE-HSSL)
cca_loss = CCALoss(lambda_decorr=0.005)
loss = cca_loss(z1, z2)  # Maximize correlation between views

# Orthogonal projection for debiasing
debias_x = orthogonal_projection(x, sens_idx=0)  # Remove bias from sensitive attribute

# Balance hyperedge group representation
balanced_edge_index = balance_hyperedges(
    hyperedge_index,
    node_groups=[0, 0, 1, 1, 0],  # Binary group labels
    beta=1.0  # Balance strength
)

Composite Losses

from pyg_hyper_ssl.losses import CompositeLoss, InfoNCE

# Combine multiple losses with weights
composite = CompositeLoss([
    (InfoNCE(temperature=0.5), 1.0),     # Weight 1.0
    (CosineSimilarityLoss(), 0.5),       # Weight 0.5
])

Extending pyg-hyper-ssl

Custom SSL Method

from pyg_hyper_ssl.methods.base import BaseSSLMethod
import torch

class MySSLMethod(BaseSSLMethod):
    def forward(self, data, data_aug):
        # Your encoding logic
        z1 = self.encoder(data.x, data.hyperedge_index)
        z2 = self.encoder(data_aug.x, data_aug.hyperedge_index)
        return z1, z2

    def compute_loss(self, z1, z2, **kwargs):
        # Your loss computation
        return torch.nn.functional.mse_loss(z1, z2)

Custom Augmentation

from pyg_hyper_ssl.augmentations.base import BaseAugmentation

class MyAugmentation(BaseAugmentation):
    def __init__(self, param=0.5):
        super().__init__(param=param)
        self.param = param

    def __call__(self, data):
        # Your augmentation logic
        augmented_data = data.clone()
        # Modify augmented_data...
        return augmented_data

Custom Loss Function

from pyg_hyper_ssl.losses.base import BaseLoss

class MyLoss(BaseLoss):
    def forward(self, z1, z2, **kwargs):
        # Your loss computation
        return (z1 - z2).pow(2).mean()

Architecture

pyg-hyper-ssl/
โ”œโ”€โ”€ methods/
โ”‚   โ”œโ”€โ”€ base.py                    # BaseSSLMethod
โ”‚   โ””โ”€โ”€ contrastive/
โ”‚       โ”œโ”€โ”€ tricl.py               # TriCL implementation
โ”‚       โ”œโ”€โ”€ tricl_encoder.py       # TriCL encoder
โ”‚       โ””โ”€โ”€ tricl_layer.py         # TriCL convolution layer
โ”œโ”€โ”€ augmentations/
โ”‚   โ”œโ”€โ”€ base.py                    # Base augmentation classes
โ”‚   โ”œโ”€โ”€ structural/
โ”‚   โ”‚   โ””โ”€โ”€ edge_drop.py          # Edge dropping
โ”‚   โ””โ”€โ”€ attribute/
โ”‚       โ””โ”€โ”€ feature_mask.py       # Feature masking
โ”œโ”€โ”€ losses/
โ”‚   โ”œโ”€โ”€ base.py                    # Base loss classes
โ”‚   โ””โ”€โ”€ contrastive.py            # InfoNCE, NT-Xent
โ””โ”€โ”€ encoders/
    โ””โ”€โ”€ wrapper.py                 # Encoder wrapper for pyg-hyper-nn

Development

Setup

# Clone and install
git clone https://github.com/nishide-dev/pyg-hyper-ssl.git
cd pyg-hyper-ssl
uv sync

# Install pre-commit hooks
uv run pre-commit install
uv run pre-commit install --hook-type commit-msg

# Run tests
uv run pytest tests/ -v

# Run with coverage
uv run pytest tests/ --cov=src/pyg_hyper_ssl --cov-report=term-missing

Pre-commit hooks

This project uses pre-commit hooks to ensure code quality:

# Run hooks manually on all files
uv run pre-commit run --all-files

# Run hooks on staged files (happens automatically on git commit)
git commit -m "Your message"

Hooks:

  • ruff lint --fix: Auto-fix linting issues
  • ruff format: Format code
  • ty check: Type checking (runs on entire project)

Testing

# All tests
uv run pytest

# Specific test file
uv run pytest tests/test_tricl.py -v

# Test with output
uv run pytest tests/test_tricl.py -v -s

Code Quality

# Format code
uv run ruff format .

# Lint code
uv run ruff check .

# Fix auto-fixable issues
uv run ruff check --fix .

# Type checking
uv run ty check

Dependencies

  • Runtime:

    • torch >= 2.0.0
    • torch-geometric >= 2.4.0
    • torch-scatter >= 2.1.0
    • pyg-hyper-nn >= 0.1.0
    • pyg-hyper-data >= 0.1.0
    • hydra-core >= 1.3.0
    • scikit-learn >= 1.0.0
  • Development:

    • pytest >= 8.0
    • pytest-cov >= 4.1
    • ruff >= 0.6
    • ty (type checker)

Citation

If you use this library in your research, please cite:

@software{pyg_hyper_ssl,
  title = {pyg-hyper-ssl: Self-supervised Learning for Hypergraphs},
  author = {nishide-dev},
  year = {2026},
  url = {https://github.com/nishide-dev/pyg-hyper-ssl}
}

And cite the original papers for the methods you use:

@inproceedings{huang2023tricl,
  title={Contrastive Learning Meets Homophily: Two Birds with One Stone},
  author={Huang, Xiaojun and others},
  booktitle={AAAI},
  year={2023}
}

Roadmap

  • TriCL (AAAI'23)
  • HyperGCL (NeurIPS'22)
  • HypeBoy (ICLR'24)
  • Additional structural augmentations (NodeDrop, EdgePerturb, Subgraph)
  • Additional attribute augmentations (FeatureNoise, FeatureShuffle)
  • Pre-trained model zoo
  • Comprehensive benchmarks

Related Projects

License

MIT License - see LICENSE file for details.

Contributing

Contributions are welcome! Please feel free to submit a Pull Request.

  1. Fork the repository
  2. Create your feature branch (git checkout -b feature/amazing-feature)
  3. Commit your changes (git commit -m 'feat: add amazing feature')
  4. Push to the branch (git push origin feature/amazing-feature)
  5. Open a Pull Request

Acknowledgments

Built with:


Made with โค๏ธ for hypergraph self-supervised learning

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

pyg_hyper_ssl-0.1.1.tar.gz (39.2 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

pyg_hyper_ssl-0.1.1-py3-none-any.whl (37.0 kB view details)

Uploaded Python 3

File details

Details for the file pyg_hyper_ssl-0.1.1.tar.gz.

File metadata

  • Download URL: pyg_hyper_ssl-0.1.1.tar.gz
  • Upload date:
  • Size: 39.2 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: uv/0.11.7 {"installer":{"name":"uv","version":"0.11.7","subcommand":["publish"]},"python":null,"implementation":{"name":null,"version":null},"distro":{"name":"Ubuntu","version":"24.04","id":"noble","libc":null},"system":{"name":null,"release":null},"cpu":null,"openssl_version":null,"setuptools_version":null,"rustc_version":null,"ci":true}

File hashes

Hashes for pyg_hyper_ssl-0.1.1.tar.gz
Algorithm Hash digest
SHA256 87028ec382225c05f74be501f467fdc2e16d5ba9ddeebf5b92c45c85ccc6e042
MD5 f57dae60ab0f58afd7248e8073de0bed
BLAKE2b-256 8a76102a2afa1de28dc97e53efcec83cf4024fdff82b547b2f7aa3cca02a902d

See more details on using hashes here.

File details

Details for the file pyg_hyper_ssl-0.1.1-py3-none-any.whl.

File metadata

  • Download URL: pyg_hyper_ssl-0.1.1-py3-none-any.whl
  • Upload date:
  • Size: 37.0 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: uv/0.11.7 {"installer":{"name":"uv","version":"0.11.7","subcommand":["publish"]},"python":null,"implementation":{"name":null,"version":null},"distro":{"name":"Ubuntu","version":"24.04","id":"noble","libc":null},"system":{"name":null,"release":null},"cpu":null,"openssl_version":null,"setuptools_version":null,"rustc_version":null,"ci":true}

File hashes

Hashes for pyg_hyper_ssl-0.1.1-py3-none-any.whl
Algorithm Hash digest
SHA256 077eefbd2da7b9049d116baf12f4a179a50cf884386210139855e575c963468b
MD5 b89817c45afffa7cb2ecfd314e8f9d47
BLAKE2b-256 95145f7b661dd021e7f2b9c4dfc6756b5921f898d2f7d0ec9653072f28f61570

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page