Skip to content

Implementation of Gradient Agreement Filtering, from Chaubard et al. of Stanford, but for single machine microbatches, in Pytorch

License

Notifications You must be signed in to change notification settings

lucidrains/GAF-microbatch-pytorch

Repository files navigation

Gradient Agreement Filtering - Pytorch

Implementation of Gradient Agreement Filtering, from Chaubard et al. of Stanford, but done for single machine microbatches, in Pytorch.

The official repository that does filtering for macrobatches across machines is here

Install

$ pip install GAF-microbatch-pytorch

Usage

import torch

# mock network

from torch import nn

net = nn.Sequential(
    nn.Linear(512, 256),
    nn.SiLU(),
    nn.Linear(256, 128)
)

# import the gradient agreement filtering (GAF) wrapper

from GAF_microbatch_pytorch import GAFWrapper

# just wrap your neural net

gaf_net = GAFWrapper(
    net,
    filter_distance_thres = 0.97
)

# your batch of data

x = torch.randn(16, 1024, 512)

# forward and backwards as usual

out = gaf_net(x)

out.sum().backward()

# gradients should be filtered by set threshold comparing per sample gradients within batch, as in paper

You can supply your own gradient filtering method as a Callable[[Tensor], Tensor] with the filter_gradients_fn kwarg as so

def filtering_fn(grads):
    # make your big discovery here
    return grads
 
gaf_net = GAFWrapper(
    net = net,
    filter_gradients_fn = filtering_fn
)

To set all GAFWrapper states within a network, use set_filter_gradients_

from GAF_microbatch_pytorch import set_filter_gradients_

set_filter_gradients_(net, False) # turning on / off

# or perhaps filter thresholds on some schedule

set_filter_gradients_(net, True, 0.98)

Todo

  • replicate cifar results on single machine
  • allow for excluding certain parameters from being filtered

Citations

@inproceedings{Chaubard2024BeyondGA,
    title   = {Beyond Gradient Averaging in Parallel Optimization: Improved Robustness through Gradient Agreement Filtering},
    author  = {Francois Chaubard and Duncan Eddy and Mykel J. Kochenderfer},
    year    = {2024},
    url     = {https://api.semanticscholar.org/CorpusID:274992650}
}

About

Implementation of Gradient Agreement Filtering, from Chaubard et al. of Stanford, but for single machine microbatches, in Pytorch

Topics

Resources

License

Stars

Watchers

Forks

Packages

No packages published

Languages