Skip to main content

No project description provided

Project description

💻  Code   •   Docs  📑

Python Documentation Status


PhAST: Physics-Aware, Scalable, and Task-specific GNNs for Accelerated Catalyst Design

This repository contains implementations for 2 of the PhAST components presented in the paper:

  • PhysEmbedding that allows one to create an embedding vector from atomic numbers that is the concatenation of:
    • A learned embedding for the atom's group
    • A learned embedding for the atom's period
    • A fixed or learned embedding from a set of known physical properties, as reported by mendeleev
    • In the case of the OC20 dataset, a learned embedding for the atom's tag (adsorbate, catalyst surface or catalyst sub-surface)
  • Tag-based graph rewiring strategies for the OC20 dataset:
    • remove_tag0_nodes deletes all nodes in the graph associated with a tag 0 and recomputes edges

    • one_supernode_per_graph replaces all tag 0 atoms with a single new atom

    • one_supernode_per_atom_type replaces all tag 0 atoms of a given element with its own super node

Also: https://github.com/vict0rsch/faenet

Installation

pip install phast

⚠️ The above installation does not include torch_geometric which is a complex and very variable dependency you have to install yourself if you want to use the graph re-wiring functions of phast.

☮️ Ignore torch_geometric if you only care about the PhysEmbeddings.

Getting started

Physical embeddings

Embedding illustration

import torch
from phast.embedding import PhysEmbedding

z = torch.randint(1, 85, (3, 12)) # batch of 3 graphs with 12 atoms each
phys_embedding = PhysEmbedding(
    z_emb_size=32, # default
    period_emb_size=32, # default
    group_emb_size=32, # default
    properties_proj_size=32, # default is 0 -> no learned projection
    n_elements=85, # default
)
h = phys_embedding(z) # h.shape = (3, 12, 128)

tags = torch.randint(0, 3, (3, 12))
phys_embedding = PhysEmbedding(
    tag_emb_size=32, # default is 0, this is OC20-specific
    final_proj_size=64, # default is 0, no projection, just the concat. of embeds.
)

h = phys_embedding(z, tags) # h.shape = (3, 12, 64)

# Assuming torch_geometric is installed:
data = torch.load("examples/data/is2re_bs3.pt")
h = phys_embedding(data.atomic_numbers.long(), data.tags) # h.shape = (261, 64)

Graph rewiring

Rewiring illustration

from copy import deepcopy
import torch
from phast.graph_rewiring import (
    remove_tag0_nodes,
    one_supernode_per_graph,
    one_supernode_per_atom_type,
)

data = torch.load("./examples/data/is2re_bs3.pt")  # 3 batched OC20 IS2RE data samples
print(
    "Data initially contains {} graphs, a total of {} atoms and {} edges".format(
        len(data.natoms), data.ptr[-1], len(data.cell_offsets)
    )
)
rewired_data = remove_tag0_nodes(deepcopy(data))
print(
    "Data without tag-0 nodes contains {} graphs, a total of {} atoms and {} edges".format(
        len(rewired_data.natoms), rewired_data.ptr[-1], len(rewired_data.cell_offsets)
    )
)
rewired_data = one_supernode_per_graph(deepcopy(data))
print(
    "Data with one super node per graph contains a total of {} atoms and {} edges".format(
        rewired_data.ptr[-1], len(rewired_data.cell_offsets)
    )
)
rewired_data = one_supernode_per_atom_type(deepcopy(data))
print(
    "Data with one super node per atom type contains a total of {} atoms and {} edges".format(
        rewired_data.ptr[-1], len(rewired_data.cell_offsets)
    )
)
Data initially contains 3 graphs, a total of 261 atoms and 11596 edges
Data without tag-0 nodes contains 3 graphs, a total of 64 atoms and 1236 edges
Data with one super node per graph contains a total of 67 atoms and 1311 edges
Data with one super node per atom type contains a total of 71 atoms and 1421 edges

Tests

This requires poetry. Make sure to have torch and torch_geometric installed in your environment before you can run the tests. Unfortunately because of CUDA/torch compatibilities, neither torch nor torch_geometric are part of the explicit dependencies and must be installed independently.

git clone git@github.com:vict0rsch/phast.git
poetry install --with dev
pytest --cov=phast --cov-report term-missing

Testing on Macs you may encounter a Library Not Loaded Error

Requires Python <3.12 because

mendeleev (0.14.0) requires Python >=3.8.1,<3.12

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

phast-0.1.2.tar.gz (15.1 kB view hashes)

Uploaded Source

Built Distribution

phast-0.1.2-py3-none-any.whl (14.3 kB view hashes)

Uploaded Python 3

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page