General Base Layers for Graph Convolutions with Keras
Project description
Keras Graph Convolution Neural Networks
General | Requirements | Installation | Documentation | Implementation details | Literature | Data | Datasets | Training | Issues | Citing | References
General
The package in kgcnn contains several layer classes to build up graph convolution models in Keras with Tensorflow, PyTorch or Jax as backend. Some models are given as an example in literature. A documentation is generated in docs. Focus of kgcnn is (batched) graph learning for molecules kgcnn.molecule and materials kgcnn.crystal. If you want to get in contact, feel free to discuss.
Note that kgcnn>=4.0.0 requires keras>=3.0.0. Previous versions of kgcnn were focused on ragged tensors of tensorflow, for which
hyperparameter for models should also transfer to kgcnn 4.0 by adding input_tensor_type: "ragged"
and checking the order and dtype of inputs.
Requirements
Standard python package requirements are installed automatically. However, you must make sure to install the GPU/TPU acceleration for the backend of your choice.
Installation
Clone repository or latest release and install with editable mode or latest release via Python Package Index.
pip install kgcnn
Documentation
Auto-documentation is generated at https://kgcnn.readthedocs.io/en/latest/index.html .
Implementation details
Representation
A graph of N
nodes and M
edges is commonly represented by a list of node or edge attributes: node_attr
or edge_attr
, respectively.
Plus a list of indices pairs (i, j)
that represents a directed edge in the graph: edge_index
.
The feature dimension of the attributes is denoted by F
.
Alternatively, an adjacency matrix A_ij
of shape (N, N)
can be ascribed that has 'ones' entries
where there is an edge between nodes and 'zeros' elsewhere. Consequently, sum of A_ij
will give M
edges.
Input
For learning on batches or single graphs, following tensor representation can be chosen:
Batched Graphs
node_attr
: Node attributes of shape(batch, N, F)
and dtype floatedge_attr
: Edge attributes of shape(batch, M, F)
and dtype floatedge_index
: Indices of shape(batch, M, 2)
and dtype intgraph_attr
: Graph attributes of shape(batch, F)
and dtype float
Graphs are stacked along the batch dimension batch
. Note that for flexible sized graphs the tensor has to be padded up to a max N
/M
or ragged tensors are used,
with a ragged rank of one.
Disjoint Graphs
node_attr
: Node attributes of shape([N], F)
and dtype floatedge_attr
: Edge attributes of shape([M], F)
and dtype floatedge_index
: Indices of shape(2, [M])
and dtype intbatch_ID
: Graph ID of shape([N], )
and dtype int
Here, the lists essentially represent one graph but which consists of disjoint sub-graphs from the batch,
which has been introduced by PytorchGeometric (PyG).
For pooling, the graph assignment is stored in batch_ID
.
Note, that for Jax, we can not have dynamic shapes, so we use a padded disjoint representation assigning
all padded nodes to a discarded graph with zero index.
Model
The keras layers in kgcnn.layers can be used with PyG compatible tensor representation.
Or even by simply wrapping a PyG model with TorchModuleWrapper
. Efficient model loading can be achieved
in multiple ways (see kgcnn.io).
For most simple keras-like behaviour, the model can fed with batched padded or ragged tensor which are converted to/from
disjoint representation wrapping the PyG equivalent model.
Here an example of a minimal message passing GNN:
import keras as ks
from kgcnn.layers.casting import CastBatchedIndicesToDisjoint
from kgcnn.layers.gather import GatherNodes
from kgcnn.layers.pooling import PoolingNodes
from kgcnn.layers.aggr import AggregateLocalEdges
# Example for padded input.
ns = ks.layers.Input(shape=(None, 64), dtype="float32", name="node_attributes")
e_idx = ks.layers.Input(shape=(None, 2), dtype="int64", name="edge_indices")
total_n = ks.layers.Input(shape=(), dtype="int64", name="total_nodes") # Or mask
total_e = ks.layers.Input(shape=(), dtype="int64", name="total_edges") # Or mask
n, idx, batch_id, _, _, _, _, _ = CastBatchedIndicesToDisjoint(uses_mask=False)([ns, e_idx, total_n, total_e])
n_in_out = GatherNodes()([n, idx])
node_messages = ks.layers.Dense(64, activation='relu')(n_in_out)
node_updates = AggregateLocalEdges()([n, node_messages, idx])
n_node_updates = ks.layers.Concatenate()([n, node_updates])
n_embedding = ks.layers.Dense(1)(n_node_updates)
g_embedding = PoolingNodes()([total_n, n_embedding, batch_id])
message_passing = ks.models.Model(inputs=[ns, e_idx, total_n, total_e], outputs=g_embedding)
The actual message passing model can further be structured by e.g. subclassing the message passing base layer:
import keras as ks
from kgcnn.layers.message import MessagePassingBase
class MyMessageNN(MessagePassingBase):
def __init__(self, units, **kwargs):
super(MyMessageNN, self).__init__(**kwargs)
self.dense = ks.layers.Dense(units)
self.add = ks.layers.Add()
def message_function(self, inputs, **kwargs):
n_in, n_out, edges = inputs
return self.dense(n_out, **kwargs)
def update_nodes(self, inputs, **kwargs):
nodes, nodes_update = inputs
return self.add([nodes, nodes_update], **kwargs)
Literature
The following models, proposed in literature, have a module in literature. The module usually exposes a make_model
function
to create a keras.models.Model
. The models can but must not be build completely from kgcnn.layers
and can for example include
original implementations (with proper licencing).
- AttentiveFP: Pushing the Boundaries of Molecular Representation for Drug Discovery with the Graph Attention Mechanism by Xiong et al. (2019)
- CGCNN: Crystal Graph Convolutional Neural Networks for an Accurate and Interpretable Prediction of Material Properties by Xie et al. (2018)
- CMPNN: Communicative Representation Learning on Attributed Molecular Graphs by Song et al. (2020)
- DGIN: Improved Lipophilicity and Aqueous Solubility Prediction with Composite Graph Neural Networks by Wieder et al. (2021)
- DimeNetPP: Fast and Uncertainty-Aware Directional Message Passing for Non-Equilibrium Molecules by Klicpera et al. (2020)
- DMPNN: Analyzing Learned Molecular Representations for Property Prediction by Yang et al. (2019)
- EGNN: E(n) Equivariant Graph Neural Networks by Satorras et al. (2021)
- GAT: Graph Attention Networks by Veličković et al. (2018)
... and many more (click to expand).
- GATv2: How Attentive are Graph Attention Networks? by Brody et al. (2021)
- GCN: Semi-Supervised Classification with Graph Convolutional Networks by Kipf et al. (2016)
- GIN: How Powerful are Graph Neural Networks? by Xu et al. (2019)
- GNNExplainer: GNNExplainer: Generating Explanations for Graph Neural Networks by Ying et al. (2019)
- GNNFilm: GNN-FiLM: Graph Neural Networks with Feature-wise Linear Modulation by Marc Brockschmidt (2020)
- GraphSAGE: Inductive Representation Learning on Large Graphs by Hamilton et al. (2017)
- HamNet: HamNet: Conformation-Guided Molecular Representation with Hamiltonian Neural Networks by Li et al. (2021)
- HDNNP2nd: Atom-centered symmetry functions for constructing high-dimensional neural network potentials by Jörg Behler (2011)
- INorp: Interaction Networks for Learning about Objects,Relations and Physics by Battaglia et al. (2016)
- MAT: Molecule Attention Transformer by Maziarka et al. (2020)
- MEGAN: MEGAN: Multi-explanation Graph Attention Network by Teufel et al. (2023)
- Megnet: Graph Networks as a Universal Machine Learning Framework for Molecules and Crystals by Chen et al. (2019)
- MoGAT: Multi-order graph attention network for water solubility prediction and interpretation by Lee et al. (2023)
- MXMNet: Molecular Mechanics-Driven Graph Neural Network with Multiplex Graph for Molecular Structures by Zhang et al. (2020)
- NMPN: Neural Message Passing for Quantum Chemistry by Gilmer et al. (2017)
- PAiNN: Equivariant message passing for the prediction of tensorial properties and molecular spectra by Schütt et al. (2020)
- RGCN: Modeling Relational Data with Graph Convolutional Networks by Schlichtkrull et al. (2017)
- rGIN Random Features Strengthen Graph Neural Networks by Sato et al. (2020)
- Schnet: SchNet – A deep learning architecture for molecules and materials by Schütt et al. (2017)
Data
Data handling classes are given in kgcnn.data
which stores graphs as List[Dict]
.
Graph dictionary
Graphs are represented by a dictionary GraphDict
of (numpy) arrays which behaves like a python dict
.
There are graph pre- and postprocessors in kgcnn.graph
which take specific properties by name and apply a
processing function or transformation.
[!IMPORTANT]
They can do any operation but note thatGraphDict
does not impose an actual graph structure! For example to sort edge indices make sure that all attributes are sorted accordingly.
from kgcnn.graph import GraphDict
# Single graph.
graph = GraphDict({"edge_indices": [[1, 0], [0, 1]], "node_label": [[0], [1]]})
graph.set("graph_labels", [0]) # use set(), get() to assign (tensor) properties.
graph.set("edge_attributes", [[1.0], [2.0]])
graph.to_networkx()
# Modify with e.g. preprocessor.
from kgcnn.graph.preprocessor import SortEdgeIndices
SortEdgeIndices(edge_indices="edge_indices", edge_attributes="^edge_(?!indices$).*", in_place=True)(graph)
List of graph dictionaries
A MemoryGraphList
should behave identical to a python list but contain only GraphDict
items.
from kgcnn.data import MemoryGraphList
# List of graph dicts.
graph_list = MemoryGraphList([{"edge_indices": [[0, 1], [1, 0]]}, {"edge_indices": [[0, 0]]}, {}])
graph_list.clean(["edge_indices"]) # Remove graphs without property
graph_list.get("edge_indices") # opposite is set()
# Easily cast to tensor; makes copy.
tensor = graph_list.tensor([{"name": "edge_indices"}]) # config of keras `Input` layer
# Or directly modify list.
for i, x in enumerate(graph_list):
x.set("graph_number", [i])
print(len(graph_list), graph_list[:2]) # Also supports indexing lists.
Datasets
The MemoryGraphDataset
inherits from MemoryGraphList
but must be initialized with file information on disk that points to a data_directory
for the dataset.
The data_directory
can have a subdirectory for files and/or single file such as a CSV file:
├── data_directory
├── file_directory
│ ├── *.*
│ └── ...
├── file_name
└── dataset_name.kgcnn.pickle
A base dataset class is created with path and name information:
from kgcnn.data import MemoryGraphDataset
dataset = MemoryGraphDataset(data_directory="ExampleDir/",
dataset_name="Example",
file_name=None, file_directory=None)
dataset.save() # opposite is load().
The subclasses QMDataset
, ForceDataset
, MoleculeNetDataset
, CrystalDataset
and GraphTUDataset
further have functions required for the specific dataset type to convert and process files such as '.txt', '.sdf', '.xyz' etc.
Most subclasses implement prepare_data()
and read_in_memory()
with dataset dependent arguments.
An example for MoleculeNetDataset
is shown below.
For more details find tutorials in notebooks.
from kgcnn.data.moleculenet import MoleculeNetDataset
# File directory and files must exist.
# Here 'ExampleDir' and 'ExampleDir/data.csv' with columns "smiles" and "label".
dataset = MoleculeNetDataset(dataset_name="Example",
data_directory="ExampleDir/",
file_name="data.csv")
dataset.prepare_data(overwrite=True, smiles_column_name="smiles", add_hydrogen=True,
make_conformers=True, optimize_conformer=True, num_workers=None)
dataset.read_in_memory(label_column_name="label", add_hydrogen=False,
has_conformers=True)
In data.datasets there are graph learning benchmark datasets as subclasses which are being downloaded from e.g. popular graph archives like TUDatasets, MatBench or MoleculeNet.
The subclasses GraphTUDataset2020
, MatBenchDataset2020
and MoleculeNetDataset2018
download and read the available datasets by name.
There are also specific dataset subclasses for each dataset to handle additional processing or downloading from individual sources:
from kgcnn.data.datasets.MUTAGDataset import MUTAGDataset
dataset = MUTAGDataset() # inherits from GraphTUDataset2020
Downloaded datasets are stored in ~/.kgcnn/datasets
on your computer. Please remove them manually, if no longer required.
Training
A set of example training can be found in training. Training scripts are configurable with a hyperparameter config file and command line arguments regarding model and dataset.
You can find a table of common benchmark datasets in results.
Issues
Some known issues to be aware of, if using and making new models or layers with kgcnn
.
- Jagged or nested Tensors loading into models for PyTorch backend is not working.
- BatchNormalization layer dos not support padding yet.
- Keras AUC metric does not seem to work for torch cuda.
Citing
If you want to cite this repo, please refer to our paper:
@article{REISER2021100095,
title = {Graph neural networks in TensorFlow-Keras with RaggedTensor representation (kgcnn)},
journal = {Software Impacts},
pages = {100095},
year = {2021},
issn = {2665-9638},
doi = {https://doi.org/10.1016/j.simpa.2021.100095},
url = {https://www.sciencedirect.com/science/article/pii/S266596382100035X},
author = {Patrick Reiser and Andre Eberhard and Pascal Friederich}
}
References
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
File details
Details for the file kgcnn-4.0.1.tar.gz
.
File metadata
- Download URL: kgcnn-4.0.1.tar.gz
- Upload date:
- Size: 347.0 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/4.0.2 CPython/3.11.4
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | db3868850b9f41ebf9b9f4dff74ad8d53a96f13045e97d1a1db158d69d9d8c4c |
|
MD5 | 2cb4a25898dda5fcf73ec135bfbb0ceb |
|
BLAKE2b-256 | c5d0941f50fa2d2c483bb63127305a3d942c2685b2407a7ee2dca98ec8bb4e0b |
File details
Details for the file kgcnn-4.0.1-py3-none-any.whl
.
File metadata
- Download URL: kgcnn-4.0.1-py3-none-any.whl
- Upload date:
- Size: 468.8 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/4.0.2 CPython/3.11.4
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | befb13820023b3796e2c70c76092ac6ef885a0734ffd664b333122eaafbdbb12 |
|
MD5 | fe77bcf831e24f9c9ff62c8380894ba4 |
|
BLAKE2b-256 | a816e3e5ec7dfb9710bede7d8941aa5cc64fa7fcc49975d227529f3d360c0ea4 |