Graph Neural Networks for Molecular Machine Learning
Project description
Graph Neural Networks with TensorFlow and Keras. Focused on Molecular Machine Learning.
Currently, Keras 3 does not support extension types. As soon as it does, it is hoped that MolGraph will migrate to Keras 3.
Highlights
Build a Graph Neural Network with Keras' Sequential API:
from molgraph import GraphTensor
from molgraph import layers
from tensorflow import keras
g = GraphTensor(node_feature=[[4.], [2.]], edge_src=[0], edge_dst=[1])
model = keras.Sequential([
layers.GNNInput(type_spec=g.spec),
layers.GATv2Conv(units=32),
layers.GATv2Conv(units=32),
layers.Readout(),
keras.layers.Dense(units=1),
])
pred = model(g)
# Save and load Keras model
model.save('/tmp/gatv2_model.keras')
loaded_model = keras.models.load_model('/tmp/gatv2_model.keras')
loaded_pred = loaded_model(g)
assert pred == loaded_pred
Combine outputs of GNN layers to improve predictive performance:
model = keras.Sequential([
layers.GNNInput(type_spec=g.spec),
layers.GNN([
layers.FeatureProjection(units=32),
layers.GINConv(units=32),
layers.GINConv(units=32),
layers.GINConv(units=32),
]),
layers.Readout(),
keras.layers.Dense(units=128),
keras.layers.Dense(units=1),
])
model.summary()
Paper
See arXiv
Documentation
See readthedocs
Overview
Implementations
- Graph tensor (GraphTensor)
- A composite tensor holding graph data.
- Has a ragged state (multiple graphs) and a non-ragged state (single disjoint graph).
- Can conveniently go between both states (merge(), separate()).
- Can propagate node states (features) based on edges (propagate()).
- Can add, update and remove graph data (update(), remove()).
- Compatible with TensorFlow's APIs (including Keras). For instance, graph data (encoded as a GraphTensor) can now seamlessly be used with keras.Sequential, keras.Functional, tf.data.Dataset, and tf.saved_model APIs.
- Layers
- Convolutional
- GCNConv (GCNConv)
- GINConv (GINConv)
- GCNIIConv (GCNIIConv)
- GraphSageConv (GraphSageConv)
- Attentional
- GATConv (GATConv)
- GATv2Conv (GATv2Conv)
- GTConv (GTConv)
- GMMConv (GMMConv)
- GatedGCNConv (GatedGCNConv)
- AttentiveFPConv (AttentiveFPConv)
- Message-passing
- Distance-geometric
- Pre- and post-processing
- In addition to the aforementioned GNN layers, there are also several other layers which improves model-building. See readout/, preprocessing/, postprocessing/, positional_encoding/.
- Convolutional
- Models
- Although model building is easy with MolGraph, there are some built-in GNN models:
- GIN
- MPNN
- DMPNN
- And models for improved interpretability of GNNs:
- SaliencyMapping
- IntegratedSaliencyMapping
- SmoothGradSaliencyMapping
- GradientActivationMapping (Recommended)
- Although model building is easy with MolGraph, there are some built-in GNN models:
Requirements/dependencies
- Python (version >= 3.10)
- TensorFlow (version 2.15.*)
- RDKit (version 2023.9.*)
- Pandas
- IPython
Installation
For CPU users:
pip install molgraph
For GPU users:
pip install molgraph[gpu]
Now run your first program with MolGraph:
from tensorflow import keras
from molgraph import chemistry
from molgraph import layers
from molgraph import models
# Obtain dataset, specifically ESOL
esol = chemistry.datasets.get('esol')
# Define molecular graph encoder
atom_encoder = chemistry.Featurizer([
chemistry.features.Symbol(),
chemistry.features.Hybridization(),
# ...
])
bond_encoder = chemistry.Featurizer([
chemistry.features.BondType(),
# ...
])
encoder = chemistry.MolecularGraphEncoder(atom_encoder, bond_encoder)
# Obtain graphs and associated labels
x_train = encoder(esol['train']['x'])
y_train = esol['train']['y']
x_test = encoder(esol['test']['x'])
y_test = esol['test']['y']
# Build model via Keras API
gnn_model = keras.Sequential([
layers.GNNInputLayer(type_spec=x_train.spec),
layers.GATConv(units=32),
layers.GATConv(units=32),
layers.Readout(),
keras.layers.Dense(units=1024, activation='relu'),
keras.layers.Dense(units=y_train.shape[-1])
])
# Compile, fit and evaluate
gnn_model.compile(optimizer='adam', loss='mae')
gnn_model.fit(x_train, y_train, epochs=50)
scores = gnn_model.evaluate(x_test, y_test)
# Compute gradient activation maps
gam_model = models.GradientActivationMapping(gnn_model)
maps = gam_model(x_train.separate())
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
molgraph-0.6.13.tar.gz
(110.9 kB
view details)
Built Distribution
molgraph-0.6.13-py3-none-any.whl
(199.0 kB
view details)
File details
Details for the file molgraph-0.6.13.tar.gz
.
File metadata
- Download URL: molgraph-0.6.13.tar.gz
- Upload date:
- Size: 110.9 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/5.0.0 CPython/3.10.12
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | 8c38a11a06d911da7e62c93bcdf8bd89df4aae49fa79e666d9923c783d08e6c7 |
|
MD5 | 7a9babf630a62ee6a0105f88facb6402 |
|
BLAKE2b-256 | 20b3ee8813f9bf5c876f226b1d34152f2e4fe84bc889edfcdf2e0b405fcc7164 |
File details
Details for the file molgraph-0.6.13-py3-none-any.whl
.
File metadata
- Download URL: molgraph-0.6.13-py3-none-any.whl
- Upload date:
- Size: 199.0 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/5.0.0 CPython/3.10.12
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | 5537038fd3f83301cbd8c86ab83b62442d3f8155cb83eba8a35bd4f582bb5970 |
|
MD5 | 8e96339836ce923d28cf7465fef8c455 |
|
BLAKE2b-256 | e58b4572b72cb8825fb184770c11b4e461830febc3b60e63682040c1fb8d6c75 |