Skip to main content

Frame Averaging

Project description

Frame Averaging - Pytorch (wip)

Pytorch implementation of a simple way to enable (Stochastic) Frame Averaging for any network. This technique was recently adopted by Prescient Design in AbDiffuser

Install

$ pip install frame-averaging-pytorch

Usage

import torch
from frame_averaging_pytorch import FrameAverage

# contrived neural network

net = torch.nn.Linear(3, 3)

# wrap the network with FrameAverage

net = FrameAverage(
    net,
    dim = 3,           # defaults to 3 for spatial, but can be any value
    stochastic = True  # whether to use stochastic variant from FAENet (one frame sampled at random)
)

# pass your input to the network as usual

points = torch.randn(4, 1024, 3)
mask = torch.ones(4, 1024).bool()

out = net(points, frame_average_mask = mask)

out.shape # (4, 1024, 3)

# frame averaging is automatically taken care of, as though the network were unwrapped

Citations

@article{Puny2021FrameAF,
    title   = {Frame Averaging for Invariant and Equivariant Network Design},
    author  = {Omri Puny and Matan Atzmon and Heli Ben-Hamu and Edward James Smith and Ishan Misra and Aditya Grover and Yaron Lipman},
    journal = {ArXiv},
    year    = {2021},
    volume  = {abs/2110.03336},
    url     = {https://api.semanticscholar.org/CorpusID:238419638}
}
@article{Duval2023FAENetFA,
    title   = {FAENet: Frame Averaging Equivariant GNN for Materials Modeling},
    author  = {Alexandre Duval and Victor Schmidt and Alex Hernandez Garcia and Santiago Miret and Fragkiskos D. Malliaros and Yoshua Bengio and David Rolnick},
    journal = {ArXiv},
    year    = {2023},
    volume  = {abs/2305.05577},
    url     = {https://api.semanticscholar.org/CorpusID:258564608}
}

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

frame_averaging_pytorch-0.0.16.tar.gz (220.4 kB view details)

Uploaded Source

Built Distribution

File details

Details for the file frame_averaging_pytorch-0.0.16.tar.gz.

File metadata

File hashes

Hashes for frame_averaging_pytorch-0.0.16.tar.gz
Algorithm Hash digest
SHA256 b54d0b1348503495eb5738b0901f30c35f5e3e36a5896e82d7dfe21ce49441a2
MD5 c886955697bd033a632361e3afa27f48
BLAKE2b-256 cca84fecb45fb0d96a2fd831d5c9dac62b01219290a4b106998173ddb51a1829

See more details on using hashes here.

File details

Details for the file frame_averaging_pytorch-0.0.16-py3-none-any.whl.

File metadata

File hashes

Hashes for frame_averaging_pytorch-0.0.16-py3-none-any.whl
Algorithm Hash digest
SHA256 1ec902bed7d1cff321329bd28fbd0e6dfcca3f05b6d678566a5d5708bd924f6a
MD5 2aa72af3d64751ba18e31860894b5efb
BLAKE2b-256 98b7abcf34d7a93f79bf5b9dc87644e5d1fe3f7d1c625207e8f790f93c2282d4

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page