Skip to main content

A PyTorch implementation of Graph Attention Networks, with experimental speedup features.

Project description

Pytorch Fast GAT Implementation

Fast GAT

This is my implementation of an old paper, Graph Attention Networks. However, instead of a standard impementation, this one introduces several techniques to speed up the process, which are found below.

Installation

pip install <NAME-TBA>

Alternatively,

git clone https://github.com/tatp22/pytorch-fast-GAT.git
cd fast_gat

What makes this repo faster?

What is great about this paper is that, besides its state of the art performance on a number of benchmarks, is that it could be applied to any graph, regardless of its structure. However, this algorithm has a runtime that depends on the number of edges, and when the graph is dense, this means that it can run in nodes^2 time.

Most sparsifying techniques for graphs rely on somehow decreasing the number of edges. However, I will try out a different method: Reducing the number of nodes in the interior representation. This will be done similarly to how the Linformer decreases the memory requirement of the internal matrices, which is by adding a parameterized matrix to the input that transforms it. A challenge here is that since this is a graph, not all nodes will connect to all other nodes. My plan is to explore techniques to reduce the size of the graph (the nodes, that is), pass it into the GAT, and then upscale it back to the original size.

Seeing that sparse attention has shown to perfom just as well as traditional attention, could it be the same for graphs? I will try some experiments and see if this is indeed the case.

This is not yet implemented.

Note: This idea has not been tested. I do not know what its performance will be on real life applications, and it may or may not provide accurate results.

Code Example

Right now, there exist two different versions of GAT: one for sparse graphs, and one for dense graphs. The idea in the end is to use only the dense version, since the sparse version runs slower. It is currently not possible to use the dense version on very large graphs, since it creates a matrix of size (n,n), which will quickly drain the system's memory.

As an example, this is how to use the sparse version:

import torch
from fast_gat import GraphAttentionNetwork

nodes = torch.tensor([[0.1, 0.2, 0.3], [0.4, 0.5, 0.6], [0.7, 0.8, 0.9], [1.0, 1.1, 1.2]], dtype= torch.float)
edges = {0: {1,2}, 1: {0,2,3}, 2: {0,1}, 3: {1}}

depth = 3
heads = 3
input_dim = 3
inner_dim = 2

net = GraphAttentionNetwork(depth, heads, input_dim, inner_dim)

output = net(nodes, edges)
print(output)

A point of interest here that one may notice is that the modules assume the graph is directed and that the edges have already been processed such that the nodes are zero indexed.

Downsampling method

The main thing that I am experimenting with here is to somehow reduce the number of input vertices in the input graph while keeping the edges connected in a way that makes sense. Some things that might work:

  • A learned iterative process; that is, some function f: V x V -> V that takes two nodes and makes it into one, run on the graph over several iterations. This could probably just be a single learned linear layer. The challenge here would be keeping the same procedure for upsampling, as well as preserving the edges somehow.
  • A fixed method related to the rank of each node. That is, lets say an edge connects v_i and v_j, then by some fixed rule, the two are joined and then split again.

The learned iterative process looks more appealing; further work will be done to look into it.

TODO

Further work that needs to be done

  • Create some sort of downsampling/upsampling method
  • Figure out how to get this to pypi
  • Optional: Create a test suite on every ci run? Could be cool to have ci/cd somehow here

Citation

@misc{veličković2018graph,
      title={Graph Attention Networks}, 
      author={Petar Veličković and Guillem Cucurull and Arantxa Casanova and Adriana Romero and Pietro Liò and Yoshua Bengio},
      year={2018},
      eprint={1710.10903},
      archivePrefix={arXiv},
      primaryClass={stat.ML}
}

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

fast_gat-0.1.0.tar.gz (6.2 kB view details)

Uploaded Source

Built Distribution

fast_gat-0.1.0-py3-none-any.whl (7.3 kB view details)

Uploaded Python 3

File details

Details for the file fast_gat-0.1.0.tar.gz.

File metadata

  • Download URL: fast_gat-0.1.0.tar.gz
  • Upload date:
  • Size: 6.2 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/3.4.2 importlib_metadata/4.8.1 pkginfo/1.7.1 requests/2.26.0 requests-toolbelt/0.9.1 tqdm/4.62.3 CPython/3.9.7

File hashes

Hashes for fast_gat-0.1.0.tar.gz
Algorithm Hash digest
SHA256 86ced2fc13bb7bf22a3c992986e9fe1c5dfe9cfb6745268347b3ce3ae4d8b293
MD5 0c5c48680f38a315fac8fcdd30b47b4a
BLAKE2b-256 5c3f8467f8775f4575b0158518731f810ef738fa526bd0aa53acce54af8afa70

See more details on using hashes here.

File details

Details for the file fast_gat-0.1.0-py3-none-any.whl.

File metadata

  • Download URL: fast_gat-0.1.0-py3-none-any.whl
  • Upload date:
  • Size: 7.3 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/3.4.2 importlib_metadata/4.8.1 pkginfo/1.7.1 requests/2.26.0 requests-toolbelt/0.9.1 tqdm/4.62.3 CPython/3.9.7

File hashes

Hashes for fast_gat-0.1.0-py3-none-any.whl
Algorithm Hash digest
SHA256 f75dea5b25ce0f5acab14519caf9855f5de3d854471d45027bce5747d8c69274
MD5 6540a20a880d513fe1de9cb5dc4ec102
BLAKE2b-256 806686be6489c8fb44104e17446ce78ca8d83eb53fb424fcc449c40e124350d3

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page