Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Per sample grad correctness util #532

Closed
Closed
Show file tree
Hide file tree
Changes from 4 commits
Commits
Show all changes
38 commits
Select commit Hold shift + click to select a range
b8a2ff7
support empty batches in memory manager and optimizer
Oct 25, 2022
2e1b9d7
restore warning
Oct 25, 2022
df9d1ab
disable functorch test for 1.13+
Oct 25, 2022
0268fa1
Merge branch 'main' of github.com:pytorch/opacus into ffuuugor_522
Oct 26, 2022
b952c2a
0-batch tests
Oct 27, 2022
5c7fc6f
lint
Oct 27, 2022
64f08ad
EW test fix
Oct 27, 2022
df7c355
docstring up
Oct 27, 2022
b64b06a
Implement per sample grads util and refactor code
Oct 27, 2022
9b4d5ee
Merge branch 'pytorch:main' into per-sample-grad-correctness-util
psolikov Oct 27, 2022
3f9f9cd
Add docs and refactor
Oct 28, 2022
765b84e
Apply code style fixes
Oct 28, 2022
16477ac
Merge branch 'main' into per-sample-grad-correctness-util
psolikov Oct 28, 2022
585be68
Fix flake8 errors
Oct 28, 2022
82c8f52
Implement per sample grads util and refactor code
Oct 27, 2022
1d957fa
Fixed issue with missing argument in MNIST example (#520)
Oct 27, 2022
4e3a979
Add docs and refactor
Oct 28, 2022
3d0a5db
Apply code style fixes
Oct 28, 2022
36dd386
Functorch gradients: investigation and fix (#510)
Oct 28, 2022
c06ebec
Fix flake8 errors
Oct 28, 2022
5168e20
Add type hints
Oct 31, 2022
b71fb30
Refactor
Oct 31, 2022
cdcae86
Update docstrings
Oct 31, 2022
ab1d6a7
Fix reduction modes for EW
Oct 31, 2022
206a042
Rebase on #530, separate utils tests, refactor
Nov 1, 2022
f9a35de
Optimize imports
Nov 1, 2022
a8aac48
Merge remote-tracking branch 'origin/per-sample-grad-correctness-util…
Nov 1, 2022
f7880d8
Fix test
Nov 1, 2022
1ae50cb
Add utility description to tutorial
Nov 1, 2022
8c67f40
Fix grad samples test
Nov 7, 2022
0faf661
Fixed isort warnings
Nov 7, 2022
a95c95a
Fix grad samples zero batch test
Nov 7, 2022
786a093
Skip functorch test when unavailable
Nov 7, 2022
42866b1
Merge branch 'main' into per-sample-grad-correctness-util
Nov 8, 2022
6402c18
Fix merge
Nov 8, 2022
0b9a1ec
Isort fix
Nov 8, 2022
05fdb4f
Fix docstring
Nov 8, 2022
fd8fbde
Merge branch 'main' into per-sample-grad-correctness-util
psolikov Nov 17, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
264 changes: 10 additions & 254 deletions opacus/tests/grad_samples/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,20 +13,19 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import io
import unittest
from typing import Dict, List, Union
from typing import Union

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from opacus.grad_sample import wrap_model
from opacus.utils.module_utils import trainable_parameters
from opacus.utils.packed_sequences import compute_seq_lengths
from torch.nn.utils.rnn import PackedSequence, pad_packed_sequence
from torch.nn.utils.rnn import PackedSequence
from torch.testing import assert_allclose

from opacus.utils.per_sample_gradients_utils import (
compute_grad_samples_microbatch_and_opacus,
)


def expander(x, factor: int = 2):
return x * factor
Expand All @@ -36,181 +35,12 @@ def shrinker(x, factor: int = 2):
return max(1, x // factor) # if avoid returning 0 for x == 1


class ModelWithLoss(nn.Module):
"""
To test the gradients of a module, we need to have a loss.
This module makes it easy to get a loss from any nn.Module, and automatically generates
a target y vector for it in the forward (of all zeros of the correct size).
This reduces boilerplate while testing.
"""

supported_reductions = ["mean", "sum"]

def __init__(self, module: nn.Module, loss_reduction: str = "mean"):
"""
Instantiates this module.

Args:
module: The nn.Module you want to test.
loss_reduction: What reduction to apply to the loss. Defaults to "mean".

Raises:
ValueError: If ``loss_reduction`` is not among those supported.
"""
super().__init__()
self.wrapped_module = module

if loss_reduction not in self.supported_reductions:
raise ValueError(
f"Passed loss_reduction={loss_reduction}. Only {self.supported_reductions} supported."
)
self.criterion = nn.L1Loss(reduction=loss_reduction)

def forward(self, x):
x = self.wrapped_module(x)
if type(x) is PackedSequence:
loss = _compute_loss_packedsequences(self.criterion, x)
else:
y = torch.zeros_like(x)
loss = self.criterion(x, y)
return loss


def clone_module(module: nn.Module) -> nn.Module:
"""
Handy utility to clone an nn.Module. PyTorch doesn't always support copy.deepcopy(), so it is
just easier to serialize the model to a BytesIO and read it from there.

Args:
module: The module to clone

Returns:
The clone of ``module``
"""
with io.BytesIO() as bytesio:
torch.save(module, bytesio)
bytesio.seek(0)
module_copy = torch.load(bytesio)
return module_copy


class GradSampleHooks_test(unittest.TestCase):
"""
Set of common testing utils. It is meant to be subclassed by your test.
See other tests as an example of how this is done.
"""

def compute_microbatch_grad_sample(
self,
x: Union[torch.Tensor, List[torch.Tensor]],
module: nn.Module,
batch_first=True,
loss_reduction="mean",
) -> Dict[str, torch.tensor]:
"""
Computes per-sample gradients with the microbatch method, i.e. by computing normal gradients
with batch_size set to 1, and manually accumulating them. This is our reference for testing
as this method is obviously correct, but slow.

Args:
x: The tensor in input to the ``module``
module: The ``ModelWithLoss`` that wraps the nn.Module you want to test.
batch_first: Whether batch size is the first dimension (as opposed to the second).
Defaults to True.

Returns:
Dictionary mapping parameter_name -> per-sample-gradient for that parameter
"""
torch.use_deterministic_algorithms(True)
torch.manual_seed(0)
np.random.seed(0)

module = ModelWithLoss(clone_module(module), loss_reduction)

for _, p in trainable_parameters(module):
p.microbatch_grad_sample = []

if not batch_first and type(x) is not list:
# This allows us to iterate with x_i
x = x.transpose(0, 1)

# Invariant: x is [B, T, ...]

for x_i in x:
# x_i is [T, ...]
x_i = x_i.unsqueeze(
0 if batch_first else 1
) # x_i of size [1, T, ...] if batch_first, else [T, 1, ...]
module.zero_grad()
loss_i = module(x_i)
loss_i.backward()
for p in module.parameters():
p.microbatch_grad_sample.append(p.grad.detach().clone())

for _, p in trainable_parameters(module):
if batch_first:
p.microbatch_grad_sample = torch.stack(
p.microbatch_grad_sample, dim=0 # [B, T, ...]
)
else:
p.microbatch_grad_sample = torch.stack(
p.microbatch_grad_sample, dim=1 # [T, B, ...]
).transpose(
0, 1
) # Opacus's semantics is that grad_samples are ALWAYS batch_first: [B, T, ...]

microbatch_grad_samples = {
name: p.microbatch_grad_sample
for name, p in trainable_parameters(module.wrapped_module)
}
return microbatch_grad_samples

def compute_opacus_grad_sample(
self,
x: Union[torch.Tensor, PackedSequence],
module: nn.Module,
batch_first=True,
loss_reduction="mean",
grad_sample_mode="hooks",
) -> Dict[str, torch.tensor]:
"""
Runs Opacus to compute per-sample gradients and return them for testing purposes.

Args:
x: The tensor in input to the ``module``
module: The ``ModelWithLoss`` that wraps the nn.Module you want to test.
batch_first: Whether batch size is the first dimension (as opposed to the second).
Defaults to True.
loss_reduction: What reduction to apply to the loss. Defaults to "mean".

Returns:
Dictionary mapping parameter_name -> per-sample-gradient for that parameter
"""
torch.use_deterministic_algorithms(True)
torch.manual_seed(0)
np.random.seed(0)

gs_module = wrap_model(
model=clone_module(module),
grad_sample_mode=grad_sample_mode,
batch_first=batch_first,
loss_reduction=loss_reduction,
)
grad_sample_module = ModelWithLoss(gs_module, loss_reduction)

grad_sample_module.zero_grad()
loss = grad_sample_module(x)
loss.backward()

opacus_grad_samples = {
name: p.grad_sample
for name, p in trainable_parameters(
grad_sample_module.wrapped_module._module
)
}

return opacus_grad_samples

def run_test(
self,
x: Union[torch.Tensor, PackedSequence],
Expand All @@ -228,7 +58,6 @@ def run_test(

for grad_sample_mode in grad_sample_modes:
for loss_reduction in ["sum", "mean"]:

with self.subTest(
grad_sample_mode=grad_sample_mode, loss_reduction=loss_reduction
):
Expand Down Expand Up @@ -262,34 +91,17 @@ def run_test_with_reduction(
rtol=10e-5,
grad_sample_mode="hooks",
):
if type(x) is PackedSequence:
x_unpacked = _unpack_packedsequences(x)
microbatch_grad_samples = self.compute_microbatch_grad_sample(
x_unpacked,
module,
batch_first=batch_first,
loss_reduction=loss_reduction,
)
else:
microbatch_grad_samples = self.compute_microbatch_grad_sample(
x, module, batch_first=batch_first, loss_reduction=loss_reduction
)

opacus_grad_samples = self.compute_opacus_grad_sample(
(
microbatch_grad_samples,
opacus_grad_samples,
) = compute_grad_samples_microbatch_and_opacus(
x,
module,
batch_first=batch_first,
loss_reduction=loss_reduction,
grad_sample_mode=grad_sample_mode,
)

if microbatch_grad_samples.keys() != opacus_grad_samples.keys():
raise ValueError(
"Keys not matching! "
f"Keys only in microbatch: {microbatch_grad_samples.keys() - opacus_grad_samples.keys()}; "
f"Keys only in Opacus: {opacus_grad_samples.keys() - microbatch_grad_samples.keys()}"
)

self.check_shapes(microbatch_grad_samples, opacus_grad_samples, loss_reduction)
self.check_values(
microbatch_grad_samples, opacus_grad_samples, loss_reduction, atol, rtol
Expand Down Expand Up @@ -358,59 +170,3 @@ def check_values(
f"A total of {len(failed)} values do not match "
f"for loss_reduction={loss_reduction}: \n\t{failed_str}"
)


def _unpack_packedsequences(X: PackedSequence) -> List[torch.Tensor]:
r"""
Produces a list of tensors from X (PackedSequence) such that this list was used to create X with batch_first=True

Args:
X: A PackedSequence from which the output list of tensors will be produced.

Returns:
unpacked_data: The list of tensors produced from X.
"""

X_padded = pad_packed_sequence(X)
X_padded = X_padded[0].permute((1, 0, 2))

if X.sorted_indices is not None:
X_padded = X_padded[X.sorted_indices]

seq_lens = compute_seq_lengths(X.batch_sizes)
unpacked_data = [0] * len(seq_lens)
for idx, length in enumerate(seq_lens):
unpacked_data[idx] = X_padded[idx][:length, :]

return unpacked_data


def _compute_loss_packedsequences(
criterion: nn.L1Loss, x: PackedSequence
) -> torch.Tensor:
r"""
This function computes the loss in a different way for 'mean' reduced L1 loss while for 'sum' reduced L1 loss,
it computes the same way as with non-packed data. For 'mean' reduced L1 loss, it transforms x (PackedSequence)
into a list of tensors such that this list of tensors was used to create this PackedSequence in the first
place using batch_first=True and then takes the mean of the loss values produced from applying criterion on
each sequence sample.

Args:
criterion: An L1 loss function with reduction either set to 'sum' or 'mean'.
x: Data in the form of a PackedSequence.

Returns:
A loss variable, reduced either using summation or averaging from L1 errors.
"""

if criterion.reduction == "sum":
y = torch.zeros_like(x[0])
return criterion(x[0], y)
elif criterion.reduction == "mean":
x = _unpack_packedsequences(x)
loss_sum = 0
for x_i in x:
y_i = torch.zeros_like(x_i)
loss_sum += criterion(x_i, y_i)
loss_mean = loss_sum / len(x)
return loss_mean
19 changes: 18 additions & 1 deletion opacus/tests/grad_samples/conv1d_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,10 @@
from hypothesis import given, settings

from .common import GradSampleHooks_test, expander, shrinker
from ...utils.per_sample_gradients_utils import (
check_per_sample_gradients_are_correct,
get_grad_sample_modes,
)


class Conv1d_test(GradSampleHooks_test):
Expand All @@ -34,6 +38,7 @@ class Conv1d_test(GradSampleHooks_test):
padding=st.sampled_from([0, 1, 2, "same", "valid"]),
dilation=st.integers(1, 2),
groups=st.integers(1, 12),
test_or_check=st.integers(1, 2),
)
@settings(deadline=10000)
def test_conv1d(
Expand All @@ -47,6 +52,7 @@ def test_conv1d(
padding: int,
dilation: int,
groups: int,
test_or_check: int,
):

if padding == "same" and stride != 1:
Expand All @@ -67,4 +73,15 @@ def test_conv1d(
dilation=dilation,
groups=groups,
)
self.run_test(x, conv, batch_first=True, atol=10e-5, rtol=10e-4)
if test_or_check == 1:
self.run_test(x, conv, batch_first=True, atol=10e-5, rtol=10e-4)
if test_or_check == 2:
for grad_sample_mode in get_grad_sample_modes(use_ew=True):
assert check_per_sample_gradients_are_correct(
x,
conv,
batch_first=True,
atol=10e-5,
rtol=10e-4,
grad_sample_mode=grad_sample_mode,
)
Loading