diff --git a/opacus/data_loader.py b/opacus/data_loader.py index 4feaaf94..0f45bbfb 100644 --- a/opacus/data_loader.py +++ b/opacus/data_loader.py @@ -13,7 +13,7 @@ # limitations under the License. import logging -from typing import Any, Optional, Sequence +from typing import Any, Optional, Sequence, Tuple, Type, Union import torch from opacus.utils.uniform_sampler import ( @@ -29,7 +29,10 @@ def wrap_collate_with_empty( - collate_fn: Optional[_collate_fn_t], sample_empty_shapes: Sequence + *, + collate_fn: Optional[_collate_fn_t], + sample_empty_shapes: Sequence[Tuple], + dtypes: Sequence[Union[torch.dtype, Type]], ): """ Wraps given collate function to handle empty batches. @@ -49,12 +52,15 @@ def collate(batch): if len(batch) > 0: return collate_fn(batch) else: - return [torch.zeros(x) for x in sample_empty_shapes] + return [ + torch.zeros(shape, dtype=dtype) + for shape, dtype in zip(sample_empty_shapes, dtypes) + ] return collate -def shape_safe(x: Any): +def shape_safe(x: Any) -> Tuple: """ Exception-safe getter for ``shape`` attribute @@ -67,6 +73,19 @@ def shape_safe(x: Any): return x.shape if hasattr(x, "shape") else () +def dtype_safe(x: Any) -> Union[torch.dtype, Type]: + """ + Exception-safe getter for ``dtype`` attribute + + Args: + x: any object + + Returns: + ``x.dtype`` if attribute exists, type of x otherwise + """ + return x.dtype if hasattr(x, "dtype") else type(x) + + class DPDataLoader(DataLoader): """ DataLoader subclass that always does Poisson sampling and supports empty batches @@ -143,7 +162,8 @@ def __init__( sample_rate=sample_rate, generator=generator, ) - sample_empty_shapes = [[0, *shape_safe(x)] for x in dataset[0]] + sample_empty_shapes = [(0, *shape_safe(x)) for x in dataset[0]] + dtypes = [dtype_safe(x) for x in dataset[0]] if collate_fn is None: collate_fn = default_collate @@ -156,7 +176,11 @@ def __init__( dataset=dataset, batch_sampler=batch_sampler, num_workers=num_workers, - collate_fn=wrap_collate_with_empty(collate_fn, sample_empty_shapes), + collate_fn=wrap_collate_with_empty( + collate_fn=collate_fn, + sample_empty_shapes=sample_empty_shapes, + dtypes=dtypes, + ), pin_memory=pin_memory, timeout=timeout, worker_init_fn=worker_init_fn, diff --git a/opacus/grad_sample/README.md b/opacus/grad_sample/README.md index 1a78499a..7827680f 100644 --- a/opacus/grad_sample/README.md +++ b/opacus/grad_sample/README.md @@ -74,6 +74,7 @@ Please note that these are known limitations and we plan to improve Expanded Wei | `batch_first=False` | ✅ Supported | Not supported | ✅ Supported | | Recurrent networks | ✅ Supported | Not supported | ✅ Supported | | Padding `same` in Conv | ✅ Supported | Not supported | ✅ Supported | +| Empty poisson batches | ✅ Supported | Not supported | Not supported | † Note, that performance differences are unstable and can vary a lot depending on the exact model and batch size. Numbers above are averaged over benchmarks with small models consisting of convolutional and linear layers. diff --git a/opacus/grad_sample/conv.py b/opacus/grad_sample/conv.py index abc044e1..16f2d95b 100644 --- a/opacus/grad_sample/conv.py +++ b/opacus/grad_sample/conv.py @@ -42,6 +42,14 @@ def compute_conv_grad_sample( """ activations = activations[0] n = activations.shape[0] + if n == 0: + # Empty batch + ret = {} + ret[layer.weight] = torch.zeros_like(layer.weight).unsqueeze(0) + if layer.bias is not None and layer.bias.requires_grad: + ret[layer.bias] = torch.zeros_like(layer.bias).unsqueeze(0) + return ret + # get activations and backprops in shape depending on the Conv layer if type(layer) == nn.Conv2d: activations = unfold2d( diff --git a/opacus/grad_sample/embedding.py b/opacus/grad_sample/embedding.py index 476a66e7..9e206a6a 100644 --- a/opacus/grad_sample/embedding.py +++ b/opacus/grad_sample/embedding.py @@ -40,6 +40,10 @@ def compute_embedding_grad_sample( torch.backends.cudnn.deterministic = True batch_size = activations.shape[0] + if batch_size == 0: + ret[layer.weight] = torch.zeros_like(layer.weight).unsqueeze(0) + return ret + index = ( activations.unsqueeze(-1) .expand(*activations.shape, layer.embedding_dim) diff --git a/opacus/optimizers/optimizer.py b/opacus/optimizers/optimizer.py index 46a414d9..1afb3ce5 100644 --- a/opacus/optimizers/optimizer.py +++ b/opacus/optimizers/optimizer.py @@ -394,13 +394,17 @@ def clip_and_accumulate(self): Stores clipped and aggregated gradients into `p.summed_grad``` """ - per_param_norms = [ - g.reshape(len(g), -1).norm(2, dim=-1) for g in self.grad_samples - ] - per_sample_norms = torch.stack(per_param_norms, dim=1).norm(2, dim=1) - per_sample_clip_factor = (self.max_grad_norm / (per_sample_norms + 1e-6)).clamp( - max=1.0 - ) + if len(self.grad_samples[0]) == 0: + # Empty batch + per_sample_clip_factor = torch.zeros((0,)) + else: + per_param_norms = [ + g.reshape(len(g), -1).norm(2, dim=-1) for g in self.grad_samples + ] + per_sample_norms = torch.stack(per_param_norms, dim=1).norm(2, dim=1) + per_sample_clip_factor = ( + self.max_grad_norm / (per_sample_norms + 1e-6) + ).clamp(max=1.0) for p in self.params: _check_processed_flag(p.grad_sample) diff --git a/opacus/tests/batch_memory_manager_test.py b/opacus/tests/batch_memory_manager_test.py index 08bca477..14bdbea7 100644 --- a/opacus/tests/batch_memory_manager_test.py +++ b/opacus/tests/batch_memory_manager_test.py @@ -37,8 +37,7 @@ class BatchMemoryManagerTest(unittest.TestCase): GSM_MODE = "hooks" def setUp(self) -> None: - self.data_size = 100 - self.batch_size = 10 + self.data_size = 256 self.inps = torch.randn(self.data_size, 5) self.tgts = torch.randn( self.data_size, @@ -46,11 +45,11 @@ def setUp(self) -> None: self.dataset = TensorDataset(self.inps, self.tgts) - def _init_training(self, **data_loader_kwargs): + def _init_training(self, batch_size=10, **data_loader_kwargs): model = Model() optimizer = torch.optim.SGD(model.parameters(), lr=0.1) data_loader = DataLoader( - self.dataset, batch_size=self.batch_size, **data_loader_kwargs + self.dataset, batch_size=batch_size, **data_loader_kwargs ) return model, optimizer, data_loader @@ -58,16 +57,22 @@ def _init_training(self, **data_loader_kwargs): @given( num_workers=st.integers(0, 4), pin_memory=st.booleans(), + batch_size=st.sampled_from([8, 16, 64]), + max_physical_batch_size=st.sampled_from([4, 8]), ) @settings(deadline=10000) def test_basic( self, num_workers: int, pin_memory: bool, + batch_size: int, + max_physical_batch_size: int, ): + batches_per_step = max(1, batch_size // max_physical_batch_size) model, optimizer, data_loader = self._init_training( num_workers=num_workers, pin_memory=pin_memory, + batch_size=batch_size, ) privacy_engine = PrivacyEngine() @@ -80,22 +85,19 @@ def test_basic( poisson_sampling=False, grad_sample_mode=self.GSM_MODE, ) - max_physical_batch_size = 3 with BatchMemoryManager( data_loader=data_loader, max_physical_batch_size=max_physical_batch_size, optimizer=optimizer, ) as new_data_loader: - self.assertEqual( - len(data_loader), len(data_loader.dataset) // self.batch_size - ) + self.assertEqual(len(data_loader), len(data_loader.dataset) // batch_size) self.assertEqual( len(new_data_loader), len(data_loader.dataset) // max_physical_batch_size, ) weights_before = torch.clone(model._module.fc.weight) for i, (x, y) in enumerate(new_data_loader): - self.assertTrue(x.shape[0] <= 3) + self.assertTrue(x.shape[0] <= max_physical_batch_size) out = model(x) loss = (y - out).mean() @@ -104,7 +106,63 @@ def test_basic( optimizer.step() optimizer.zero_grad() - if i % 4 < 3: + if (i + 1) % batches_per_step > 0: + self.assertTrue( + torch.allclose(model._module.fc.weight, weights_before) + ) + else: + self.assertFalse( + torch.allclose(model._module.fc.weight, weights_before) + ) + weights_before = torch.clone(model._module.fc.weight) + + @given( + num_workers=st.integers(0, 4), + pin_memory=st.booleans(), + ) + @settings(deadline=10000) + def test_empty_batch( + self, + num_workers: int, + pin_memory: bool, + ): + batch_size = 2 + max_physical_batch_size = 10 + torch.manual_seed(30) + + model, optimizer, data_loader = self._init_training( + num_workers=num_workers, + pin_memory=pin_memory, + batch_size=batch_size, + ) + + privacy_engine = PrivacyEngine() + model, optimizer, data_loader = privacy_engine.make_private( + module=model, + optimizer=optimizer, + data_loader=data_loader, + noise_multiplier=0.0, + max_grad_norm=1e5, + poisson_sampling=True, + grad_sample_mode=self.GSM_MODE, + ) + with BatchMemoryManager( + data_loader=data_loader, + max_physical_batch_size=max_physical_batch_size, + optimizer=optimizer, + ) as new_data_loader: + weights_before = torch.clone(model._module.fc.weight) + for i, (x, y) in enumerate(new_data_loader): + self.assertTrue(x.shape[0] <= max_physical_batch_size) + + out = model(x) + loss = (y - out).mean() + + loss.backward() + optimizer.step() + optimizer.zero_grad() + + if len(x) == 0: self.assertTrue( torch.allclose(model._module.fc.weight, weights_before) ) @@ -174,3 +232,13 @@ def test_equivalent_to_one_batch(self): ) class BatchMemoryManagerTestWithExpandedWeights(BatchMemoryManagerTest): GSM_MODE = "ew" + + def test_empty_batch(self): + pass + + +@unittest.skipIf( + torch.__version__ >= API_CUTOFF_VERSION, "not supported in this torch version" +) +class BatchMemoryManagerTestWithFunctorch(BatchMemoryManagerTest): + GSM_MODE = "functorch" diff --git a/opacus/tests/grad_samples/common.py b/opacus/tests/grad_samples/common.py index be5e220f..25ffbdc9 100644 --- a/opacus/tests/grad_samples/common.py +++ b/opacus/tests/grad_samples/common.py @@ -15,7 +15,7 @@ import io import unittest -from typing import Dict, List, Union +from typing import Dict, Iterable, List, Tuple, Union import numpy as np import torch @@ -36,6 +36,13 @@ def shrinker(x, factor: int = 2): return max(1, x // factor) # if avoid returning 0 for x == 1 +def is_batch_empty(batch: Union[torch.Tensor, Iterable[torch.Tensor]]): + if type(batch) is torch.Tensor: + return batch.numel() == 0 + else: + return batch[0].numel() == 0 + + class ModelWithLoss(nn.Module): """ To test the gradients of a module, we need to have a loss. @@ -221,7 +228,7 @@ def compute_opacus_grad_sample( def run_test( self, - x: Union[torch.Tensor, PackedSequence], + x: Union[torch.Tensor, PackedSequence, Tuple], module: nn.Module, batch_first=True, atol=10e-6, @@ -235,7 +242,9 @@ def run_test( except ImportError: grad_sample_modes = ["hooks"] - if type(module) is nn.EmbeddingBag: + if type(module) is nn.EmbeddingBag or ( + type(x) is not PackedSequence and is_batch_empty(x) + ): grad_sample_modes = ["hooks"] for grad_sample_mode in grad_sample_modes: @@ -277,6 +286,14 @@ def run_test_with_reduction( grad_sample_mode="hooks", chunk_method=iter, ): + opacus_grad_samples = self.compute_opacus_grad_sample( + x, + module, + batch_first=batch_first, + loss_reduction=loss_reduction, + grad_sample_mode=grad_sample_mode, + ) + if type(x) is PackedSequence: x_unpacked = _unpack_packedsequences(x) microbatch_grad_samples = self.compute_microbatch_grad_sample( @@ -285,7 +302,7 @@ def run_test_with_reduction( batch_first=batch_first, loss_reduction=loss_reduction, ) - else: + elif not is_batch_empty(x): microbatch_grad_samples = self.compute_microbatch_grad_sample( x, module, @@ -293,14 +310,9 @@ def run_test_with_reduction( loss_reduction=loss_reduction, chunk_method=chunk_method, ) - - opacus_grad_samples = self.compute_opacus_grad_sample( - x, - module, - batch_first=batch_first, - loss_reduction=loss_reduction, - grad_sample_mode=grad_sample_mode, - ) + else: + # We've checked opacus can handle 0-sized batch. Microbatch doesn't make sense + return if microbatch_grad_samples.keys() != opacus_grad_samples.keys(): raise ValueError( diff --git a/opacus/tests/grad_samples/conv1d_test.py b/opacus/tests/grad_samples/conv1d_test.py index 179e496e..2576f159 100644 --- a/opacus/tests/grad_samples/conv1d_test.py +++ b/opacus/tests/grad_samples/conv1d_test.py @@ -25,7 +25,7 @@ class Conv1d_test(GradSampleHooks_test): @given( - N=st.integers(1, 4), + N=st.integers(0, 4), C=st.sampled_from([1, 3, 32]), W=st.integers(6, 10), out_channels_mapper=st.sampled_from([expander, shrinker]), @@ -67,4 +67,6 @@ def test_conv1d( dilation=dilation, groups=groups, ) - self.run_test(x, conv, batch_first=True, atol=10e-5, rtol=10e-4) + self.run_test( + x, conv, batch_first=True, atol=10e-5, rtol=10e-4, ew_compatible=N > 0 + ) diff --git a/opacus/tests/grad_samples/conv2d_test.py b/opacus/tests/grad_samples/conv2d_test.py index f27ad158..6d9a5b33 100644 --- a/opacus/tests/grad_samples/conv2d_test.py +++ b/opacus/tests/grad_samples/conv2d_test.py @@ -29,7 +29,7 @@ class Conv2d_test(GradSampleHooks_test): @given( - N=st.integers(1, 4), + N=st.integers(0, 4), C=st.sampled_from([1, 3, 32]), H=st.integers(11, 17), W=st.integers(11, 17), @@ -73,7 +73,7 @@ def test_conv2d( groups=groups, ) is_ew_compatible = ( - padding != "same" + padding != "same" and N > 0 ) # TODO add support for padding = 'same' with EW # Test regular GSM @@ -86,7 +86,7 @@ def test_conv2d( ew_compatible=is_ew_compatible, ) - if padding != "same": + if padding != "same" and N > 0: # Test 'convolution as a backward' GSM # 'convolution as a backward' doesn't support padding=same conv2d_gsm = GradSampleModule.GRAD_SAMPLERS[nn.Conv2d] diff --git a/opacus/tests/grad_samples/conv3d_test.py b/opacus/tests/grad_samples/conv3d_test.py index afa01b4b..e50909e2 100644 --- a/opacus/tests/grad_samples/conv3d_test.py +++ b/opacus/tests/grad_samples/conv3d_test.py @@ -25,7 +25,7 @@ class Conv3d_test(GradSampleHooks_test): @given( - N=st.integers(1, 4), + N=st.integers(0, 4), C=st.sampled_from([1, 3, 32]), D=st.integers(3, 6), H=st.integers(6, 10), @@ -71,7 +71,7 @@ def test_conv3d( groups=groups, ) is_ew_compatible = ( - dilation == 1 and padding != "same" + dilation == 1 and padding != "same" and N > 0 ) # TODO add support for padding = 'same' with EW self.run_test( x, diff --git a/opacus/tests/grad_samples/embedding_test.py b/opacus/tests/grad_samples/embedding_test.py index bb768450..a4afb37e 100644 --- a/opacus/tests/grad_samples/embedding_test.py +++ b/opacus/tests/grad_samples/embedding_test.py @@ -23,7 +23,7 @@ class Embedding_test(GradSampleHooks_test): @given( - N=st.integers(1, 4), + N=st.integers(0, 4), T=st.integers(1, 5), Q=st.integers(1, 4), R=st.integers(1, 2), @@ -56,4 +56,4 @@ def test_input_across_dims( emb = nn.Embedding(V, D) x = torch.randint(low=0, high=V, size=size) - self.run_test(x, emb, batch_first=batch_first) + self.run_test(x, emb, batch_first=batch_first, ew_compatible=N > 0) diff --git a/opacus/tests/grad_samples/group_norm_test.py b/opacus/tests/grad_samples/group_norm_test.py index 2f4bbaff..e3836b93 100644 --- a/opacus/tests/grad_samples/group_norm_test.py +++ b/opacus/tests/grad_samples/group_norm_test.py @@ -30,7 +30,7 @@ class GroupNorm_test(GradSampleHooks_test): """ @given( - N=st.integers(1, 4), + N=st.integers(0, 4), C=st.integers(1, 8), H=st.integers(5, 10), W=st.integers(4, 8), @@ -54,4 +54,4 @@ def test_3d_input_groups( x = torch.randn([N, C, H, W]) norm = nn.GroupNorm(num_groups=num_groups, num_channels=C, affine=True) - self.run_test(x, norm, batch_first=True) + self.run_test(x, norm, batch_first=True, ew_compatible=N > 0) diff --git a/opacus/tests/grad_samples/linear_test.py b/opacus/tests/grad_samples/linear_test.py index 3b23f3ef..e856e9d3 100644 --- a/opacus/tests/grad_samples/linear_test.py +++ b/opacus/tests/grad_samples/linear_test.py @@ -23,7 +23,7 @@ class Linear_test(GradSampleHooks_test): @given( - N=st.integers(1, 4), + N=st.integers(0, 4), Z=st.integers(1, 4), H=st.integers(1, 3), W=st.integers(10, 17), @@ -57,4 +57,4 @@ def test_input_bias( x = torch.randn(x_shape) if not batch_first: x = x.transpose(0, 1) - self.run_test(x, linear, batch_first=batch_first) + self.run_test(x, linear, batch_first=batch_first, ew_compatible=N > 0) diff --git a/opacus/tests/grad_samples/sequence_bias_test.py b/opacus/tests/grad_samples/sequence_bias_test.py index b61ffc66..ec36d74b 100644 --- a/opacus/tests/grad_samples/sequence_bias_test.py +++ b/opacus/tests/grad_samples/sequence_bias_test.py @@ -23,7 +23,7 @@ class SequenceBias_test(GradSampleHooks_test): @given( - N=st.integers(1, 4), + N=st.integers(0, 4), T=st.integers(10, 20), D=st.integers(4, 8), batch_first=st.booleans(), diff --git a/opacus/tests/privacy_engine_test.py b/opacus/tests/privacy_engine_test.py index ad2eb6f3..ed5b1f7a 100644 --- a/opacus/tests/privacy_engine_test.py +++ b/opacus/tests/privacy_engine_test.py @@ -811,6 +811,20 @@ def _init_model(self): return SampleConvNet() +class PrivacyEngineConvNetEmptyBatchTest(PrivacyEngineConvNetTest): + def setUp(self): + super().setUp() + + # This will trigger multiple empty batches with poisson sampling enabled + self.BATCH_SIZE = 1 + + def test_checkpoints(self): + pass + + def test_noise_level(self): + pass + + class PrivacyEngineConvNetFrozenTest(BasePrivacyEngineTest, unittest.TestCase): def _init_data(self): ds = FakeData( diff --git a/opacus/utils/batch_memory_manager.py b/opacus/utils/batch_memory_manager.py index f7e3b65f..8f757e5d 100644 --- a/opacus/utils/batch_memory_manager.py +++ b/opacus/utils/batch_memory_manager.py @@ -53,6 +53,11 @@ def __init__( def __iter__(self): for batch_idxs in self.sampler: + if len(batch_idxs) == 0: + self.optimizer.signal_skip_step(do_skip=False) + yield [] + continue + split_idxs = np.array_split( batch_idxs, math.ceil(len(batch_idxs) / self.max_batch_size) )