Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

functorch mode computes the forward twice #521

Open
kshitij12345 opened this issue Oct 13, 2022 · 3 comments
Open

functorch mode computes the forward twice #521

kshitij12345 opened this issue Oct 13, 2022 · 3 comments
Assignees
Labels
enhancement New feature or request

Comments

@kshitij12345
Copy link

kshitij12345 commented Oct 13, 2022

In functorch mode, Opacus saves the activations from previous layers and uses to compute the gradient per sample with functorch. However, functorch.grad ends up doing forward on the layer and a backward. This hampers the performance when using functorch mode.

Ref to the code which uses functorch.grad:

ft_compute_grad = grad(compute_loss_stateless_model)
# Note that the vmap is done on the first dimension, regardless of batch_first
# This is because the activations and backprops given by the GradSampleModule
# are always batch_first=True
layer.ft_compute_sample_grad = vmap(ft_compute_grad, in_dims=(None, 0, 0))

per_sample_grads = layer.ft_compute_sample_grad(parameters, activations, backprops)


We can apply the following patch (for Linear) to improve the perf when using functorch mode. The idea is similar to what is done with hooks approach.

diff --git a/opacus/grad_sample/functorch.py b/opacus/grad_sample/functorch.py
index 9777950..45021bb 100644
--- a/opacus/grad_sample/functorch.py
+++ b/opacus/grad_sample/functorch.py
@@ -1,5 +1,6 @@
 from opacus.layers.dp_rnn import RNNLinear
-
+import torch
+import functorch
 
 def prepare_layer(layer, batch_first=True):
     """
@@ -48,6 +49,14 @@ def ft_compute_per_sample_gradient(layer, activations, backprops):
         activations: the input to the layer
         backprops: the  gradient of the loss w.r.t. outputs of the layer
     """
+    if isinstance(layer, torch.nn.Linear):
+        ret = {}
+        if layer.weight.requires_grad:
+            ret[layer.weight] = functorch.vmap(torch.mm)(backprops, activations)
+        if layer.bias is not None and layer.bias.requires_grad:
+            ret[layer.bias] = functorch.vmap(torch.sum)(backprops)
+        return ret
+
     parameters = list(layer.parameters())
     if not hasattr(layer, "ft_compute_sample_grad"):
         prepare_layer(layer)

Before Patch

Per-sample-grads without functorch
compute_grad_opacus(gs_lin_mod, data, targets)
  Median: 1.64 ms
  IQR:    0.06 ms (1.61 to 1.67)
  61 measurements, 100 runs per measurement, 1 thread

Per-sample-grads with functorch
compute_grad_opacus(gs_lin_mod_functorch, data, targets)
  Median: 2.16 ms
  IQR:    0.08 ms (2.09 to 2.16)
  47 measurements, 100 runs per measurement, 1 thread

After Patch

Per-sample-grads without functorch
compute_grad_opacus(gs_lin_mod, data, targets)
  Median: 1.60 ms
  IQR:    0.02 ms (1.60 to 1.61)
  63 measurements, 100 runs per measurement, 1 thread

Per-sample-grads with functorch
compute_grad_opacus(gs_lin_mod_functorch, data, targets)
  Median: 1.39 ms
  IQR:    0.01 ms (1.39 to 1.40)
  72 measurements, 100 runs per measurement, 1 thread
Benchmark Script
from opacus.utils.module_utils import (
    trainable_modules)
import copy
from opacus.grad_sample import GradSampleModule
from torch.utils.benchmark import Timer
import torch
torch.manual_seed(42)


batch_size = 8
N = 128
device = 'cpu'
net = torch.nn.Linear(N, N, bias=False).to(device)
net_2 = copy.deepcopy(net)


def compute_loss(x, y):
    return (x - y).sum()

gs_lin_mod = GradSampleModule(net, force_functorch=False)
gs_lin_mod_functorch = GradSampleModule(net_2, force_functorch=True)


def compute_grad_opacus(net, sample, target):
    prediction = net(sample)
    loss = compute_loss(prediction, target)
    loss.backward()

    modules = list(trainable_modules(net))
    _, module = modules[0]
    grad_samples = module.weight.grad_sample
    o = grad_samples
    net.zero_grad()
    return (o,)


