Skip to main content

Graph convolutional operator that uses a LSTM as a filter

Project description

SeqConv

SeqConv is a PyTorch implementation of a graph convolutional operator that uses long short-term memory (LSTM) network as a filter -- that is, LSTM is used to update node embeddings. This is useful for graph datasets where each node represents a sequence of vectors, such as a time series.

Where φr and φm are LSTMs (torch.nn.LSTM), and hΘ is a neural network. The outputs of each LSTM are the last hidden state, hn, rather than all the output features.

Installation

This module can be installed with pip:

$ pip install seq_conv

Usage

SeqConv is built on PyTorch Geometric and derives from the MessagePassing module. It expects an input graph where each node's has a sequence of vectors associated with it. SeqConv, similarly to NNConv, also incorporates any available edge features when collecting messages from a node's neighbors.

Parameters:

  • in_channels (int): Number of channels in the input node sequence (e.g. if each node has a sequence of vectors of size n associated with it, then in_channels = n)
  • out_channels (int): Number of channels in the output node embedding
  • edge_nn (torch.nn.Module): A neural network hΘ that maps edge features, edge_attr, of shape [-1, num_edge_features] to shape [-1, out_channels]
  • aggr (string, optional): The message aggregation scheme to use ("add", "mean", "max")
  • root_lstm (bool, optional): If set to False, the layer will not add the LSTM-transformed root node features to the output
  • bias (bool, optional): If set to False, the layer will not learn an additive bias
  • **kwargs (optional): Additional arguments for torch.nn.LSTM

Example Usage:

import torch
from seq_conv import SeqConv

# Convolutional layer
conv_layer = SeqConv(
    in_channels=1,
    out_channels=5,
    edge_nn=torch.nn.Linear(2, 5)
)

# Your input graph (see: https://pytorch-geometric.readthedocs.io/en/latest/notes/introduction.html#data-handling-of-graphs)
x = torch.randn((3, 12, 1), dtype=torch.float) # Shape is [num_nodes, seq_len, in_channels]
edge_index = torch.tensor([
    [0, 1, 1, 2],
    [1, 0, 2, 1]
], dtype=torch.long)
edge_attr = torch.randn((4, 2), dtype=torch.long)

# Your output graph
x = conv_layer(x, edge_index, edge_attr) # Shape is now [3, 5]

To-Do: Allow stacking of SeqConv layers.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

seq_conv-1.0.1.tar.gz (3.5 kB view hashes)

Uploaded Source

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page