Graph convolutional operator that uses a LSTM as a filter
Project description
SeqConv
SeqConv
is a PyTorch implementation of a graph convolutional operator that uses long short-term memory (LSTM) to update node embeddings. This is useful for graph datasets where each node represents a sequence, such as a time series.
Where φr and φm are LSTMs (torch.nn.LSTM
), and hΘ is a neural network. The outputs of each LSTM are the last hidden state, hn, rather than all the output features.
Installation
This module can be installed with pip
:
$ pip install seq_conv
Usage
SeqConv
is built on PyTorch Geometric and derives from the MessagePassing
class. It expects an input graph where each node's "features" is a sequence of vectors. SeqConv
, similarly to NNConv
, also incorporates any available edge features when collecting messages from a node's neighbors.
Parameters:
- in_channels (int): Number of channels in the input node sequence (e.g. if each node has a sequence of vectors of size n associated with it, then in_channels = n)
- out_channels (int): Number of channels in the output node embedding
- edge_nn (torch.nn.Module): A neural network hΘ that maps edge features,
edge_attr
, of shape[-1, num_edge_features]
to shape[-1, out_channels]
- aggr (string, optional): The message aggregation scheme to use ("add", "mean", "max")
- root_lstm (bool, optional): If set to
False
, the layer will not add the LSTM-transformed root node features to the output - bias (bool, optional): If set to
False
, the layer will not learn an additive bias - **kwargs (optional): Additional arguments for
torch.nn.LSTM
Example Usage:
import torch
from seq_conv import SeqConv
# Convolutional layer
conv_layer = SeqConv(
in_channels=1,
out_channels=5,
edge_nn=torch.nn.Linear(2, 5)
)
# Your input graph (see: https://pytorch-geometric.readthedocs.io/en/latest/notes/introduction.html#data-handling-of-graphs)
x = torch.randn((3, 12, 1), dtype=torch.float) # Shape is [num_nodes, seq_len, in_channels]
edge_index = torch.tensor([
[0, 1, 1, 2],
[1, 0, 2, 1]
], dtype=torch.long)
edge_attr = torch.randn((4, 2), dtype=torch.long)
# Your output graph
x = conv_layer(x, edge_index, edge_attr) # Shape is now [3, 5]
To-Do: Allow stacking of SeqConv
layers.
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
File details
Details for the file seq_conv-1.0.2.tar.gz
.
File metadata
- Download URL: seq_conv-1.0.2.tar.gz
- Upload date:
- Size: 3.6 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/2.0.0 pkginfo/1.5.0.1 requests/2.22.0 setuptools/46.0.0 requests-toolbelt/0.9.1 tqdm/4.36.0 CPython/3.7.8
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | 61ed61e332d82ab2ccb54569ee75ecbdcd5d0b4c3b0c44397f3ae5e5f057eb04 |
|
MD5 | b4f78f5ecf09cbe1a86ce4a8c9ef688e |
|
BLAKE2b-256 | 734320d931366f1db4ff5823444786ba7625778379ad66d2c16bf5a0d385d6db |