Skip to main content

Graph Neural Networks for Molecular Machine Learning

Project description

molgraph-title

Graph Neural Networks with TensorFlow and Keras. Focused on Molecular Machine Learning.

Currently, Keras 3 does not support extension types. As soon as it does, it is hoped that MolGraph will migrate to Keras 3.

Highlights

Build a Graph Neural Network with Keras' Sequential API:

from molgraph import GraphTensor
from molgraph import layers
from tensorflow import keras

g = GraphTensor(node_feature=[[4.], [2.]], edge_src=[0], edge_dst=[1])

model = keras.Sequential([
    layers.GNNInput(type_spec=g.spec),
    layers.GATv2Conv(units=32),
    layers.GATv2Conv(units=32),
    layers.Readout(),
    keras.layers.Dense(units=1),
])

pred = model(g)

# Save and load Keras model
model.save('/tmp/gatv2_model.keras')
loaded_model = keras.models.load_model('/tmp/gatv2_model.keras')
loaded_pred = loaded_model(g)
assert pred == loaded_pred

Combine outputs of GNN layers to improve predictive performance:

model = keras.Sequential([
    layers.GNNInput(type_spec=g.spec),
    layers.GNN([
        layers.FeatureProjection(units=32),
        layers.GINConv(units=32),
        layers.GINConv(units=32),
        layers.GINConv(units=32),
    ]),
    layers.Readout(),
    keras.layers.Dense(units=128),
    keras.layers.Dense(units=1),
])

model.summary()

Paper

See arXiv

Documentation

See readthedocs

Overview

molgraph-overview

Implementations

  • Graph tensor (GraphTensor)
    • A composite tensor holding graph data.
    • Has a ragged state (multiple graphs) and a non-ragged state (single disjoint graph).
    • Can conveniently go between both states (merge(), separate()).
    • Can propagate node states (features) based on edges (propagate()).
    • Can add, update and remove graph data (update(), remove()).
    • Compatible with TensorFlow's APIs (including Keras). For instance, graph data (encoded as a GraphTensor) can now seamlessly be used with keras.Sequential, keras.Functional, tf.data.Dataset, and tf.saved_model APIs.
  • Layers
  • Models
    • Although model building is easy with MolGraph, there are some built-in GNN models:
      • GIN
      • MPNN
      • DMPNN
    • And models for improved interpretability of GNNs:
      • SaliencyMapping
      • IntegratedSaliencyMapping
      • SmoothGradSaliencyMapping
      • GradientActivationMapping (Recommended)

Requirements/dependencies

  • Python (version >= 3.10)
    • TensorFlow (version 2.15.*)
    • RDKit (version 2023.9.*)
    • Pandas
    • IPython

Installation

For CPU users:

pip install molgraph

For GPU users:

pip install molgraph[gpu]

Now run your first program with MolGraph:

from tensorflow import keras
from molgraph import chemistry
from molgraph import layers
from molgraph import models

# Obtain dataset, specifically ESOL
esol = chemistry.datasets.get('esol')

# Define molecular graph encoder
atom_encoder = chemistry.Featurizer([
    chemistry.features.Symbol(),
    chemistry.features.Hybridization(),
    # ...
])

bond_encoder = chemistry.Featurizer([
    chemistry.features.BondType(),
    # ...
])

encoder = chemistry.MolecularGraphEncoder(atom_encoder, bond_encoder)

# Obtain graphs and associated labels
x_train = encoder(esol['train']['x'])
y_train = esol['train']['y']

x_test = encoder(esol['test']['x'])
y_test = esol['test']['y']

# Build model via Keras API
gnn_model = keras.Sequential([
    layers.GNNInputLayer(type_spec=x_train.spec),
    layers.GATConv(units=32),
    layers.GATConv(units=32),
    layers.Readout(),
    keras.layers.Dense(units=1024, activation='relu'),
    keras.layers.Dense(units=y_train.shape[-1])
])

# Compile, fit and evaluate
gnn_model.compile(optimizer='adam', loss='mae')
gnn_model.fit(x_train, y_train, epochs=50)
scores = gnn_model.evaluate(x_test, y_test)

# Compute gradient activation maps
gam_model = models.GradientActivationMapping(gnn_model)

maps = gam_model(x_train.separate())

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

molgraph-0.6.14.tar.gz (111.0 kB view details)

Uploaded Source

Built Distribution

molgraph-0.6.14-py3-none-any.whl (199.1 kB view details)

Uploaded Python 3

File details

Details for the file molgraph-0.6.14.tar.gz.

File metadata

  • Download URL: molgraph-0.6.14.tar.gz
  • Upload date:
  • Size: 111.0 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/5.0.0 CPython/3.10.12

File hashes

Hashes for molgraph-0.6.14.tar.gz
Algorithm Hash digest
SHA256 79cd27cc27e7c55cf8d5cc5e6b59587d3abdb6538d78d3854f0ff4d5c8a4b732
MD5 60555a2a6589f5aee8a8cc018ad6d278
BLAKE2b-256 a8d98c7b978c9d5357026dcf0ac18f4acc807f510fb93bb8e17d668bef8320b3

See more details on using hashes here.

File details

Details for the file molgraph-0.6.14-py3-none-any.whl.

File metadata

  • Download URL: molgraph-0.6.14-py3-none-any.whl
  • Upload date:
  • Size: 199.1 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/5.0.0 CPython/3.10.12

File hashes

Hashes for molgraph-0.6.14-py3-none-any.whl
Algorithm Hash digest
SHA256 a32df8842fd4a70e329e2a8ad5e7f49dbee5a9f2fc8935067cb73f87cbe6b84e
MD5 b003e7272ec56bc2b6219fc3b4b70b60
BLAKE2b-256 77e58444546b3dabeb882263f225fe4275da5415bdb5ff107da273bec30bd2d8

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page