Frame Averaging
Project description
Frame Averaging - Pytorch (wip)
Pytorch implementation of a simple way to enable (Stochastic) Frame Averaging for any network
Install
$ pip install frame-averaging-pytorch
Usage
import torch
from frame_averaging_pytorch import FrameAverage
# contrived neural network
net = torch.nn.Linear(3, 3)
# wrap the network with FrameAverage
net = FrameAverage(net)
# pass your input to the network as usual
points = torch.randn(2, 4, 1024, 3)
out = net(points)
out.shape # (2, 4, 1024, 3)
# frame averaging is automatically taken care of, as though the network were unwrapped
Citations
@article{Puny2021FrameAF,
title = {Frame Averaging for Invariant and Equivariant Network Design},
author = {Omri Puny and Matan Atzmon and Heli Ben-Hamu and Edward James Smith and Ishan Misra and Aditya Grover and Yaron Lipman},
journal = {ArXiv},
year = {2021},
volume = {abs/2110.03336},
url = {https://api.semanticscholar.org/CorpusID:238419638}
}
@article{Duval2023FAENetFA,
title = {FAENet: Frame Averaging Equivariant GNN for Materials Modeling},
author = {Alexandre Duval and Victor Schmidt and Alex Hernandez Garcia and Santiago Miret and Fragkiskos D. Malliaros and Yoshua Bengio and David Rolnick},
journal = {ArXiv},
year = {2023},
volume = {abs/2305.05577},
url = {https://api.semanticscholar.org/CorpusID:258564608}
}
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
frame_averaging_pytorch-0.0.2.tar.gz
(219.8 kB
view details)
Built Distribution
File details
Details for the file frame_averaging_pytorch-0.0.2.tar.gz
.
File metadata
- Download URL: frame_averaging_pytorch-0.0.2.tar.gz
- Upload date:
- Size: 219.8 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/5.1.0 CPython/3.9.19
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | 6d72a810a5effec105f9d0173bda1a0dd4b3d14b3932401cdaa49a3fbba5b12d |
|
MD5 | 1e005975b17ea72dfad70559f703efa2 |
|
BLAKE2b-256 | 57aeaf5f3b26f13936da23fab8d8c5f3fb20edf423ab7dc32fafb4ea827e475a |
File details
Details for the file frame_averaging_pytorch-0.0.2-py3-none-any.whl
.
File metadata
- Download URL: frame_averaging_pytorch-0.0.2-py3-none-any.whl
- Upload date:
- Size: 3.5 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/5.1.0 CPython/3.9.19
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | 9b520c331eb6315732b03c5224ead72a1ef06ac19d83cc457a07dd6291482e7b |
|
MD5 | 31024d13127c75d667da139b58d43ae4 |
|
BLAKE2b-256 | 24e8cd194a1d6a4ceae12a9db987e837075fabaffca57e0bea3f28c0ad5f104a |