Skip to main content

PyTorch implementation for FAENet from 'FAENet: Frame Averaging Equivariant GNN for Materials Modeling'

Project description

💻  Code   •   Docs  📑

Python Documentation Status


FAENet: Frame Averaging Equivariant GNN for Materials modeling

This repository contains an implementation of the paper FAENet: Frame Averaging Equivariant GNN for Materials modeling, accepted at ICML 2023. More precisely, you will find:

  • FrameAveraging: the transform that projects your pytorch-geometric data into the canonical space defined in the paper.
  • FAENet GNN model for material modeling.
  • model_forward: a high-level forward function that computes appropriate model predictions for the Frame Averaging method, i.e. handling the different frames and mapping to equivariant predictions.

Also: https://github.com/vict0rsch/faenet

Installation

pip install faenet

⚠️ The above installation requires Python >= 3.8, torch > 1.11, torch_geometric > 2.1 to the best of our knowledge. Both mendeleev and pandas package are also required to derive physics-aware atom embeddings in FAENet.

Getting started

Frame Averaging Transform

FrameAveraging is a Transform method applicable to pytorch-geometric Data object. You can choose among several options ranging from Full FA to Stochastic FA (in 2D or 3D) including data augmentation DA. This method shall be applied in the get_item() function of your Dataset class. Note that although this transform is specific to pytorch-geometric data objects, it can be easily extended to new settings since the core functions frame_averaging_2D() and frame_averaging_3D() generalise to other data format.

import torch
from faenet.transform import FrameAveraging

frame_averaging = "3D"  # symmetry preservation method used: {"3D", "2D", "DA", ""}:
fa_method = "stochastic"  # the frame averaging method: {"det", "all", "se3-stochastic", "se3-det", "se3-all", ""}:
transform = FrameAveraging(frame_averaging, fa_method)
transform(g)  # transform the PyG graph g 

Model forward for Frame Averaging

model_forward() aggregates model predictions when Frame Averaging is applied, as stipulated by the Equation (1) of the paper. It must be applied.

from faenet.fa_forward import model_forward

preds = model_forward(
    batch=batch,   # batch from, dataloader
    model=model,  # FAENet(**kwargs)
    frame_averaging="3D", # ["2D", "3D", "DA", ""]
    mode="train",  # for training 
    crystal_task=True,  # for crystals, with pbc conditions
)

FAENet GNN

Implementation of the FAENet GNN model, compatible with any dataset or transform. In short, FAENet is a very simple, scalable and expressive model. Since does not explicitly preserve data symmetries, it has the ability to process directly and unrestrictedly atom relative positions, which is very efficient. Note that the training procedure is not given here.

from faenet.model import FAENet

preds = FAENet(**kwargs)
model(batch)

FAENet architecture

Eval

The eval_model_symmetries() function helps you evaluate the equivariant, invariant and other properties of a model, as we did in the paper.

Tests

The /tests folder contains several useful unit-tests. Feel free to have a look at them to explore how the model can be used. For more advanced examples, please refer to the full repository used in our ICML paper to make predictions on OC20 IS2RE, S2EF, QM9 and QM7-X dataset.

This requires poetry. Make sure to have torch and torch_geometric installed in your environment before you can run the tests. Unfortunately because of CUDA/torch compatibilities, neither torch nor torch_geometric are part of the explicit dependencies and must be installed independently.

git clone git@github.com:vict0rsch/faenet.git
poetry install --with dev
pytest --cov=faenet --cov-report term-missing

Testing on Macs you may encounter a Library Not Loaded Error

Contact

Authors: Alexandre Duval (alexandre.duval@mila.quebec) and Victor Schmidt (schmidtv@mila.quebec). We welcome your questions and feedback via email or GitHub Issues.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

faenet-0.1.1.tar.gz (20.4 kB view hashes)

Uploaded Source

Built Distribution

faenet-0.1.1-py3-none-any.whl (21.0 kB view hashes)

Uploaded Python 3

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page