Skip to main content

Implementations of graph neural networks for molecular machine learning

Project description

MolGraph: Graph Neural Networks for Molecular Machine Learning

This is an early release; things are still being updated, added and experimented with. Hence, API compatibility may break in the future. Any feedback is welcomed!

Important update: The GraphTensor is now a tf.experimental.ExtensionType (as of version 0.6.0). Although user code will likely break when updating to 0.6.0, the new GraphTensor was implemented to avoid this as much as possible. See GraphTensor documentation and GraphTensor walk through for more information on how to use it. Old features will for most part raise depracation warnings (though in the near future they will raise errors). Likely cause of breakage: the GraphTensor is now by default in its "non-ragged" state when obtained from the chemistry.MolecularGraphEncoder. The non-ragged GraphTensor can now be batched; i.e., a disjoint molecular graph, encoded by nested tf.Tensor values, can now be passed to a tf.data.Dataset and subsequently batched, unbatched, etc. There is no need to separate the GraphTensor beforehand and then merge it again. Finally, there is no need to pass an input type_spec to keras.Sequential model, making it even easier to code up and use a GNN models:

from molgraph import GraphTensor
from molgraph import layers
from tensorflow import keras

model = keras.Sequential([
    layers.GINConv(units=32),
    layers.GINConv(units=32),
    layers.Readout(),
    keras.layers.Dense(units=1),
])
output = model(
    GraphTensor(node_feature=[[4.], [2.]], edge_src=[0], edge_dst=[1])
)

Paper

See arXiv

Documentation

See readthedocs

Implementations

  • Graph tensor (GraphTensor)
    • A composite tensor holding graph data.
    • Has a ragged (multiple graphs) and a non-ragged state (single disjoint graph)
    • Can conveniently go between both states (merge(), separate())
    • Can propagate node information (features) based on edges (propagate())
    • Can add, update and remove graph data (update(), remove())
    • As it is now implemented with the TF's ExtensionType API, it is now compatible with TensorFlow's APIs (including Keras). For instance, graph data (encoded as a GraphTensor) can now seamlessly be used with keras.Sequential, keras.Functional, tf.data.Dataset, and tf.saved_model APIs.
  • Layers
  • Models
    • Although model building is easy with MolGraph, there are some built-in GNN models:
      • GIN
      • MPNN
      • DMPNN
    • And models for improved interpretability of GNNs:
      • SaliencyMapping
      • IntegratedSaliencyMapping
      • SmoothGradSaliencyMapping
      • GradientActivationMapping (Recommended)

Changelog

For a detailed list of changes, see the CHANGELOG.md.

Requirements/dependencies

  • Python (version >= 3.6 recommended)
    • TensorFlow (version >= 2.13.0 recommended)
    • RDKit (version >= 2022.3.5 recommended)
    • Pandas (version >= 1.0.3 recommended)
    • IPython (version == 8.12.0 recommended)

Installation

Install via pip:

pip install molgraph

Install via docker:

git clone https://github.com/akensert/molgraph.git
cd molgraph/docker
docker build -t molgraph-tf[-gpu][-jupyter]/molgraph:0.0 molgraph-tf[-gpu][-jupyter]/
docker run -it [-p 8888:8888] molgraph-tf[-gpu][-jupyter]/molgraph:0.0

Now run your first program with MolGraph:

from tensorflow import keras
from molgraph import chemistry
from molgraph import layers
from molgraph import models

# Obtain dataset, specifically ESOL
qm7 = chemistry.datasets.get('esol')

# Define molecular graph encoder
atom_encoder = chemistry.Featurizer([
    chemistry.features.Symbol(),
    chemistry.features.Hybridization(),
    # ...
])

bond_encoder = chemistry.Featurizer([
    chemistry.features.BondType(),
    # ...
])

encoder = chemistry.MolecularGraphEncoder(atom_encoder, bond_encoder)

# Obtain graphs and associated labels
x_train = encoder(qm7['train']['x'])
y_train = qm7['train']['y']

x_test = encoder(qm7['test']['x'])
y_test = qm7['test']['y']

# Build model via Keras API
gnn_model = keras.Sequential([
    layers.GATConv(units=32, name='gat_conv_1'),
    layers.GATConv(units=32, name='gat_conv_2'),
    layers.Readout(),
    keras.layers.Dense(units=1024, activation='relu'),
    keras.layers.Dense(units=y_train.shape[-1])
])

# Compile, fit and evaluate
gnn_model.compile(optimizer='adam', loss='mae')
gnn_model.fit(x_train, y_train, epochs=50)
scores = gnn_model.evaluate(x_test, y_test)

# Compute gradient activation maps
gam_model = models.GradientActivationMapping(
    model=gnn_model, layer_names=['gat_conv_1', 'gat_conv_2'])

maps = gam_model(x_train.separate())

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

molgraph-0.6.0.tar.gz (112.9 kB view details)

Uploaded Source

Built Distribution

molgraph-0.6.0-py3-none-any.whl (199.8 kB view details)

Uploaded Python 3

File details

Details for the file molgraph-0.6.0.tar.gz.

File metadata

  • Download URL: molgraph-0.6.0.tar.gz
  • Upload date:
  • Size: 112.9 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/4.0.2 CPython/3.10.6

File hashes

Hashes for molgraph-0.6.0.tar.gz
Algorithm Hash digest
SHA256 0c94358e691f74d65d1997cb2327f230f65be2ab30c16fa6cdc9b39208cbb08b
MD5 cdb64914312e2fa4b307c151b795ef3c
BLAKE2b-256 0e2fa73d64153813de7c3b103cd0ff4e8f0dde5852b605d5943a9888b0f5784d

See more details on using hashes here.

File details

Details for the file molgraph-0.6.0-py3-none-any.whl.

File metadata

  • Download URL: molgraph-0.6.0-py3-none-any.whl
  • Upload date:
  • Size: 199.8 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/4.0.2 CPython/3.10.6

File hashes

Hashes for molgraph-0.6.0-py3-none-any.whl
Algorithm Hash digest
SHA256 2b72a39fa910e4120dd9e79921b1f594aa418c1b5d2ca31f7c8687d0320c619e
MD5 909577ad11f5edd6393d9dbc0861bcd9
BLAKE2b-256 481b4612b6a2bd9d52cbc38e67717a6d3863dfc6b8b50ef9fea7374e1535af1c

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page