Skip to main content

Graph Neural Networks for Molecular Machine Learning

Project description

molgraph-title

Graph Neural Networks with TensorFlow and Keras. Focused on Molecular Machine Learning.

Currently, Keras 3 does not support extension types. As soon as it does, it is hoped that MolGraph will migrate to Keras 3.

Highlights

Build a Graph Neural Network with Keras' Sequential API:

from molgraph import GraphTensor
from molgraph import layers
from tensorflow import keras

g = GraphTensor(node_feature=[[4.], [2.]], edge_src=[0], edge_dst=[1])

model = keras.Sequential([
    layers.GNNInput(type_spec=g.spec),
    layers.GATv2Conv(units=32),
    layers.GATv2Conv(units=32),
    layers.Readout(),
    keras.layers.Dense(units=1),
])

pred = model(g)

# Save and load Keras model
model.save('/tmp/gatv2_model.keras')
loaded_model = keras.models.load_model('/tmp/gatv2_model.keras')
loaded_pred = loaded_model(g)
assert pred == loaded_pred

Combine outputs of GNN layers to improve predictive performance:

model = keras.Sequential([
    layers.GNNInput(type_spec=g.spec),
    layers.GNN([
        layers.FeatureProjection(units=32),
        layers.GINConv(units=32),
        layers.GINConv(units=32),
        layers.GINConv(units=32),
    ]),
    layers.Readout(),
    keras.layers.Dense(units=128),
    keras.layers.Dense(units=1),
])

model.summary()

Paper

See arXiv

Documentation

See readthedocs

Overview

molgraph-overview

Implementations

  • Graph tensor (GraphTensor)
    • A composite tensor holding graph data.
    • Has a ragged state (multiple graphs) and a non-ragged state (single disjoint graph).
    • Can conveniently go between both states (merge(), separate()).
    • Can propagate node states (features) based on edges (propagate()).
    • Can add, update and remove graph data (update(), remove()).
    • Compatible with TensorFlow's APIs (including Keras). For instance, graph data (encoded as a GraphTensor) can now seamlessly be used with keras.Sequential, keras.Functional, tf.data.Dataset, and tf.saved_model APIs.
  • Layers
  • Models
    • Although model building is easy with MolGraph, there are some built-in GNN models:
      • GIN
      • MPNN
      • DMPNN
    • And models for improved interpretability of GNNs:
      • SaliencyMapping
      • IntegratedSaliencyMapping
      • SmoothGradSaliencyMapping
      • GradientActivationMapping

Requirements/dependencies

  • Python (version >= 3.10)
    • TensorFlow (version 2.15.*)
    • RDKit (version 2023.9.*)
    • Pandas
    • IPython

Installation

For CPU users:

pip install molgraph

For GPU users:

pip install molgraph[gpu]

Now run your first program with MolGraph:

from tensorflow import keras
from molgraph import chemistry
from molgraph import layers
from molgraph import models

# Obtain dataset, specifically ESOL
esol = chemistry.datasets.get('esol')

# Define molecular graph encoder
atom_encoder = chemistry.Featurizer([
    chemistry.features.Symbol(),
    chemistry.features.Hybridization(),
    # ...
])

bond_encoder = chemistry.Featurizer([
    chemistry.features.BondType(),
    # ...
])

encoder = chemistry.MolecularGraphEncoder(atom_encoder, bond_encoder)

# Obtain graphs and associated labels
x_train = encoder(esol['train']['x'])
y_train = esol['train']['y']

x_test = encoder(esol['test']['x'])
y_test = esol['test']['y']

# Build model via Keras API
gnn_model = keras.Sequential([
    layers.GNNInputLayer(type_spec=x_train.spec),
    layers.GATConv(units=32),
    layers.GATConv(units=32),
    layers.Readout(),
    keras.layers.Dense(units=1024, activation='relu'),
    keras.layers.Dense(units=y_train.shape[-1])
])

# Compile, fit and evaluate
gnn_model.compile(optimizer='adam', loss='mae')
gnn_model.fit(x_train, y_train, epochs=50)
scores = gnn_model.evaluate(x_test, y_test)

# Compute gradient activation maps
gam_model = models.GradientActivationMapping(gnn_model)

maps = gam_model(x_train.separate())

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

molgraph-0.7.2.tar.gz (116.1 kB view details)

Uploaded Source

Built Distribution

molgraph-0.7.2-py3-none-any.whl (210.3 kB view details)

Uploaded Python 3

File details

Details for the file molgraph-0.7.2.tar.gz.

File metadata

  • Download URL: molgraph-0.7.2.tar.gz
  • Upload date:
  • Size: 116.1 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/5.0.0 CPython/3.10.12

File hashes

Hashes for molgraph-0.7.2.tar.gz
Algorithm Hash digest
SHA256 105b0e9ca5e928f747f2a67d8d6f9f95156bf9cfb5edf16600a1a8078321e368
MD5 5ea97ffbf1dfb52478b0acce089fd188
BLAKE2b-256 517979355af15396f39f26e5f9079876bab1232db3f0f3aec0e4ba44fd58bdd3

See more details on using hashes here.

File details

Details for the file molgraph-0.7.2-py3-none-any.whl.

File metadata

  • Download URL: molgraph-0.7.2-py3-none-any.whl
  • Upload date:
  • Size: 210.3 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/5.0.0 CPython/3.10.12

File hashes

Hashes for molgraph-0.7.2-py3-none-any.whl
Algorithm Hash digest
SHA256 16c0af97d8de55c8d5ceba8a465c699bec54c8c37e8c184f61f1b577378757f1
MD5 21e710475bdbac86e0b6ebfb0ef80445
BLAKE2b-256 9a0e7add6a170fa8fa53342cfcecdbfec0825c416a4b9af425c8e28c66e8557c

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page