Implementation of Gradient Agreement Filtering, from Chaubard et al. of Stanford, but done for single machine microbatches, in Pytorch.
The official repository that does filtering for macrobatches across machines is here
$ pip install GAF-microbatch-pytorch
import torch
# mock network
from torch import nn
net = nn.Sequential(
nn.Linear(512, 256),
nn.SiLU(),
nn.Linear(256, 128)
)
# import the gradient agreement filtering (GAF) wrapper
from GAF_microbatch_pytorch import GAFWrapper
# just wrap your neural net
gaf_net = GAFWrapper(
net,
filter_distance_thres = 0.97
)
# your batch of data
x = torch.randn(16, 1024, 512)
# forward and backwards as usual
out = gaf_net(x)
out.sum().backward()
# gradients should be filtered by set threshold comparing per sample gradients within batch, as in paper
You can supply your own gradient filtering method as a Callable[[Tensor], Tensor]
with the filter_gradients_fn
kwarg as so
def filtering_fn(grads):
# make your big discovery here
return grads
gaf_net = GAFWrapper(
net = net,
filter_gradients_fn = filtering_fn
)
To set all GAFWrapper
states within a network, use set_filter_gradients_
from GAF_microbatch_pytorch import set_filter_gradients_
set_filter_gradients_(net, False) # turning on / off
# or perhaps filter thresholds on some schedule
set_filter_gradients_(net, True, 0.98)
- replicate cifar results on single machine
- allow for excluding certain parameters from being filtered
@inproceedings{Chaubard2024BeyondGA,
title = {Beyond Gradient Averaging in Parallel Optimization: Improved Robustness through Gradient Agreement Filtering},
author = {Francois Chaubard and Duncan Eddy and Mykel J. Kochenderfer},
year = {2024},
url = {https://api.semanticscholar.org/CorpusID:274992650}
}