Hierarchical GNN for graphs with community structure
Project description
CommunityNet
CommunityNet
is a hierarchical Graph Neural Network (GNN) designed for graph datasets with community structure (e.g. social networks, molecules, etc.). It's designed to encode information at both the within-community level and the inter-community level.
Installation
You can download CommunityNet
from PyPi:
$ pip install communitynet
Usage
Before instantiating CommunityNet
, you must define a "base" GNN and an "output" GNN. The base GNN is used to create vector embeddings of each community in an input graph. These embeddings are used as node features in an "inter-community" graph, where each node represents a community and each edge is the mean of the edges between two communities. This graph is submitted to the output GNN to make a prediction. Both GNNs can be constructed using the GraphNet
and MLP
PyTorch modules supplied by the library. For example, to construct the CommunityNet
shown in the diagram above, you can do the following:
import torch.nn as nn
from communitynet import GraphNet, MLP, CommunityNet
# Example numbers (arbitrary)
num_node_features = 4
num_edge_features = 2
base_gnn = GraphNet(in_channels=num_node_features, out_channels=8,
num_edge_features=num_edge_features)
output_gnn = nn.Sequential(
GraphNet(in_channels=8, out_channels=4, num_edge_features=num_edge_features),
MLP(in_channels=4, out_channels=1)
)
community_net = CommunityNet(base_gnn, output_gnn, num_communities=3)
GraphNet
and MLP
both have additional hyperparameters (e.g. hidden layers, dropout, etc.) which are described in the reference below. The CommunityNet
class itself derives from torch.nn.Module
, so it can be trained like any other PyTorch model.
Each graph you submit to CommunityNet
must be an instance of torch_geometric.data.Data
with an additional communities
attribute. data.communities
should hold a list of communities, where each community is a set of node indices. For example:
import torch
from torch_geometric.data import Data
edge_index = torch.tensor([[0, 1, 0, 2, 1, 2, 2, 3, 3, 4, 3, 5, 4, 5],
[1, 0, 2, 0, 2, 1, 3, 2, 4, 3, 5, 3, 5, 4]],
dtype=torch.long)
x = torch.tensor([[-1], [0], [1], [0.5], [0.75], [-0.25]], dtype=torch.float)
data = Data(x=x, edge_index=edge_index)
data.communities = [{0, 1, 2}, {3, 4, 5}]
Note that every graph in your dataset must have the same number of communities.
Reference
GraphNet
PyTorch module that implements a GNN. Uses NNConv
(an edge-conditioned convolutional operator) as a filter and global pooling to convert a graph into a vector embedding.
Parameters:
in_channels
(int): Number of node featuresout_channels
(int): Number of output featuresnum_edge_features
(int): Number of edge featureshidden_channels
(list, optional (default=[])): List of hidden state sizes; length of list == number of layersuse_pooling
(bool, optional (default=False)): Whether or not to use top-k poolingdropout_prob
(float, optional (default=0.0)): Dropout probability applied to each GNN layerglobal_pooling
(str, optional (default="mean")): Global pooling mode; options are: "mean", "add", and "max"activation
(torch.nn.Module, optional (default=None)): Activation function used forNNConv
edge_nn_kwargs
(dict, optional (default={})): Dictionary of parameters for the MLP used to process edge features inNNConv
MLP
PyTorch module that implements a multi-layer perceptron. This can be used in an output GNN to convert a graph embedding into a prediction (e.g. a classification/regression).
Parameters:
in_channels
(int): Number of input featuresout_channels
(int): Number of output featureshidden_channels
(list, optional (default=[])): List of hidden state sizes; length of list == number of layersh_activation
(torch.nn.Module, optional (default=None)): Hidden activation functionout_activation
(torch.nn.Module, optional (default=None)): Output activation function
CommunityNet
PyTorch module that implements a hierarchical GNN.
Parameters:
base_gnn
(torch.nn.Module): Base GNN used to process each communityoutput_gnn
(torch.nn.Module): Output GNN used to process the inter-community graph and produce a predictionnum_communities
(int): Number of communities in each input graphnum_jobs
(int, optional (default=1)): Number of jobs (CPU cores) to distribute the community embedding work across
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.