Graph convolutional operator that uses a LSTM as a filter
Project description
SeqConv
SeqConv
is a PyTorch implementation of a graph convolutional operator that uses long short-term memory (LSTM) to update node embeddings. This is useful for graph datasets where each node represents a sequence, such as a time series.
Where φr and φm are LSTMs (torch.nn.LSTM
), and hΘ is a neural network. The outputs of each LSTM are the last hidden state, hn, rather than all the output features.
Installation
This module can be installed with pip
:
$ pip install seq_conv
Usage
SeqConv
is built on PyTorch Geometric and derives from the MessagePassing
class. It expects an input graph where each node's "features" is a sequence of vectors. SeqConv
, similarly to NNConv
, also incorporates any available edge features when collecting messages from a node's neighbors.
Parameters:
- in_channels (int): Number of channels in the input node sequence (e.g. if each node has a sequence of vectors of size n associated with it, then in_channels = n)
- out_channels (int): Number of channels in the output node embedding
- edge_nn (torch.nn.Module): A neural network hΘ that maps edge features,
edge_attr
, of shape[-1, num_edge_features]
to shape[-1, out_channels]
- aggr (string, optional): The message aggregation scheme to use ("add", "mean", "max")
- root_lstm (bool, optional): If set to
False
, the layer will not add the LSTM-transformed root node features to the output - bias (bool, optional): If set to
False
, the layer will not learn an additive bias - **kwargs (optional): Additional arguments for
torch.nn.LSTM
Example Usage:
import torch
from seq_conv import SeqConv
# Convolutional layer
conv_layer = SeqConv(
in_channels=1,
out_channels=5,
edge_nn=torch.nn.Linear(2, 5)
)
# Your input graph (see: https://pytorch-geometric.readthedocs.io/en/latest/notes/introduction.html#data-handling-of-graphs)
x = torch.randn((3, 12, 1), dtype=torch.float) # Shape is [num_nodes, seq_len, in_channels]
edge_index = torch.tensor([
[0, 1, 1, 2],
[1, 0, 2, 1]
], dtype=torch.long)
edge_attr = torch.randn((4, 2), dtype=torch.long)
# Your output graph
x = conv_layer(x, edge_index, edge_attr) # Shape is now [3, 5]
To-Do: Allow stacking of SeqConv
layers.
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.