# Make sure that outputs match
for _ in range(5):
    d = torch.randn(batch_size, N, N, device=device, requires_grad=False)
    t = torch.randn(batch_size, N, N, device=device, requires_grad=False)
    o_ = compute_grad_opacus(gs_lin_mod, d, t)
    o2_ = compute_grad_opacus(gs_lin_mod_functorch, d, t)
    for o, o2 in zip(o_, o2_):
        torch.testing.assert_close(o, o2)

data = torch.randn(batch_size, N, N, device=device, requires_grad=False)
targets = torch.randn(batch_size, N, N, device=device, requires_grad=False)

# Benchmark
without_functorch = Timer(
    stmt="compute_grad_opacus(gs_lin_mod, data, targets)", globals=globals())
with_functorch = Timer(
    stmt="compute_grad_opacus(gs_lin_mod_functorch, data, targets)", globals=globals())

no_functorch_timing = without_functorch.blocked_autorange(min_run_time=10)
print(f'Per-sample-grads without functorch {no_functorch_timing}')

functorch_timing = with_functorch.blocked_autorange(min_run_time=10)
print(f'Per-sample-grads with functorch {functorch_timing}')

cc: @zou3519

@kshitij12345 kshitij12345 changed the title functorch mode recomputes the forward twice functorch mode computes the forward twice Oct 13, 2022
@zou3519
Copy link

zou3519 commented Oct 13, 2022

My understanding is functorch mode is designed to handle arbitrary modules. Since we (as library developers) know what the backward pass of linear is, we know it is possible to avoid re-calling the forward pass.

For arbitrary user modules, we don't know what the backward pass is, so we cannot directly vmap over it. functorch doesn't know what the backward pass is until the forward pass gets executed using functorch, so it needs to recompute the forward pass.

I'm wondering if there's a way to avoid computing the forward pass twice (once in the original execution of the module, and once in the grad_sample computation)

@zou3519
Copy link

zou3519 commented Oct 13, 2022

I believe something like the following works, though functorch does not officially support it (yet). The idea is:

  • somehow figure out how to run the forward pass using functorch
  • pass some state to the backward pass
  • during the backward pass, use functorch along with that state to compute per-sample-gradients
import torch
import torch.nn.functional as F
from functorch import vmap, vjp, grad
import functools

torch.manual_seed(0)
x = torch.randn(2, 3)
w = torch.randn(3, 3)

fn = torch.matmul

backdoor = []

def sample_grad_call(fn, x, w):
  def inner(x, w):
    res, vjp_fn = vjp(functools.partial(fn, x), w)
    backdoor.append(vjp_fn)
    return res

  res = vmap(inner, (0, None))(x, w)
  return res

def compute_grad_sample(grad_out):
  def inner(grad_out, dummy):
    return backdoor[-1](grad_out)

  grad_sample = vmap(inner)(grad_out, x)
  return grad_sample

# Somehow replace the module forward pass with the following
y = sample_grad_call(fn, x, w)

grad_y = torch.ones_like(y)

# And then replace the module backward pass with the following
result, = compute_grad_sample(grad_y)

# Here's a correctness check
w.requires_grad_()
expected0, = torch.autograd.grad(fn(x[0], w).sum(), w)
expected1, = torch.autograd.grad(fn(x[1], w).sum(), w)
expected = torch.stack([expected0, expected1])

On the functorch side, we'll try to hack up a POC integrating this approach with Opacus.

@alexandresablayrolles alexandresablayrolles self-assigned this Oct 21, 2022
@alexandresablayrolles
Copy link
Contributor

Thanks @zou3519 for jumping in. I second this:

functorch mode is designed to handle arbitrary modules.

For linear modules we use einsums (either directly or through ExpandedWeights). Note that it is also possible to have the entire model be functional and pass the grad samples to Opacus (it is the "no_op" GradSampleModule).
I agree that doing the forward twice is a bit wasteful but the hope is that it is isolated to one layer and so the gains from going back to 1 forward would be rather small.
Thanks @zou3519 for proposing this solution though, I'm happy to follow-up to see if we can integrate it with Opacus.

@alexandresablayrolles alexandresablayrolles added the enhancement New feature or request label Oct 21, 2022
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
Development

No branches or pull requests

3 participants