Skip to main content

Frame Averaging

Project description

Frame Averaging - Pytorch (wip)

Pytorch implementation of a simple way to enable (Stochastic) Frame Averaging for any network

Install

$ pip install frame-averaging-pytorch

Usage

import torch
from frame_averaging_pytorch import FrameAverage

# contrived neural network

net = torch.nn.Linear(3, 3)

# wrap the network with FrameAverage

net = FrameAverage(
    net,
    dim = 3,           # defaults to 3 for spatial, but can be any value
    stochastic = True  # whether to use stochastic variant from FAENet (one frame sampled at random)
)

# pass your input to the network as usual

points = torch.randn(2, 4, 1024, 3)

out = net(points)

out.shape # (2, 4, 1024, 3)

# frame averaging is automatically taken care of, as though the network were unwrapped

Citations

@article{Puny2021FrameAF,
    title   = {Frame Averaging for Invariant and Equivariant Network Design},
    author  = {Omri Puny and Matan Atzmon and Heli Ben-Hamu and Edward James Smith and Ishan Misra and Aditya Grover and Yaron Lipman},
    journal = {ArXiv},
    year    = {2021},
    volume  = {abs/2110.03336},
    url     = {https://api.semanticscholar.org/CorpusID:238419638}
}
@article{Duval2023FAENetFA,
    title   = {FAENet: Frame Averaging Equivariant GNN for Materials Modeling},
    author  = {Alexandre Duval and Victor Schmidt and Alex Hernandez Garcia and Santiago Miret and Fragkiskos D. Malliaros and Yoshua Bengio and David Rolnick},
    journal = {ArXiv},
    year    = {2023},
    volume  = {abs/2305.05577},
    url     = {https://api.semanticscholar.org/CorpusID:258564608}
}

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

frame_averaging_pytorch-0.0.5.tar.gz (220.2 kB view details)

Uploaded Source

Built Distribution

frame_averaging_pytorch-0.0.5-py3-none-any.whl (3.6 kB view details)

Uploaded Python 3

File details

Details for the file frame_averaging_pytorch-0.0.5.tar.gz.

File metadata

File hashes

Hashes for frame_averaging_pytorch-0.0.5.tar.gz
Algorithm Hash digest
SHA256 727db88ad2aab9356b396d6be9e3bf9cbd4e75298e776f0578774933460169e5
MD5 93b83cd73d8120a9b58f7e3c08e24ebd
BLAKE2b-256 c9ade7aa0a4d3e57984cd756d5ad7c2c970d5b5392869871cd506b7fe51bcedf

See more details on using hashes here.

File details

Details for the file frame_averaging_pytorch-0.0.5-py3-none-any.whl.

File metadata

File hashes

Hashes for frame_averaging_pytorch-0.0.5-py3-none-any.whl
Algorithm Hash digest
SHA256 381a2e2777f75d432024a45ec59e822e47d0721792de7a951f2029292a3335ad
MD5 8a7aca32e9350951f3e743b68abcd730
BLAKE2b-256 5111fe42dea9a1d7ec33568de3f9a3bcc4b25a5d9905dfb718481f24b59be82c

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page