Skip to main content

Frame Averaging

Project description

Frame Averaging - Pytorch (wip)

Pytorch implementation of a simple way to enable (Stochastic) Frame Averaging for any network. This technique was recently adopted by Prescient Design in AbDiffuser

Install

$ pip install frame-averaging-pytorch

Usage

import torch
from frame_averaging_pytorch import FrameAverage

# contrived neural network

net = torch.nn.Linear(3, 3)

# wrap the network with FrameAverage

net = FrameAverage(
    net,
    dim = 3,           # defaults to 3 for spatial, but can be any value
    stochastic = True  # whether to use stochastic variant from FAENet (one frame sampled at random)
)

# pass your input to the network as usual

points = torch.randn(4, 1024, 3)
mask = torch.ones(4, 1024).bool()

out = net(points, frame_average_mask = mask)

out.shape # (4, 1024, 3)

# frame averaging is automatically taken care of, as though the network were unwrapped

or you can also carry it out manually

import torch
from frame_averaging_pytorch import FrameAverage

# contrived neural network

net = torch.nn.Linear(3, 3)

# frame average module without passing in network

fa = FrameAverage()

# pass the 3d points and mask to FrameAverage forward

points = torch.randn(4, 1024, 3)
mask = torch.ones(4, 1024).bool()

framed_inputs, frame_average_fn = fa(points, frame_average_mask = mask)

# network forward

net_out = net(framed_inputs)

# frame average

frame_averaged = frame_average_fn(net_out)

frame_averaged.shape # (4, 1024, 3)

Citations

@article{Puny2021FrameAF,
    title   = {Frame Averaging for Invariant and Equivariant Network Design},
    author  = {Omri Puny and Matan Atzmon and Heli Ben-Hamu and Edward James Smith and Ishan Misra and Aditya Grover and Yaron Lipman},
    journal = {ArXiv},
    year    = {2021},
    volume  = {abs/2110.03336},
    url     = {https://api.semanticscholar.org/CorpusID:238419638}
}
@article{Duval2023FAENetFA,
    title   = {FAENet: Frame Averaging Equivariant GNN for Materials Modeling},
    author  = {Alexandre Duval and Victor Schmidt and Alex Hernandez Garcia and Santiago Miret and Fragkiskos D. Malliaros and Yoshua Bengio and David Rolnick},
    journal = {ArXiv},
    year    = {2023},
    volume  = {abs/2305.05577},
    url     = {https://api.semanticscholar.org/CorpusID:258564608}
}

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

frame_averaging_pytorch-0.0.19.tar.gz (220.6 kB view details)

Uploaded Source

Built Distribution

File details

Details for the file frame_averaging_pytorch-0.0.19.tar.gz.

File metadata

File hashes

Hashes for frame_averaging_pytorch-0.0.19.tar.gz
Algorithm Hash digest
SHA256 35191fb5ff8b25cac5630561c4cf0fc852e277a741f4ea174f10cc91634af933
MD5 684d4e7af96001479b662ea3e9f3b448
BLAKE2b-256 4ad67e9cd83077ad81b0bb7408491dc3ed31253fdc9f8a93cae92ca54c086594

See more details on using hashes here.

File details

Details for the file frame_averaging_pytorch-0.0.19-py3-none-any.whl.

File metadata

File hashes

Hashes for frame_averaging_pytorch-0.0.19-py3-none-any.whl
Algorithm Hash digest
SHA256 3acfa731a2c8118ddbdb3acb681cf4737fce3c1eb00f66e1a7a9c0ee87d364d8
MD5 34fdab3c025ee4d77fe66919cf0c8763
BLAKE2b-256 36a6e70e34e9bbfa894c794dfbebf44200b234263b37d93cef5c6669f21759e7

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page