Skip to main content

Graph Neural Networks for Molecular Machine Learning

Project description

molgraph-title

Graph Neural Networks with TensorFlow and Keras. Focused on Molecular Machine Learning.

Currently, Keras 3 does not support extension types. As soon as it does, it is hoped that MolGraph will migrate to Keras 3.

Highlights

Build a Graph Neural Network with Keras' Sequential API:

from molgraph import GraphTensor
from molgraph import layers
from tensorflow import keras

g = GraphTensor(node_feature=[[4.], [2.]], edge_src=[0], edge_dst=[1])

model = keras.Sequential([
    layers.GNNInput(type_spec=g.spec),
    layers.GATv2Conv(units=32),
    layers.GATv2Conv(units=32),
    layers.Readout(),
    keras.layers.Dense(units=1),
])

pred = model(g)

# Save and load Keras model
model.save('/tmp/gatv2_model.keras')
loaded_model = keras.models.load_model('/tmp/gatv2_model.keras')
loaded_pred = loaded_model(g)
assert pred == loaded_pred

Paper

See arXiv

Documentation

See readthedocs

Overview

molgraph-overview

Implementations

  • Graph tensor (GraphTensor)
    • A composite tensor holding graph data.
    • Has a ragged state (multiple graphs) and a non-ragged state (single disjoint graph).
    • Can conveniently go between both states (merge(), separate()).
    • Can propagate node states (features) based on edges (propagate()).
    • Can add, update and remove graph data (update(), remove()).
    • Compatible with TensorFlow's APIs (including Keras). For instance, graph data (encoded as a GraphTensor) can now seamlessly be used with keras.Sequential, keras.Functional, tf.data.Dataset, and tf.saved_model APIs.
  • Layers
  • Models
    • Although model building is easy with MolGraph, there are some built-in GNN models:
      • GIN
      • MPNN
      • DMPNN
    • And models for improved interpretability of GNNs:
      • SaliencyMapping
      • IntegratedSaliencyMapping
      • SmoothGradSaliencyMapping
      • GradientActivationMapping (Recommended)

Requirements/dependencies

  • Python (version >= 3.10)
    • TensorFlow (version 2.15.*)
    • RDKit (version 2023.9.*)
    • Pandas
    • IPython

Installation

For CPU users:

pip install molgraph

For GPU users:

pip install molgraph[gpu]

Now run your first program with MolGraph:

from tensorflow import keras
from molgraph import chemistry
from molgraph import layers
from molgraph import models

# Obtain dataset, specifically ESOL
esol = chemistry.datasets.get('esol')

# Define molecular graph encoder
atom_encoder = chemistry.Featurizer([
    chemistry.features.Symbol(),
    chemistry.features.Hybridization(),
    # ...
])

bond_encoder = chemistry.Featurizer([
    chemistry.features.BondType(),
    # ...
])

encoder = chemistry.MolecularGraphEncoder(atom_encoder, bond_encoder)

# Obtain graphs and associated labels
x_train = encoder(esol['train']['x'])
y_train = esol['train']['y']

x_test = encoder(esol['test']['x'])
y_test = esol['test']['y']

# Build model via Keras API
gnn_model = keras.Sequential([
    layers.GNNInputLayer(type_spec=x_train.spec),
    layers.GATConv(units=32),
    layers.GATConv(units=32),
    layers.Readout(),
    keras.layers.Dense(units=1024, activation='relu'),
    keras.layers.Dense(units=y_train.shape[-1])
])

# Compile, fit and evaluate
gnn_model.compile(optimizer='adam', loss='mae')
gnn_model.fit(x_train, y_train, epochs=50)
scores = gnn_model.evaluate(x_test, y_test)

# Compute gradient activation maps
gam_model = models.GradientActivationMapping(gnn_model)

maps = gam_model(x_train.separate())

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

molgraph-0.6.11.tar.gz (110.5 kB view details)

Uploaded Source

Built Distribution

molgraph-0.6.11-py3-none-any.whl (198.3 kB view details)

Uploaded Python 3

File details

Details for the file molgraph-0.6.11.tar.gz.

File metadata

  • Download URL: molgraph-0.6.11.tar.gz
  • Upload date:
  • Size: 110.5 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/5.0.0 CPython/3.10.12

File hashes

Hashes for molgraph-0.6.11.tar.gz
Algorithm Hash digest
SHA256 b213c26c59e4f31da2256e2bb30dba03d0b01eb75770bf8b41adb7a0174cfcb8
MD5 6dceb35c268d1b5a34363bdc20bb1af1
BLAKE2b-256 d00ecb305e5b2d9b696e7aac6d53c2f1c5f51029a3c46f4d919f5a01756e2063

See more details on using hashes here.

File details

Details for the file molgraph-0.6.11-py3-none-any.whl.

File metadata

  • Download URL: molgraph-0.6.11-py3-none-any.whl
  • Upload date:
  • Size: 198.3 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/5.0.0 CPython/3.10.12

File hashes

Hashes for molgraph-0.6.11-py3-none-any.whl
Algorithm Hash digest
SHA256 a5b670ff8cfd756f63f803b436a55128ef816308d83fa242dc48aa31285bafbf
MD5 7a95780416dec008a507c52317ba2d9d
BLAKE2b-256 2422dda74c60680012da7349adbf829ab02f6f22404cda38429d5fc442d50948

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page