Skip to main content

Some custom GNN layers for PyTorch Geometric

Project description

custom-gnn-layers

This is a little collection of the custom graph convolutional and pooling layers I've made for various projects. Everything here is built on the PyTorch Geometric library and can be used like a regular PyTorch module.

Installation

Requires Python 3.X, PyTorch 1.2.X, and PyTorch Geometric

You can install these layers with pip:

$ pip install gnn_layers

Convolutional Layers

EdgeAttentionConv

EdgeAttentionConv is an edge-conditioned filter with an attention mechanism. It's the same as NNConv, except an attention coefficient for each message is calculated from the edge features. The idea is that messages from some neighbors may be more important than others, depending on their connection with the root node. Node embeddings are updated like so:

where Wr and Wg are trainable weight matrices, and he is a neural network (e.g. a MLP). Wr is used to transform the root node features and Wg is used to calculate an attention coefficient.

Parameters:

  • in_channels (int): Size of each input node embedding.
  • out_channels (int): Size of each output node embedding.
  • edge_nn (torch.nn.Module): A neural network he that maps edge features edge_attr of shape [-1, num_edge_features] to shape [-1, in_channels * out_channels]
  • root_weight (bool, optional): If set to False, the layer will not add the transformed root node features to the output. (default: True)
  • bias (bool, optional): If set to False, the layer will not learn an additive bias. (default: True)

Example:

import torch
from gnn_layers import EdgeAttentionConv

# Convolutional layer
conv = EdgeAttentionConv(
  in_channels=1,
  out_channels=4,
  edge_nn=torch.nn.Linear(2, 1 * 4),
)

# Your input graph
x = torch.tensor([[-1], [0], [1]], dtype=torch.float)
edge_index = torch.tensor([[0, 1, 1, 2],
                           [1, 0, 2, 1]], dtype=torch.long)
edge_attr = torch.tensor([[0, 1], [1, 0], [1, 0], [0, 1]], dtype=torch.long)

# Your output graph
x = conv(x, edge_index, edge_attr) # Shape is now [3, 4]

ImageConv

ImageConv is an edge-conditioned filter for graphs where the node features are n-dimensional matrices, such as 2D or 3D images, rather than vectors. The filter applies a (non-graph) convolution, i.e. torch.nn.Conv2d or torch.nn.Conv3d, to transform the node features. Node embeddings are updated like so:

where φr and φm are convolutional layers, and We is a weight matrix.

Parameters:

  • in_channels (int): Number of channels in the input node image.
  • out_channels (int): Number of channels in the output node image.
  • image_dims (tuple): Dimensions of the input node image as a tuple, e.g. for a 4x4 image, set to (4, 4).
  • kernel_size (tuple): Size of the convolving kernel.
  • num_edge_attr (int): Number of edge features.
  • bias (bool, optional): If set to False, the layer will not learn an additive bias. (default: True)
  • aggr (str, optional): The aggregation scheme to use ("add", "mean", "max"). (default: "add")
  • **kwargs (optional): Additional arguments for torch.nn.Conv1d, torch.nn.Conv2d, or torch.nn.Conv3d.

Example:

import torch
from gnn_layers import ImageConv

# Convolutional layer
conv = ImageConv(
  in_channels=1,
  out_channels=4,
  image_dims=(8, 8),
  kernel_size=(2, 2),
  num_edge_attr=2
)

# Your input graph
x = torch.randn((3, 1, 8, 8), dtype=torch.float)
edge_index = torch.tensor([[0, 1, 1, 2],
                           [1, 0, 2, 1]], dtype=torch.long)
edge_attr = torch.tensor([[0, 1], [1, 0], [1, 0], [0, 1]], dtype=torch.long)

# Your output graph
x = conv(x, edge_index, edge_attr) # Shape is now [3, 4, 7, 7]

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Files for gnn-layers, version 1.0.2
Filename, size File type Python version Upload date Hashes
Filename, size gnn_layers-1.0.2.tar.gz (5.0 kB) File type Source Python version None Upload date Hashes View hashes

Supported by

Elastic Elastic Search Pingdom Pingdom Monitoring Google Google BigQuery Sentry Sentry Error logging AWS AWS Cloud computing DataDog DataDog Monitoring Fastly Fastly CDN DigiCert DigiCert EV certificate StatusPage StatusPage Status page