Self-supervised learning methods for hypergraphs
Project description
pyg-hyper-ssl
Self-supervised learning methods for hypergraphs built on PyTorch Geometric.
Overview
pyg-hyper-ssl provides state-of-the-art self-supervised learning (SSL) methods for hypergraphs. Built on top of pyg-hyper-nn and pyg-hyper-data, this library implements cutting-edge SSL algorithms from recent research papers.
Key Features
- ๐ฏ State-of-the-art SSL Methods: TriCL (AAAI'23), and more coming soon
- ๐งฉ Modular Design: Extensible base classes for methods, augmentations, and losses
- ๐ Rich Augmentations: Structural (edge drop) and attribute (feature mask) augmentations
- ๐ Production Ready: Comprehensive tests (83% coverage), type hints, and documentation
- ๐ Seamless Integration: Works with all 19 models from pyg-hyper-nn
- โก Optimized: Built on PyTorch Geometric for efficient graph operations
Installation
From PyPI (Recommended)
pip install pyg-hyper-ssl
From Source
git clone https://github.com/nishide-dev/pyg-hyper-ssl.git
cd pyg-hyper-ssl
uv sync # or pip install -e .
Quick Start
TriCL: Tri-directional Contrastive Learning
import torch
from pyg_hyper_data.datasets import CoraCocitation
from pyg_hyper_ssl.methods.contrastive import TriCL, TriCLEncoder
from pyg_hyper_ssl.augmentations import EdgeDrop, FeatureMask
# Load dataset
dataset = CoraCocitation()
data = dataset[0]
# Create TriCL model
encoder = TriCLEncoder(
in_dim=data.num_node_features,
edge_dim=128,
node_dim=256,
num_layers=2
)
model = TriCL(
encoder=encoder,
proj_dim=256,
node_tau=0.5,
edge_tau=0.5,
membership_tau=0.1
)
# Create augmentations
aug1 = EdgeDrop(drop_prob=0.2)
aug2 = FeatureMask(mask_prob=0.3)
# Training loop
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
model.train()
for epoch in range(100):
# Apply augmentations
data_aug1 = aug1(data)
data_aug2 = aug2(data)
# Compute loss
loss = model.train_step(data_aug1, data_aug2)
# Optimize
optimizer.zero_grad()
loss.backward()
optimizer.step()
if epoch % 10 == 0:
print(f"Epoch {epoch:03d}, Loss: {loss.item():.4f}")
# Get embeddings for downstream tasks
model.eval()
embeddings = model.get_embeddings(data)
print(f"Embeddings shape: {embeddings.shape}")
Using with pyg-hyper-nn Models
from pyg_hyper_nn.models import HGNN, UniGNN, HyperGCN
from pyg_hyper_ssl.encoders import EncoderWrapper
from pyg_hyper_ssl.methods.base import BaseSSLMethod
# Wrap any pyg-hyper-nn model for SSL
encoder = EncoderWrapper(
model_class=HGNN, # or UniGNN, HyperGCN, etc.
in_channels=128,
hidden_channels=256,
num_layers=2,
use_projection=True # Add projection head for contrastive learning
)
# Use in your custom SSL method
class MySSLMethod(BaseSSLMethod):
def forward(self, data, data_aug):
h1, z1 = encoder.forward_from_data(data)
h2, z2 = encoder.forward_from_data(data_aug)
return (h1, z1), (h2, z2)
Implementation Accuracy
All implementations are verified against official reference implementations.
We have carefully compared our implementations with the original papers' reference code to ensure accuracy. See IMPLEMENTATION_ACCURACY.md for detailed verification results.
Key verification points:
- โ FeatureMask: Dimension-wise masking (matches reference exactly)
- โ EdgeDrop: Sparse matrix approach (matches reference + improvements)
- โ TriCL: All three loss levels verified (31 tests)
- โ HyperGCL: InfoNCE loss verified (20 tests)
- โ HypeBoy: Two-stage generative SSL verified (20 tests)
- โ SE-HSSL: Fairness-aware components verified (29 tests)
Total: 119 tests, 34 accuracy verification tests
Implemented Methods
TriCL (AAAI 2023)
Tri-directional Contrastive Learning for Hypergraphs
TriCL performs contrastive learning at three levels:
- Node-level: Contrast node embeddings across augmented views
- Group-level: Contrast hyperedge embeddings across views
- Membership-level: Contrast node-hyperedge relationships
from pyg_hyper_ssl.methods.contrastive import TriCL, TriCLEncoder
encoder = TriCLEncoder(in_dim=128, edge_dim=256, node_dim=512, num_layers=3)
model = TriCL(
encoder=encoder,
proj_dim=512,
lambda_n=1.0, # Node-level loss weight
lambda_e=1.0, # Group-level loss weight
lambda_m=1.0, # Membership-level loss weight
)
Reference: Huang et al. "Contrastive Learning Meets Homophily: Two Birds with One Stone" AAAI 2023.
HyperGCL (NeurIPS 2022)
Contrastive Learning for Hypergraphs with Fabricated Augmentations
HyperGCL performs node-level contrastive learning using InfoNCE loss. It works with any hypergraph encoder and uses fabricated augmentations (edge drop, feature mask, etc.).
from pyg_hyper_nn.models import HGNN
from pyg_hyper_ssl.encoders import EncoderWrapper
from pyg_hyper_ssl.methods.contrastive import HyperGCL
# Wrap any pyg-hyper-nn model
encoder = EncoderWrapper(
model_class=HGNN,
in_channels=128,
hidden_channels=256,
num_layers=2
)
model = HyperGCL(
encoder=encoder,
proj_hidden=256,
proj_out=256,
tau=0.5, # Temperature for InfoNCE loss
)
Reference: Wei et al. "Augmentations in Hypergraph Contrastive Learning: Fabricated and Generative" NeurIPS 2022.
HypeBoy (ICLR 2024)
Generative Self-Supervised Learning on Hypergraphs
HypeBoy performs two-stage generative SSL:
- Feature Reconstruction: Mask and reconstruct node features using a decoder
- Hyperedge Filling: Predict missing nodes in hyperedges via contrastive loss
from pyg_hyper_nn.models import HGNN
from pyg_hyper_ssl.encoders import EncoderWrapper
from pyg_hyper_ssl.methods.generative import HypeBoy, HypeBoyDecoder
# Encoder for both stages
encoder = EncoderWrapper(
model_class=HGNN,
in_channels=128,
hidden_channels=64,
num_layers=2
)
# Decoder for feature reconstruction
decoder = HypeBoyDecoder(
encoder=encoder,
in_dim=64,
out_dim=128,
hidden_dim=64,
num_layers=2
)
model = HypeBoy(
encoder=encoder,
decoder=decoder,
feature_recon_epochs=300, # Stage 1 epochs
hyperedge_fill_epochs=200, # Stage 2 epochs
feature_mask_prob=0.5, # Feature masking probability
edge_drop_prob_stage2=0.9 # Edge dropping probability (stage 2)
)
Reference: Kim et al. "HypeBoy: Generative Self-Supervised Representation Learning on Hypergraphs" ICLR 2024.
Augmentations
Structural Augmentations
from pyg_hyper_ssl.augmentations import EdgeDrop
# Randomly drop hyperedges
aug = EdgeDrop(drop_prob=0.2) # Drop 20% of hyperedges
data_aug = aug(data)
Attribute Augmentations
from pyg_hyper_ssl.augmentations import FeatureMask
# Randomly mask node features
aug = FeatureMask(mask_prob=0.3) # Mask 30% of features
data_aug = aug(data)
Composition
from pyg_hyper_ssl.augmentations import ComposedAugmentation, RandomChoice
# Sequential composition
aug = ComposedAugmentation([
EdgeDrop(drop_prob=0.2),
FeatureMask(mask_prob=0.3)
])
# Random choice
aug = RandomChoice([
EdgeDrop(drop_prob=0.2),
FeatureMask(mask_prob=0.3)
])
Loss Functions
Contrastive Losses
from pyg_hyper_ssl.losses import InfoNCE, NTXent, CosineSimilarityLoss
# InfoNCE loss (SimCLR-style)
loss_fn = InfoNCE(temperature=0.5)
# NT-Xent (alias for InfoNCE)
loss_fn = NTXent(temperature=0.5)
# Simple cosine similarity
loss_fn = CosineSimilarityLoss()
Fairness-Aware Losses
from pyg_hyper_ssl.losses import CCALoss, orthogonal_projection, balance_hyperedges
# CCA Loss for fairness-aware SSL (SE-HSSL)
cca_loss = CCALoss(lambda_decorr=0.005)
loss = cca_loss(z1, z2) # Maximize correlation between views
# Orthogonal projection for debiasing
debias_x = orthogonal_projection(x, sens_idx=0) # Remove bias from sensitive attribute
# Balance hyperedge group representation
balanced_edge_index = balance_hyperedges(
hyperedge_index,
node_groups=[0, 0, 1, 1, 0], # Binary group labels
beta=1.0 # Balance strength
)
Composite Losses
from pyg_hyper_ssl.losses import CompositeLoss, InfoNCE
# Combine multiple losses with weights
composite = CompositeLoss([
(InfoNCE(temperature=0.5), 1.0), # Weight 1.0
(CosineSimilarityLoss(), 0.5), # Weight 0.5
])
Extending pyg-hyper-ssl
Custom SSL Method
from pyg_hyper_ssl.methods.base import BaseSSLMethod
import torch
class MySSLMethod(BaseSSLMethod):
def forward(self, data, data_aug):
# Your encoding logic
z1 = self.encoder(data.x, data.hyperedge_index)
z2 = self.encoder(data_aug.x, data_aug.hyperedge_index)
return z1, z2
def compute_loss(self, z1, z2, **kwargs):
# Your loss computation
return torch.nn.functional.mse_loss(z1, z2)
Custom Augmentation
from pyg_hyper_ssl.augmentations.base import BaseAugmentation
class MyAugmentation(BaseAugmentation):
def __init__(self, param=0.5):
super().__init__(param=param)
self.param = param
def __call__(self, data):
# Your augmentation logic
augmented_data = data.clone()
# Modify augmented_data...
return augmented_data
Custom Loss Function
from pyg_hyper_ssl.losses.base import BaseLoss
class MyLoss(BaseLoss):
def forward(self, z1, z2, **kwargs):
# Your loss computation
return (z1 - z2).pow(2).mean()
Architecture
pyg-hyper-ssl/
โโโ methods/
โ โโโ base.py # BaseSSLMethod
โ โโโ contrastive/
โ โโโ tricl.py # TriCL implementation
โ โโโ tricl_encoder.py # TriCL encoder
โ โโโ tricl_layer.py # TriCL convolution layer
โโโ augmentations/
โ โโโ base.py # Base augmentation classes
โ โโโ structural/
โ โ โโโ edge_drop.py # Edge dropping
โ โโโ attribute/
โ โโโ feature_mask.py # Feature masking
โโโ losses/
โ โโโ base.py # Base loss classes
โ โโโ contrastive.py # InfoNCE, NT-Xent
โโโ encoders/
โโโ wrapper.py # Encoder wrapper for pyg-hyper-nn
Development
Setup
# Clone and install
git clone https://github.com/nishide-dev/pyg-hyper-ssl.git
cd pyg-hyper-ssl
uv sync
# Install pre-commit hooks
uv run pre-commit install
uv run pre-commit install --hook-type commit-msg
# Run tests
uv run pytest tests/ -v
# Run with coverage
uv run pytest tests/ --cov=src/pyg_hyper_ssl --cov-report=term-missing
Pre-commit hooks
This project uses pre-commit hooks to ensure code quality:
# Run hooks manually on all files
uv run pre-commit run --all-files
# Run hooks on staged files (happens automatically on git commit)
git commit -m "Your message"
Hooks:
ruff lint --fix: Auto-fix linting issuesruff format: Format codety check: Type checking (runs on entire project)
Testing
# All tests
uv run pytest
# Specific test file
uv run pytest tests/test_tricl.py -v
# Test with output
uv run pytest tests/test_tricl.py -v -s
Code Quality
# Format code
uv run ruff format .
# Lint code
uv run ruff check .
# Fix auto-fixable issues
uv run ruff check --fix .
# Type checking
uv run ty check
Dependencies
-
Runtime:
torch >= 2.0.0torch-geometric >= 2.4.0torch-scatter >= 2.1.0pyg-hyper-nn >= 0.1.0pyg-hyper-data >= 0.1.0hydra-core >= 1.3.0scikit-learn >= 1.0.0
-
Development:
pytest >= 8.0pytest-cov >= 4.1ruff >= 0.6ty(type checker)
Citation
If you use this library in your research, please cite:
@software{pyg_hyper_ssl,
title = {pyg-hyper-ssl: Self-supervised Learning for Hypergraphs},
author = {nishide-dev},
year = {2026},
url = {https://github.com/nishide-dev/pyg-hyper-ssl}
}
And cite the original papers for the methods you use:
@inproceedings{huang2023tricl,
title={Contrastive Learning Meets Homophily: Two Birds with One Stone},
author={Huang, Xiaojun and others},
booktitle={AAAI},
year={2023}
}
Roadmap
- TriCL (AAAI'23)
- HyperGCL (NeurIPS'22)
- HypeBoy (ICLR'24)
- Additional structural augmentations (NodeDrop, EdgePerturb, Subgraph)
- Additional attribute augmentations (FeatureNoise, FeatureShuffle)
- Pre-trained model zoo
- Comprehensive benchmarks
Related Projects
- pyg-hyper-data: Hypergraph datasets and evaluation protocols
- pyg-hyper-nn: Hypergraph neural network models
- pyg-hyper-bench: Benchmarking framework (coming soon)
License
MIT License - see LICENSE file for details.
Contributing
Contributions are welcome! Please feel free to submit a Pull Request.
- Fork the repository
- Create your feature branch (
git checkout -b feature/amazing-feature) - Commit your changes (
git commit -m 'feat: add amazing feature') - Push to the branch (
git push origin feature/amazing-feature) - Open a Pull Request
Acknowledgments
Built with:
- PyTorch
- PyTorch Geometric
- uv - Fast Python package manager
- ruff - Fast Python linter and formatter
Made with โค๏ธ for hypergraph self-supervised learning
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file pyg_hyper_ssl-0.1.0.tar.gz.
File metadata
- Download URL: pyg_hyper_ssl-0.1.0.tar.gz
- Upload date:
- Size: 188.7 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: uv/0.10.4 {"installer":{"name":"uv","version":"0.10.4","subcommand":["publish"]},"python":null,"implementation":{"name":null,"version":null},"distro":{"name":"Ubuntu","version":"24.04","id":"noble","libc":null},"system":{"name":null,"release":null},"cpu":null,"openssl_version":null,"setuptools_version":null,"rustc_version":null,"ci":true}
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
ffc445d6fd36cf4c3bee7f495a768ec92e07ce62e2f05c7c05c2ec41e7930159
|
|
| MD5 |
a3a5685c08e0a16a554c920d81a680cb
|
|
| BLAKE2b-256 |
6779676bfb12945ebdedf5623c64988c213f3596c436fa4793f29bcbe87480f5
|
File details
Details for the file pyg_hyper_ssl-0.1.0-py3-none-any.whl.
File metadata
- Download URL: pyg_hyper_ssl-0.1.0-py3-none-any.whl
- Upload date:
- Size: 37.9 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: uv/0.10.4 {"installer":{"name":"uv","version":"0.10.4","subcommand":["publish"]},"python":null,"implementation":{"name":null,"version":null},"distro":{"name":"Ubuntu","version":"24.04","id":"noble","libc":null},"system":{"name":null,"release":null},"cpu":null,"openssl_version":null,"setuptools_version":null,"rustc_version":null,"ci":true}
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
252580e0dd121549dae69e10d9a2341949b20d0f70918332c0ca38df8d6ad868
|
|
| MD5 |
9ed42af190c36cdf6a607f13de421544
|
|
| BLAKE2b-256 |
91d0ace1cbdb32c026d5a99a97293063316671f48d61843d97a8e2d33236f0d2
|