Gradient Agreement Filtering
Project description
Gradient Agreement Filtering - Pytorch
Implementation of Gradient Agreement Filtering, from Chaubard et al. of Stanford, but done for single machine microbatches, in Pytorch.
The official repository that does filtering for macrobatches across machines is here
Install
$ pip install GAF-microbatch-pytorch
Usage
import torch
# mock network
from torch import nn
net = nn.Sequential(
nn.Linear(512, 256),
nn.SiLU(),
nn.Linear(256, 128)
)
# import the gradient agreement filtering (GAF) wrapper
from GAF_microbatch_pytorch import GAFWrapper
# just wrap your neural net
gaf_net = GAFWrapper(
net,
filter_distance_thres = 0.97
)
# your batch of data
x = torch.randn(16, 1024, 512)
# forward and backwards as usual
out = gaf_net(x)
out.sum().backward()
# gradients should be filtered by set threshold comparing per sample gradients within batch, as in paper
You can supply your own gradient filtering method as a Callable[[Tensor], Tensor] with the filter_gradients_fn kwarg as so
def filtering_fn(grads):
# make your big discovery here
return grads
gaf_net = GAFWrapper(
net = net,
filter_gradients_fn = filtering_fn
)
Todo
- replicate cifar results on single machine
- allow for excluding certain parameters from being filtered
Citations
@inproceedings{Chaubard2024BeyondGA,
title = {Beyond Gradient Averaging in Parallel Optimization: Improved Robustness through Gradient Agreement Filtering},
author = {Francois Chaubard and Duncan Eddy and Mykel J. Kochenderfer},
year = {2024},
url = {https://api.semanticscholar.org/CorpusID:274992650}
}
Project details
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file gaf_microbatch_pytorch-0.0.5.tar.gz.
File metadata
- Download URL: gaf_microbatch_pytorch-0.0.5.tar.gz
- Upload date:
- Size: 142.2 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.0.1 CPython/3.9.21
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
b992d740e13e0e3e6c748fd1c68d9fe2c23476daf944035de01d2e8c181ac04c
|
|
| MD5 |
bc2021d24b8c11333f1f76c83f160fae
|
|
| BLAKE2b-256 |
a622f712b70479414f3518d120582bcfef38959776c4ac9d00bf790afc5281ee
|
File details
Details for the file gaf_microbatch_pytorch-0.0.5-py3-none-any.whl.
File metadata
- Download URL: gaf_microbatch_pytorch-0.0.5-py3-none-any.whl
- Upload date:
- Size: 6.2 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.0.1 CPython/3.9.21
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
2d55554f27f2a7a925d9ecc681a4b8484c9d5eb6e39d81b0d25c54910ad9f49e
|
|
| MD5 |
a28224e1d4113b8709e9bd4db605c59b
|
|
| BLAKE2b-256 |
034f440dfa1dd529e81f4937dfe0986bed7fe3e1116df175d685dbfec4a11a56
|