Some custom GNN layers for PyTorch Geometric

## Project description

# custom-gnn-layers

This is a little collection of the custom graph convolutional and pooling layers I've made for various projects. Everything here is built on the PyTorch Geometric library and can be used like a regular PyTorch module.

## Installation

Requires Python 3.X, PyTorch 1.2.X, and PyTorch Geometric

You can install these layers with pip:

`$ pip install gnn_layers`

## Convolutional Layers

### EdgeAttentionConv

`EdgeAttentionConv`

is an edge-conditioned filter with an attention mechanism. It's the same as `NNConv`

, except an attention coefficient for each message is calculated from the edge features. The idea is that messages from some neighbors may be more important than others, depending on their connection with the root node. Node embeddings are updated like so:

where *W _{r}* and

*W*are trainable weight matrices, and

_{g}*h*is a neural network (e.g. a MLP).

_{e}*W*is used to transform the root node features and

_{r}*W*is used to calculate an attention coefficient.

_{g}**Parameters:**

**in_channels**(*int*): Size of each input node embedding.**out_channels**(*int*): Size of each output node embedding.**edge_nn**(*torch.nn.Module*): A neural network*h*that maps edge features_{e}`edge_attr`

of shape`[-1, num_edge_features]`

to shape`[-1, in_channels * out_channels]`

**root_weight**(*bool, optional*): If set to`False`

, the layer will not add the transformed root node features to the output. (default:`True`

)**bias**(*bool, optional*): If set to`False`

, the layer will not learn an additive bias. (default:`True`

)

**Example:**

import torch from gnn_layers import EdgeAttentionConv # Convolutional layer conv = EdgeAttentionConv( in_channels=1, out_channels=4, edge_nn=torch.nn.Linear(2, 1 * 4), ) # Your input graph x = torch.tensor([[-1], [0], [1]], dtype=torch.float) edge_index = torch.tensor([[0, 1, 1, 2], [1, 0, 2, 1]], dtype=torch.long) edge_attr = torch.tensor([[0, 1], [1, 0], [1, 0], [0, 1]], dtype=torch.long) # Your output graph x = conv(x, edge_index, edge_attr) # Shape is now [3, 4]

### ImageConv

`ImageConv`

is an edge-conditioned filter for graphs where the node features are n-dimensional matrices, such as 2D or 3D images, rather than vectors. The filter applies a (non-graph) convolution, i.e. `torch.nn.Conv2d`

or `torch.nn.Conv3d`

, to transform the node features. Node embeddings are updated like so:

where *φ _{r}* and

*φ*are convolutional layers, and

_{m}*W*is a weight matrix.

_{e}**Parameters:**

**in_channels**(*int*): Number of channels in the input node image.**out_channels**(*int*): Number of channels in the output node image.**image_dims**(*tuple*): Dimensions of the input node image as a tuple, e.g. for a 4x4 image, set to`(4, 4)`

.**kernel_size**(*tuple*): Size of the convolving kernel.**num_edge_attr**(*int*): Number of edge features.**bias**(*bool, optional*): If set to`False`

, the layer will not learn an additive bias. (default:`True`

)**aggr**(*str, optional*): The aggregation scheme to use (`"add"`

,`"mean"`

,`"max"`

). (default:`"add"`

)****kwargs**(*optional*): Additional arguments for`torch.nn.Conv1d`

,`torch.nn.Conv2d`

, or`torch.nn.Conv3d`

.

**Example:**

import torch from gnn_layers import ImageConv # Convolutional layer conv = ImageConv( in_channels=1, out_channels=4, image_dims=(8, 8), kernel_size=(2, 2), num_edge_attr=2 ) # Your input graph x = torch.randn((3, 1, 8, 8), dtype=torch.float) edge_index = torch.tensor([[0, 1, 1, 2], [1, 0, 2, 1]], dtype=torch.long) edge_attr = torch.tensor([[0, 1], [1, 0], [1, 0], [0, 1]], dtype=torch.long) # Your output graph x = conv(x, edge_index, edge_attr) # Shape is now [3, 4, 7, 7]

## Project details

## Release history Release notifications

## Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Filename, size | File type | Python version | Upload date | Hashes |
---|---|---|---|---|

Filename, size gnn_layers-1.0.2.tar.gz (5.0 kB) | File type Source | Python version None | Upload date | Hashes View hashes |