Skip to main content

Gradient Agreement Filtering

Project description

Gradient Agreement Filtering - Pytorch

Implementation of Gradient Agreement Filtering, from Chaubard et al. of Stanford, but done for single machine microbatches, in Pytorch.

The official repository that does filtering for macrobatches across machines is here

Install

$ pip install GAF-microbatch-pytorch

Usage

import torch

# mock network

from torch import nn

net = nn.Sequential(
    nn.Linear(512, 256),
    nn.SiLU(),
    nn.Linear(256, 128)
)

# import the gradient agreement filtering (GAF) wrapper

from GAF_microbatch_pytorch import GAFWrapper

# just wrap your neural net

gaf_net = GAFWrapper(
    net,
    filter_distance_thres = 0.97
)

# your batch of data

x = torch.randn(16, 1024, 512)

# forward and backwards as usual

out = gaf_net(x)

out.sum().backward()

# gradients should be filtered by set threshold comparing per sample gradients within batch, as in paper

You can supply your own gradient filtering method as a Callable[[Tensor], Tensor] with the filter_gradients_fn kwarg as so

def filtering_fn(grads):
    # make your big discovery here
    return grads
 
gaf_net = GAFWrapper(
    net = net,
    filter_gradients_fn = filtering_fn
)

Todo

  • replicate cifar results on single machine
  • allow for excluding certain parameters from being filtered

Citations

@inproceedings{Chaubard2024BeyondGA,
    title   = {Beyond Gradient Averaging in Parallel Optimization: Improved Robustness through Gradient Agreement Filtering},
    author  = {Francois Chaubard and Duncan Eddy and Mykel J. Kochenderfer},
    year    = {2024},
    url     = {https://api.semanticscholar.org/CorpusID:274992650}
}

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

gaf_microbatch_pytorch-0.0.5.tar.gz (142.2 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

gaf_microbatch_pytorch-0.0.5-py3-none-any.whl (6.2 kB view details)

Uploaded Python 3

File details

Details for the file gaf_microbatch_pytorch-0.0.5.tar.gz.

File metadata

  • Download URL: gaf_microbatch_pytorch-0.0.5.tar.gz
  • Upload date:
  • Size: 142.2 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.0.1 CPython/3.9.21

File hashes

Hashes for gaf_microbatch_pytorch-0.0.5.tar.gz
Algorithm Hash digest
SHA256 b992d740e13e0e3e6c748fd1c68d9fe2c23476daf944035de01d2e8c181ac04c
MD5 bc2021d24b8c11333f1f76c83f160fae
BLAKE2b-256 a622f712b70479414f3518d120582bcfef38959776c4ac9d00bf790afc5281ee

See more details on using hashes here.

File details

Details for the file gaf_microbatch_pytorch-0.0.5-py3-none-any.whl.

File metadata

File hashes

Hashes for gaf_microbatch_pytorch-0.0.5-py3-none-any.whl
Algorithm Hash digest
SHA256 2d55554f27f2a7a925d9ecc681a4b8484c9d5eb6e39d81b0d25c54910ad9f49e
MD5 a28224e1d4113b8709e9bd4db605c59b
BLAKE2b-256 034f440dfa1dd529e81f4937dfe0986bed7fe3e1116df175d685dbfec4a11a56

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page