Frame Averaging
Project description
Frame Averaging - Pytorch (wip)
Pytorch implementation of a simple way to enable (Stochastic) Frame Averaging for any network. This technique was recently adopted by Prescient Design in AbDiffuser
Install
$ pip install frame-averaging-pytorch
Usage
import torch
from frame_averaging_pytorch import FrameAverage
# contrived neural network
net = torch.nn.Linear(3, 3)
# wrap the network with FrameAverage
net = FrameAverage(
net,
dim = 3, # defaults to 3 for spatial, but can be any value
stochastic = True # whether to use stochastic variant from FAENet (one frame sampled at random)
)
# pass your input to the network as usual
points = torch.randn(4, 1024, 3)
mask = torch.ones(4, 1024).bool()
out = net(points, frame_average_mask = mask)
out.shape # (4, 1024, 3)
# frame averaging is automatically taken care of, as though the network were unwrapped
or you can also carry it out manually
import torch
from frame_averaging_pytorch import FrameAverage
# contrived neural network
net = torch.nn.Linear(3, 3)
# frame average module without passing in network
fa = FrameAverage()
# pass the 3d points and mask to FrameAverage forward
points = torch.randn(4, 1024, 3)
mask = torch.ones(4, 1024).bool()
framed_inputs, frame_average_fn = fa(points, frame_average_mask = mask)
# network forward
net_out = net(framed_inputs)
# frame average
frame_averaged = frame_average_fn(net_out)
frame_averaged.shape # (4, 1024, 3)
Citations
@article{Puny2021FrameAF,
title = {Frame Averaging for Invariant and Equivariant Network Design},
author = {Omri Puny and Matan Atzmon and Heli Ben-Hamu and Edward James Smith and Ishan Misra and Aditya Grover and Yaron Lipman},
journal = {ArXiv},
year = {2021},
volume = {abs/2110.03336},
url = {https://api.semanticscholar.org/CorpusID:238419638}
}
@article{Duval2023FAENetFA,
title = {FAENet: Frame Averaging Equivariant GNN for Materials Modeling},
author = {Alexandre Duval and Victor Schmidt and Alex Hernandez Garcia and Santiago Miret and Fragkiskos D. Malliaros and Yoshua Bengio and David Rolnick},
journal = {ArXiv},
year = {2023},
volume = {abs/2305.05577},
url = {https://api.semanticscholar.org/CorpusID:258564608}
}
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
frame_averaging_pytorch-0.0.17.tar.gz
(220.5 kB
view hashes)
Built Distribution
Close
Hashes for frame_averaging_pytorch-0.0.17.tar.gz
Algorithm | Hash digest | |
---|---|---|
SHA256 | 48418e689d59308680aa7bc92f25821a4b5d35eb6ab584b9fb62caa79ddb199d |
|
MD5 | 0a6a8947b8ce4fca0d161a9ab1e1f381 |
|
BLAKE2b-256 | 8b6a8389d52ec3be5f651591b6fa4cc457bba150e1c79f103f8f44cda94aca9e |
Close
Hashes for frame_averaging_pytorch-0.0.17-py3-none-any.whl
Algorithm | Hash digest | |
---|---|---|
SHA256 | dc82d89dc26d2871e61df2ed4a66beac0884f0217a6e7d025e8a048e4535983a |
|
MD5 | 58459dfa734b82c4577695a61c446c39 |
|
BLAKE2b-256 | a683b83666f75c1095e4dde7636ce0b7a2b7aa19e9c1d94b3e7c018de3127e63 |