Skip to main content

General Base Layers for Graph Convolutions with Keras

Project description

GitHub release (latest by date) Documentation Status PyPI version PyPI - Downloads kgcnn_unit_tests DOI GitHub GitHub issues Maintenance

Keras Graph Convolution Neural Networks

General | Requirements | Installation | Documentation | Implementation details | Literature | Data | Datasets | Training | Issues | Citing | References

General

The package in kgcnn contains several layer classes to build up graph convolution models in Keras with Tensorflow, PyTorch or Jax as backend. Some models are given as an example in literature. A documentation is generated in docs. Focus of kgcnn is (batched) graph learning for molecules kgcnn.molecule and materials kgcnn.crystal. If you want to get in contact, feel free to discuss.

Note that kgcnn>=4.0.0 requires keras>=3.0.0. Previous versions of kgcnn were focused on ragged tensors of tensorflow, for which hyperparameter for models should also transfer to kgcnn 4.0 by adding input_tensor_type: "ragged" and checking the order and dtype of inputs.

Requirements

Standard python package requirements are installed automatically. However, you must make sure to install the GPU/TPU acceleration for the backend of your choice.

Installation

Clone repository or latest release and install with editable mode or latest release via Python Package Index.

pip install kgcnn

Documentation

Auto-documentation is generated at https://kgcnn.readthedocs.io/en/latest/index.html .

Implementation details

Representation

A graph of N nodes and M edges is commonly represented by a list of node or edge attributes: node_attr or edge_attr, respectively. Plus a list of indices pairs (i, j) that represents a directed edge in the graph: edge_index. The feature dimension of the attributes is denoted by F. Alternatively, an adjacency matrix A_ij of shape (N, N) can be ascribed that has 'ones' entries where there is an edge between nodes and 'zeros' elsewhere. Consequently, sum of A_ij will give M edges.

Input

For learning on batches or single graphs, following tensor representation can be chosen:

Batched Graphs
  • node_attr: Node attributes of shape (batch, N, F) and dtype float
  • edge_attr: Edge attributes of shape (batch, M, F) and dtype float
  • edge_index: Indices of shape (batch, M, 2) and dtype int
  • graph_attr: Graph attributes of shape (batch, F) and dtype float

Graphs are stacked along the batch dimension batch. Note that for flexible sized graphs the tensor has to be padded up to a max N/M or ragged tensors are used, with a ragged rank of one.

Disjoint Graphs
  • node_attr: Node attributes of shape ([N], F) and dtype float
  • edge_attr: Edge attributes of shape ([M], F) and dtype float
  • edge_index: Indices of shape (2, [M]) and dtype int
  • batch_ID: Graph ID of shape ([N], ) and dtype int

Here, the lists essentially represent one graph but which consists of disjoint sub-graphs from the batch, which has been introduced by PytorchGeometric (PyG). For pooling, the graph assignment is stored in batch_ID. Note, that for Jax, we can not have dynamic shapes, so we use a padded disjoint representation assigning all padded nodes to a discarded graph with zero index.

Model

The keras layers in kgcnn.layers can be used with PyG compatible tensor representation. Or even by simply wrapping a PyG model with TorchModuleWrapper. Efficient model loading can be achieved in multiple ways (see kgcnn.io). For most simple keras-like behaviour, the model can fed with batched padded or ragged tensor which are converted to/from disjoint representation wrapping the PyG equivalent model. Here an example of a minimal message passing GNN:

import keras as ks
from kgcnn.layers.casting import CastBatchedIndicesToDisjoint
from kgcnn.layers.gather import GatherNodes
from kgcnn.layers.pooling import PoolingNodes
from kgcnn.layers.aggr import AggregateLocalEdges

# Example for padded input.
ns = ks.layers.Input(shape=(None, 64), dtype="float32", name="node_attributes")
e_idx = ks.layers.Input(shape=(None, 2), dtype="int64", name="edge_indices")
total_n = ks.layers.Input(shape=(), dtype="int64", name="total_nodes")  # Or mask
total_e = ks.layers.Input(shape=(), dtype="int64", name="total_edges")  # Or mask

