Skip to main content

Frame Averaging

Project description

Frame Averaging - Pytorch (wip)

Pytorch implementation of a simple way to enable (Stochastic) Frame Averaging for any network

Install

$ pip install frame-averaging-pytorch

Usage

import torch
from frame_averaging_pytorch import FrameAverage

# contrived neural network

net = torch.nn.Linear(3, 3)

# wrap the network with FrameAverage

net = FrameAverage(
    net,
    dim = 3,           # defaults to 3 for spatial, but can be any value
    stochastic = True  # whether to use stochastic variant from FAENet (one frame sampled at random)
)

# pass your input to the network as usual

points = torch.randn(2, 4, 1024, 3)

out = net(points)

out.shape # (2, 4, 1024, 3)

# frame averaging is automatically taken care of, as though the network were unwrapped

Citations

@article{Puny2021FrameAF,
    title   = {Frame Averaging for Invariant and Equivariant Network Design},
    author  = {Omri Puny and Matan Atzmon and Heli Ben-Hamu and Edward James Smith and Ishan Misra and Aditya Grover and Yaron Lipman},
    journal = {ArXiv},
    year    = {2021},
    volume  = {abs/2110.03336},
    url     = {https://api.semanticscholar.org/CorpusID:238419638}
}
@article{Duval2023FAENetFA,
    title   = {FAENet: Frame Averaging Equivariant GNN for Materials Modeling},
    author  = {Alexandre Duval and Victor Schmidt and Alex Hernandez Garcia and Santiago Miret and Fragkiskos D. Malliaros and Yoshua Bengio and David Rolnick},
    journal = {ArXiv},
    year    = {2023},
    volume  = {abs/2305.05577},
    url     = {https://api.semanticscholar.org/CorpusID:258564608}
}

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

frame_averaging_pytorch-0.0.4.tar.gz (220.0 kB view details)

Uploaded Source

Built Distribution

frame_averaging_pytorch-0.0.4-py3-none-any.whl (3.6 kB view details)

Uploaded Python 3

File details

Details for the file frame_averaging_pytorch-0.0.4.tar.gz.

File metadata

File hashes

Hashes for frame_averaging_pytorch-0.0.4.tar.gz
Algorithm Hash digest
SHA256 f149df14aa8346dfe372dfb0222bcd7ac28beee91c5b74d6f365fb26e929c123
MD5 fade5ea01ae775c517f5d5cd7de84b85
BLAKE2b-256 53906fb508f7e6cc4fcb67fa62361510d968dda9fc32caeb11ded45f0f2c3dda

See more details on using hashes here.

File details

Details for the file frame_averaging_pytorch-0.0.4-py3-none-any.whl.

File metadata

File hashes

Hashes for frame_averaging_pytorch-0.0.4-py3-none-any.whl
Algorithm Hash digest
SHA256 3536a80e5aefc1c3f09dd569cb8268fe3e44d02e9333617d7bf4a5997d49ab8d
MD5 3034fa95a627a835f411812df0a398b4
BLAKE2b-256 7807a360342fea0f5b6fb762fd4696b90c178b01ddbe2d11d46410e1ab994c9a

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page