Graph convolutional operator that uses a CNN as a filter
Project description
MatrixConv
MatrixConv
is a graph convolutional filter for graphs where the node features are n-dimensional matrices, such as 2D or 3D images, such as a scene graph. The filter applies a (non-graph) convolution, i.e. torch.nn.Conv{1/2/3}d
, to transform the node features. Node embeddings are updated like so:
Where φr and φm are CNNs (torch.nn.Conv{1/2/3}d
), and We is a weight matrix.
Installation
This module can be installed with pip
:
$ pip install matrix_conv
Usage
MatrixConv
is built on PyTorch Geometric and derives from the MessagePassing
module. It expects an input graph where each node's "features" is a matrix (either 1D, 2D, or 3D). MatrixConv
, similarly to NNConv
, also incorporates any available edge features when collecting messages from a node's neighbors.
Parameters:
- in_channels (int): Number of channels in the input node matrix (e.g. if each node's features is a 3x5 matrix with 2 input channels, then
in_channels=2
) - out_channels (int): Number of channels in the output node embedding
- matrix_dims (list or tuple): Dimensions of matrix associated with node (e.g. if each node's features is a 3x5 matrix, then
matrix_dims=[3, 5]
) - num_edge_attr (int): Number of edge attributes/features
- kernel_dims (list or tuple): Dimensions of the convolving kernel in the CNN
- aggr (string, optional): The message aggregation scheme to use ("add", "mean", "max")
- root_cnn (bool, optional): If set to
False
, the layer will not add the CNN-transformed root node features to the output - bias (bool, optional): If set to
False
, the layer will not learn an additive bias - **kwargs (optional): Additional arguments for
torch.nn.Conv{1/2/3}d
Example Usage:
import torch
from matrix_conv import MatrixConv
# Convolutional layer
conv_layer = MatrixConv(
in_channels=1,
out_channels=10,
matrix_dims=[5, 5, 5],
num_edge_attr=3,
kernel_dims=[2, 3, 3]
)
# Your input graph (see: https://pytorch-geometric.readthedocs.io/en/latest/notes/introduction.html#data-handling-of-graphs)
x = torch.randn((3, 1, 5, 5, 5), dtype=torch.float) # Shape is [num_nodes, in_channels, *matrix_dims]
edge_index = torch.tensor([
[0, 1, 1, 2],
[1, 0, 2, 1]
], dtype=torch.long)
edge_attr = torch.randn((4, 3), dtype=torch.float)
# Your output graph
x = conv_layer(x, edge_index, edge_attr) # Shape is now [3, 10, 4, 3, 3]
To-Do: Show example of using this in a graph classifier (include stacking)
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
File details
Details for the file matrix_conv-1.0.1.tar.gz
.
File metadata
- Download URL: matrix_conv-1.0.1.tar.gz
- Upload date:
- Size: 3.6 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/2.0.0 pkginfo/1.5.0.1 requests/2.22.0 setuptools/46.0.0 requests-toolbelt/0.9.1 tqdm/4.36.0 CPython/3.7.8
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | 98b140165a995b692485190cff45d352ccc571a444385c04d09614f563680930 |
|
MD5 | 0f9a4448a2fa60a7e32d90dd40a7ad28 |
|
BLAKE2b-256 | 12e445391edd9c7075b73d46279215ee02bf42527ad451e2ed422b0a0a322923 |