n, idx, batch_id, _, _, _, _, _ = CastBatchedIndicesToDisjoint(uses_mask=False)([ns, e_idx, total_n, total_e])
n_in_out = GatherNodes()([n, idx])
node_messages = ks.layers.Dense(64, activation='relu')(n_in_out)
node_updates = AggregateLocalEdges()([n, node_messages, idx])
n_node_updates = ks.layers.Concatenate()([n, node_updates])
n_embedding = ks.layers.Dense(1)(n_node_updates)
g_embedding = PoolingNodes()([total_n, n_embedding, batch_id])

message_passing = ks.models.Model(inputs=[ns, e_idx, total_n, total_e], outputs=g_embedding)

The actual message passing model can further be structured by e.g. subclassing the message passing base layer:

import keras as ks
from kgcnn.layers.message import MessagePassingBase

class MyMessageNN(MessagePassingBase):

    def __init__(self, units, **kwargs):
        super(MyMessageNN, self).__init__(**kwargs)
        self.dense = ks.layers.Dense(units)
        self.add = ks.layers.Add()

    def message_function(self, inputs, **kwargs):
        n_in, n_out, edges = inputs
        return self.dense(n_out, **kwargs)

    def update_nodes(self, inputs, **kwargs):
        nodes, nodes_update = inputs
        return self.add([nodes, nodes_update], **kwargs)

Literature

The following models, proposed in literature, have a module in literature. The module usually exposes a make_model function to create a keras.models.Model. The models can but must not be build completely from kgcnn.layers and can for example include original implementations (with proper licencing).

... and many more (click to expand).

Data

Data handling classes are given in kgcnn.data which stores graphs as List[Dict] .

Graph dictionary

Graphs are represented by a dictionary GraphDict of (numpy) arrays which behaves like a python dict. There are graph pre- and postprocessors in kgcnn.graph which take specific properties by name and apply a processing function or transformation.

[!IMPORTANT]
They can do any operation but note that GraphDict does not impose an actual graph structure! For example to sort edge indices make sure that all attributes are sorted accordingly.

from kgcnn.graph import GraphDict
# Single graph.
graph = GraphDict({"edge_indices": [[1, 0], [0, 1]], "node_label": [[0], [1]]})
graph.set("graph_labels", [0])  # use set(), get() to assign (tensor) properties.
graph.set("edge_attributes", [[1.0], [2.0]])
graph.to_networkx()
# Modify with e.g. preprocessor.
from kgcnn.graph.preprocessor import SortEdgeIndices
SortEdgeIndices(edge_indices="edge_indices", edge_attributes="^edge_(?!indices$).*", in_place=True)(graph)

List of graph dictionaries

A MemoryGraphList should behave identical to a python list but contain only GraphDict items.

from kgcnn.data import MemoryGraphList
# List of graph dicts.
graph_list = MemoryGraphList([{"edge_indices": [[0, 1], [1, 0]]}, {"edge_indices": [[0, 0]]}, {}])
graph_list.clean(["edge_indices"])  # Remove graphs without property
graph_list.get("edge_indices")  # opposite is set()
# Easily cast to tensor; makes copy.
tensor = graph_list.tensor([{"name": "edge_indices"}])  # config of keras `Input` layer
# Or directly modify list.
for i, x in enumerate(graph_list):
    x.set("graph_number", [i])
print(len(graph_list), graph_list[:2])  # Also supports indexing lists.

Datasets

The MemoryGraphDataset inherits from MemoryGraphList but must be initialized with file information on disk that points to a data_directory for the dataset. The data_directory can have a subdirectory for files and/or single file such as a CSV file:

├── data_directory
    ├── file_directory
       ├── *.*
       └── ... 
    ├── file_name
    └── dataset_name.kgcnn.pickle

A base dataset class is created with path and name information:

from kgcnn.data import MemoryGraphDataset
dataset = MemoryGraphDataset(data_directory="ExampleDir/", 
                             dataset_name="Example",
                             file_name=None, file_directory=None)
dataset.save()  # opposite is load(). 

