Implementations of graph neural networks for molecular machine learning
Project description
MolGraph: Graph Neural Networks for Molecular Machine Learning
This is an early release; things are still being updated, added and experimented with. Hence, API compatibility may break in the future. Any feedback is welcomed!
Important update: The GraphTensor
is now a tf.experimental.ExtensionType (as of version 0.6.0). Although user code will likely break when updating to 0.6.0, the new GraphTensor
was implemented to avoid this as much as possible. See GraphTensor documentation and GraphTensor walk through for more information on how to use it. Old features will for most part raise depracation warnings (though in the near future they will raise errors). Likely cause of breakage: the GraphTensor
is now by default in its "non-ragged" state when obtained from the chemistry.MolecularGraphEncoder
. The non-ragged GraphTensor
can now be batched; i.e., a disjoint molecular graph, encoded by nested tf.Tensor
values, can now be passed to a tf.data.Dataset
and subsequently batched, unbatched, etc. There is no need to separate the GraphTensor
beforehand and then merge it again. Finally, there is no need to pass an input type_spec to keras.Sequential model, making it even easier to code up and use a GNN models:
from molgraph import GraphTensor
from molgraph import layers
from tensorflow import keras
model = keras.Sequential([
layers.GINConv(units=32),
layers.GINConv(units=32),
layers.Readout(),
keras.layers.Dense(units=1),
])
output = model(
GraphTensor(node_feature=[[4.], [2.]], edge_src=[0], edge_dst=[1])
)
Paper
See arXiv
Documentation
See readthedocs
Implementations
- Graph tensor (GraphTensor)
- A composite tensor holding graph data.
- Has a ragged (multiple graphs) and a non-ragged state (single disjoint graph)
- Can conveniently go between both states (merge(), separate())
- Can propagate node information (features) based on edges (propagate())
- Can add, update and remove graph data (update(), remove())
- As it is now implemented with the TF's ExtensionType API, it is now compatible with TensorFlow's APIs (including Keras). For instance, graph data (encoded as a GraphTensor) can now seamlessly be used with keras.Sequential, keras.Functional, tf.data.Dataset, and tf.saved_model APIs.
- Layers
- Convolutional
- GCNConv (GCNConv)
- GINConv (GINConv)
- GCNIIConv (GCNIIConv)
- GraphSageConv (GraphSageConv)
- Attentional
- GATConv (GATConv)
- GATv2Conv (GATv2Conv)
- GTConv (GTConv)
- GMMConv (GMMConv)
- GatedGCNConv (GatedGCNConv)
- AttentiveFPConv (AttentiveFPConv)
- Message-passing
- Distance-geometric
- Pre- and post-processing
- In addition to the aforementioned GNN layers, there are also several other layers which improves model-building. See readout/, preprocessing/, postprocessing/, positional_encoding/.
- Convolutional
- Models
- Although model building is easy with MolGraph, there are some built-in GNN models:
- GIN
- MPNN
- DMPNN
- And models for improved interpretability of GNNs:
- SaliencyMapping
- IntegratedSaliencyMapping
- SmoothGradSaliencyMapping
- GradientActivationMapping (Recommended)
- Although model building is easy with MolGraph, there are some built-in GNN models:
Changelog
For a detailed list of changes, see the CHANGELOG.md.
Requirements/dependencies
- Python (version >= 3.6 recommended)
- TensorFlow (version >= 2.13.0 recommended)
- RDKit (version >= 2022.3.5 recommended)
- Pandas (version >= 1.0.3 recommended)
- IPython (version == 8.12.0 recommended)
Installation
Install via pip:
pip install molgraph
Install via docker:
git clone https://github.com/akensert/molgraph.git cd molgraph/docker docker build -t molgraph-tf[-gpu][-jupyter]/molgraph:0.0 molgraph-tf[-gpu][-jupyter]/ docker run -it [-p 8888:8888] molgraph-tf[-gpu][-jupyter]/molgraph:0.0
Now run your first program with MolGraph:
from tensorflow import keras
from molgraph import chemistry
from molgraph import layers
from molgraph import models
# Obtain dataset, specifically ESOL
qm7 = chemistry.datasets.get('esol')
# Define molecular graph encoder
atom_encoder = chemistry.Featurizer([
chemistry.features.Symbol(),
chemistry.features.Hybridization(),
# ...
])
bond_encoder = chemistry.Featurizer([
chemistry.features.BondType(),
# ...
])
encoder = chemistry.MolecularGraphEncoder(atom_encoder, bond_encoder)
# Obtain graphs and associated labels
x_train = encoder(qm7['train']['x'])
y_train = qm7['train']['y']
x_test = encoder(qm7['test']['x'])
y_test = qm7['test']['y']
# Build model via Keras API
gnn_model = keras.Sequential([
layers.GATConv(units=32, name='gat_conv_1'),
layers.GATConv(units=32, name='gat_conv_2'),
layers.Readout(),
keras.layers.Dense(units=1024, activation='relu'),
keras.layers.Dense(units=y_train.shape[-1])
])
# Compile, fit and evaluate
gnn_model.compile(optimizer='adam', loss='mae')
gnn_model.fit(x_train, y_train, epochs=50)
scores = gnn_model.evaluate(x_test, y_test)
# Compute gradient activation maps
gam_model = models.GradientActivationMapping(
model=gnn_model, layer_names=['gat_conv_1', 'gat_conv_2'])
maps = gam_model(x_train.separate())
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
File details
Details for the file molgraph-0.6.0.tar.gz
.
File metadata
- Download URL: molgraph-0.6.0.tar.gz
- Upload date:
- Size: 112.9 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/4.0.2 CPython/3.10.6
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | 0c94358e691f74d65d1997cb2327f230f65be2ab30c16fa6cdc9b39208cbb08b |
|
MD5 | cdb64914312e2fa4b307c151b795ef3c |
|
BLAKE2b-256 | 0e2fa73d64153813de7c3b103cd0ff4e8f0dde5852b605d5943a9888b0f5784d |
File details
Details for the file molgraph-0.6.0-py3-none-any.whl
.
File metadata
- Download URL: molgraph-0.6.0-py3-none-any.whl
- Upload date:
- Size: 199.8 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/4.0.2 CPython/3.10.6
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | 2b72a39fa910e4120dd9e79921b1f594aa418c1b5d2ca31f7c8687d0320c619e |
|
MD5 | 909577ad11f5edd6393d9dbc0861bcd9 |
|
BLAKE2b-256 | 481b4612b6a2bd9d52cbc38e67717a6d3863dfc6b8b50ef9fea7374e1535af1c |