Skip to main content

Graph convolutional operator that uses a CNN as a filter

Project description

MatrixConv

MatrixConv is a graph convolutional filter for graphs where the node features are n-dimensional matrices, such as 2D or 3D images, such as a scene graph. The filter applies a (non-graph) convolution, i.e. torch.nn.Conv{1/2/3}d, to transform the node features. Node embeddings are updated like so:

Where φr and φm are CNNs (torch.nn.Conv{1/2/3}d), and We is a weight matrix.

Installation

This module can be installed with pip:

$ pip install matrix_conv

Usage

MatrixConv is built on PyTorch Geometric and derives from the MessagePassing module. It expects an input graph where each node's "features" is a matrix (either 1D, 2D, or 3D). MatrixConv, similarly to NNConv, also incorporates any available edge features when collecting messages from a node's neighbors.

Parameters:

  • in_channels (int): Number of channels in the input node matrix (e.g. if each node's features is a 3x5 matrix with 2 input channels, then in_channels=2)
  • out_channels (int): Number of channels in the output node embedding
  • matrix_dims (list or tuple): Dimensions of matrix associated with node (e.g. if each node's features is a 3x5 matrix, then matrix_dims=[3, 5])
  • num_edge_attr (int): Number of edge attributes/features
  • kernel_dims (list or tuple): Dimensions of the convolving kernel in the CNN
  • aggr (string, optional): The message aggregation scheme to use ("add", "mean", "max")
  • root_cnn (bool, optional): If set to False, the layer will not add the CNN-transformed root node features to the output
  • bias (bool, optional): If set to False, the layer will not learn an additive bias
  • **kwargs (optional): Additional arguments for torch.nn.Conv{1/2/3}d

Example Usage:

import torch
from matrix_conv import MatrixConv

# Convolutional layer
conv_layer = MatrixConv(
    in_channels=1,
    out_channels=10,
    matrix_dims=[5, 5, 5],
    num_edge_attr=3,
    kernel_dims=[2, 3, 3]
)

# Your input graph (see: https://pytorch-geometric.readthedocs.io/en/latest/notes/introduction.html#data-handling-of-graphs)
x = torch.randn((3, 1, 5, 5, 5), dtype=torch.float) # Shape is [num_nodes, in_channels, *matrix_dims]
edge_index = torch.tensor([
    [0, 1, 1, 2],
    [1, 0, 2, 1]
], dtype=torch.long)
edge_attr = torch.randn((4, 3), dtype=torch.float)

# Your output graph
x = conv_layer(x, edge_index, edge_attr) # Shape is now [3, 10, 4, 3, 3]

To-Do: Show example of using this in a graph classifier (include stacking)

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

matrix_conv-1.0.1.tar.gz (3.6 kB view details)

Uploaded Source

File details

Details for the file matrix_conv-1.0.1.tar.gz.

File metadata

  • Download URL: matrix_conv-1.0.1.tar.gz
  • Upload date:
  • Size: 3.6 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/2.0.0 pkginfo/1.5.0.1 requests/2.22.0 setuptools/46.0.0 requests-toolbelt/0.9.1 tqdm/4.36.0 CPython/3.7.8

File hashes

Hashes for matrix_conv-1.0.1.tar.gz
Algorithm Hash digest
SHA256 98b140165a995b692485190cff45d352ccc571a444385c04d09614f563680930
MD5 0f9a4448a2fa60a7e32d90dd40a7ad28
BLAKE2b-256 12e445391edd9c7075b73d46279215ee02bf42527ad451e2ed422b0a0a322923

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page