The subclasses QMDataset, ForceDataset, MoleculeNetDataset, CrystalDataset and GraphTUDataset further have functions required for the specific dataset type to convert and process files such as '.txt', '.sdf', '.xyz' etc. Most subclasses implement prepare_data() and read_in_memory() with dataset dependent arguments. An example for MoleculeNetDataset is shown below. For more details find tutorials in notebooks.

from kgcnn.data.moleculenet import MoleculeNetDataset
# File directory and files must exist. 
# Here 'ExampleDir' and 'ExampleDir/data.csv' with columns "smiles" and "label".
dataset = MoleculeNetDataset(dataset_name="Example",
                             data_directory="ExampleDir/",
                             file_name="data.csv")
dataset.prepare_data(overwrite=True, smiles_column_name="smiles", add_hydrogen=True,
                     make_conformers=True, optimize_conformer=True, num_workers=None)
dataset.read_in_memory(label_column_name="label", add_hydrogen=False, 
                       has_conformers=True)

In data.datasets there are graph learning benchmark datasets as subclasses which are being downloaded from e.g. popular graph archives like TUDatasets, MatBench or MoleculeNet. The subclasses GraphTUDataset2020, MatBenchDataset2020 and MoleculeNetDataset2018 download and read the available datasets by name. There are also specific dataset subclasses for each dataset to handle additional processing or downloading from individual sources:

from kgcnn.data.datasets.MUTAGDataset import MUTAGDataset
dataset = MUTAGDataset()  # inherits from GraphTUDataset2020

Downloaded datasets are stored in ~/.kgcnn/datasets on your computer. Please remove them manually, if no longer required.

Training

A set of example training can be found in training. Training scripts are configurable with a hyperparameter config file and command line arguments regarding model and dataset.

You can find a table of common benchmark datasets in results.

Issues

Some known issues to be aware of, if using and making new models or layers with kgcnn.

  • Jagged or nested Tensors loading into models for PyTorch backend is not working.
  • BatchNormalization layer dos not support padding yet.
  • Keras AUC metric does not seem to work for torch cuda.

Citing

If you want to cite this repo, please refer to our paper:

@article{REISER2021100095,
title = {Graph neural networks in TensorFlow-Keras with RaggedTensor representation (kgcnn)},
journal = {Software Impacts},
pages = {100095},
year = {2021},
issn = {2665-9638},
doi = {https://doi.org/10.1016/j.simpa.2021.100095},
url = {https://www.sciencedirect.com/science/article/pii/S266596382100035X},
author = {Patrick Reiser and Andre Eberhard and Pascal Friederich}
}

References

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

kgcnn-4.0.1.tar.gz (347.0 kB view details)

Uploaded Source

Built Distribution

kgcnn-4.0.1-py3-none-any.whl (468.8 kB view details)

Uploaded Python 3

File details

Details for the file kgcnn-4.0.1.tar.gz.

File metadata

  • Download URL: kgcnn-4.0.1.tar.gz
  • Upload date:
  • Size: 347.0 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/4.0.2 CPython/3.11.4

File hashes

Hashes for kgcnn-4.0.1.tar.gz
Algorithm Hash digest
SHA256 db3868850b9f41ebf9b9f4dff74ad8d53a96f13045e97d1a1db158d69d9d8c4c
MD5 2cb4a25898dda5fcf73ec135bfbb0ceb
BLAKE2b-256 c5d0941f50fa2d2c483bb63127305a3d942c2685b2407a7ee2dca98ec8bb4e0b

See more details on using hashes here.

File details

Details for the file kgcnn-4.0.1-py3-none-any.whl.

File metadata

  • Download URL: kgcnn-4.0.1-py3-none-any.whl
  • Upload date:
  • Size: 468.8 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/4.0.2 CPython/3.11.4

File hashes

Hashes for kgcnn-4.0.1-py3-none-any.whl
Algorithm Hash digest
SHA256 befb13820023b3796e2c70c76092ac6ef885a0734ffd664b333122eaafbdbb12
MD5 fe77bcf831e24f9c9ff62c8380894ba4
BLAKE2b-256 a816e3e5ec7dfb9710bede7d8941aa5cc64fa7fcc49975d227529f3d360c0ea4

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page