Skip to main content

Graph Neural Networks for Molecular Machine Learning

Project description

molgraph-title

Graph Neural Networks with TensorFlow and Keras. Focused on Molecular Machine Learning.

Currently, Keras 3 does not support extension types. As soon as it does, it is hoped that MolGraph will migrate to Keras 3.

Highlights

Build a Graph Neural Network with Keras' Sequential API:

from molgraph import GraphTensor
from molgraph import layers
from tensorflow import keras

g = GraphTensor(node_feature=[[4.], [2.]], edge_src=[0], edge_dst=[1])

model = keras.Sequential([
    layers.GNNInput(type_spec=g.spec),
    layers.GATv2Conv(units=32),
    layers.GATv2Conv(units=32),
    layers.Readout(),
    keras.layers.Dense(units=1),
])

pred = model(g)

# Save and load Keras model
model.save('/tmp/gatv2_model.keras')
loaded_model = keras.models.load_model('/tmp/gatv2_model.keras')
loaded_pred = loaded_model(g)
assert pred == loaded_pred

Combine outputs of GNN layers to improve predictive performance:

model = keras.Sequential([
    layers.GNNInput(type_spec=g.spec),
    layers.GNN([
        layers.FeatureProjection(units=32),
        layers.GINConv(units=32),
        layers.GINConv(units=32),
        layers.GINConv(units=32),
    ]),
    layers.Readout(),
    keras.layers.Dense(units=128),
    keras.layers.Dense(units=1),
])

model.summary()

Paper

See arXiv

Documentation

See readthedocs

Overview

molgraph-overview

Implementations

  • Graph tensor (GraphTensor)
    • A composite tensor holding graph data.
    • Has a ragged state (multiple graphs) and a non-ragged state (single disjoint graph).
    • Can conveniently go between both states (merge(), separate()).
    • Can propagate node states (features) based on edges (propagate()).
    • Can add, update and remove graph data (update(), remove()).
    • Compatible with TensorFlow's APIs (including Keras). For instance, graph data (encoded as a GraphTensor) can now seamlessly be used with keras.Sequential, keras.Functional, tf.data.Dataset, and tf.saved_model APIs.
  • Layers
  • Models
    • Although model building is easy with MolGraph, there are some built-in GNN models:
      • GIN
      • MPNN
      • DMPNN
    • And models for improved interpretability of GNNs:
      • SaliencyMapping
      • IntegratedSaliencyMapping
      • SmoothGradSaliencyMapping
      • GradientActivationMapping (Recommended)

Requirements/dependencies

  • Python (version >= 3.10)
    • TensorFlow (version 2.15.*)
    • RDKit (version 2023.9.*)
    • Pandas
    • IPython

Installation

For CPU users:

pip install molgraph

For GPU users:

pip install molgraph[gpu]

Now run your first program with MolGraph:

from tensorflow import keras
from molgraph import chemistry
from molgraph import layers
from molgraph import models

# Obtain dataset, specifically ESOL
esol = chemistry.datasets.get('esol')

# Define molecular graph encoder
atom_encoder = chemistry.Featurizer([
    chemistry.features.Symbol(),
    chemistry.features.Hybridization(),
    # ...
])

bond_encoder = chemistry.Featurizer([
    chemistry.features.BondType(),
    # ...
])

encoder = chemistry.MolecularGraphEncoder(atom_encoder, bond_encoder)

# Obtain graphs and associated labels
x_train = encoder(esol['train']['x'])
y_train = esol['train']['y']

x_test = encoder(esol['test']['x'])
y_test = esol['test']['y']

# Build model via Keras API
gnn_model = keras.Sequential([
    layers.GNNInputLayer(type_spec=x_train.spec),
    layers.GATConv(units=32),
    layers.GATConv(units=32),
    layers.Readout(),
    keras.layers.Dense(units=1024, activation='relu'),
    keras.layers.Dense(units=y_train.shape[-1])
])

# Compile, fit and evaluate
gnn_model.compile(optimizer='adam', loss='mae')
gnn_model.fit(x_train, y_train, epochs=50)
scores = gnn_model.evaluate(x_test, y_test)

# Compute gradient activation maps
gam_model = models.GradientActivationMapping(gnn_model)

maps = gam_model(x_train.separate())

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

molgraph-0.6.13.tar.gz (110.9 kB view details)

Uploaded Source

Built Distribution

molgraph-0.6.13-py3-none-any.whl (199.0 kB view details)

Uploaded Python 3

File details

Details for the file molgraph-0.6.13.tar.gz.

File metadata

  • Download URL: molgraph-0.6.13.tar.gz
  • Upload date:
  • Size: 110.9 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/5.0.0 CPython/3.10.12

File hashes

Hashes for molgraph-0.6.13.tar.gz
Algorithm Hash digest
SHA256 8c38a11a06d911da7e62c93bcdf8bd89df4aae49fa79e666d9923c783d08e6c7
MD5 7a9babf630a62ee6a0105f88facb6402
BLAKE2b-256 20b3ee8813f9bf5c876f226b1d34152f2e4fe84bc889edfcdf2e0b405fcc7164

See more details on using hashes here.

File details

Details for the file molgraph-0.6.13-py3-none-any.whl.

File metadata

  • Download URL: molgraph-0.6.13-py3-none-any.whl
  • Upload date:
  • Size: 199.0 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/5.0.0 CPython/3.10.12

File hashes

Hashes for molgraph-0.6.13-py3-none-any.whl
Algorithm Hash digest
SHA256 5537038fd3f83301cbd8c86ab83b62442d3f8155cb83eba8a35bd4f582bb5970
MD5 8e96339836ce923d28cf7465fef8c455
BLAKE2b-256 e58b4572b72cb8825fb184770c11b4e461830febc3b60e63682040c1fb8d6c75

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page