Frame Averaging
Project description
Frame Averaging - Pytorch (wip)
Pytorch implementation of a simple way to enable (Stochastic) Frame Averaging for any network. This technique was recently adopted by Prescient Design in AbDiffuser
Install
$ pip install frame-averaging-pytorch
Usage
import torch
from frame_averaging_pytorch import FrameAverage
# contrived neural network
net = torch.nn.Linear(3, 3)
# wrap the network with FrameAverage
net = FrameAverage(
net,
dim = 3, # defaults to 3 for spatial, but can be any value
stochastic = True # whether to use stochastic variant from FAENet (one frame sampled at random)
)
# pass your input to the network as usual
points = torch.randn(4, 1024, 3)
mask = torch.ones(4, 1024).bool()
out = net(points, frame_average_mask = mask)
out.shape # (4, 1024, 3)
# frame averaging is automatically taken care of, as though the network were unwrapped
Citations
@article{Puny2021FrameAF,
title = {Frame Averaging for Invariant and Equivariant Network Design},
author = {Omri Puny and Matan Atzmon and Heli Ben-Hamu and Edward James Smith and Ishan Misra and Aditya Grover and Yaron Lipman},
journal = {ArXiv},
year = {2021},
volume = {abs/2110.03336},
url = {https://api.semanticscholar.org/CorpusID:238419638}
}
@article{Duval2023FAENetFA,
title = {FAENet: Frame Averaging Equivariant GNN for Materials Modeling},
author = {Alexandre Duval and Victor Schmidt and Alex Hernandez Garcia and Santiago Miret and Fragkiskos D. Malliaros and Yoshua Bengio and David Rolnick},
journal = {ArXiv},
year = {2023},
volume = {abs/2305.05577},
url = {https://api.semanticscholar.org/CorpusID:258564608}
}
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
File details
Details for the file frame_averaging_pytorch-0.0.14.tar.gz
.
File metadata
- Download URL: frame_averaging_pytorch-0.0.14.tar.gz
- Upload date:
- Size: 220.3 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/5.1.0 CPython/3.9.19
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | 38af1611bbb80aeb55bd5814b56b309eba52383fcf59f4a3efc0c2a09ab45622 |
|
MD5 | 6af58908eed9b5f309e0b1f0dbadd627 |
|
BLAKE2b-256 | cc5421cc4b0c4802739603ee22ab066fec1d63ce8e459138e984c13c12e123ad |
File details
Details for the file frame_averaging_pytorch-0.0.14-py3-none-any.whl
.
File metadata
- Download URL: frame_averaging_pytorch-0.0.14-py3-none-any.whl
- Upload date:
- Size: 3.6 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/5.1.0 CPython/3.9.19
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | 0b1fb411151724fecd963c7c0444216bd2b10b45912a917d25046f4e1f919b30 |
|
MD5 | 4590c467550639ee79ab92ad0d553fa7 |
|
BLAKE2b-256 | d9073a63ccec3783cca6f241f1bbfd70b24cc47f22c5a5a80fbb8216ee1fc797 |