From b8a2ff79b8cbb63f4f224d9848f666ddce904ed8 Mon Sep 17 00:00:00 2001 From: Igor Shilov Date: Tue, 25 Oct 2022 14:12:54 +0100 Subject: [PATCH 1/9] support empty batches in memory manager and optimizer --- opacus/data_loader.py | 25 ++++++- opacus/grad_sample/README.md | 1 + opacus/grad_sample/conv.py | 8 +++ opacus/optimizers/optimizer.py | 18 +++-- opacus/privacy_engine.py | 11 +-- opacus/tests/batch_memory_manager_test.py | 85 ++++++++++++++++++++--- opacus/tests/privacy_engine_test.py | 14 ++++ opacus/utils/batch_memory_manager.py | 5 ++ 8 files changed, 142 insertions(+), 25 deletions(-) diff --git a/opacus/data_loader.py b/opacus/data_loader.py index 4feaaf94..8b200d49 100644 --- a/opacus/data_loader.py +++ b/opacus/data_loader.py @@ -29,7 +29,9 @@ def wrap_collate_with_empty( - collate_fn: Optional[_collate_fn_t], sample_empty_shapes: Sequence + collate_fn: Optional[_collate_fn_t], + sample_empty_shapes: Sequence[torch.Size], + dtypes: Sequence[torch.dtype], ): """ Wraps given collate function to handle empty batches. @@ -49,7 +51,10 @@ def collate(batch): if len(batch) > 0: return collate_fn(batch) else: - return [torch.zeros(x) for x in sample_empty_shapes] + return [ + torch.zeros(shape, dtype=dtype) + for shape, dtype in zip(sample_empty_shapes, dtypes) + ] return collate @@ -67,6 +72,19 @@ def shape_safe(x: Any): return x.shape if hasattr(x, "shape") else () +def dtype_safe(x: Any): + """ + Exception-safe getter for ``dtype`` attribute + + Args: + x: any object + + Returns: + ``x.shape`` if attribute exists, empty tuple otherwise + """ + return x.dtype if hasattr(x, "dtype") else type(x) + + class DPDataLoader(DataLoader): """ DataLoader subclass that always does Poisson sampling and supports empty batches @@ -144,6 +162,7 @@ def __init__( generator=generator, ) sample_empty_shapes = [[0, *shape_safe(x)] for x in dataset[0]] + dtypes = [dtype_safe(x) for x in dataset[0]] if collate_fn is None: collate_fn = default_collate @@ -156,7 +175,7 @@ def __init__( dataset=dataset, batch_sampler=batch_sampler, num_workers=num_workers, - collate_fn=wrap_collate_with_empty(collate_fn, sample_empty_shapes), + collate_fn=wrap_collate_with_empty(collate_fn, sample_empty_shapes, dtypes), pin_memory=pin_memory, timeout=timeout, worker_init_fn=worker_init_fn, diff --git a/opacus/grad_sample/README.md b/opacus/grad_sample/README.md index 1a78499a..e3eed52a 100644 --- a/opacus/grad_sample/README.md +++ b/opacus/grad_sample/README.md @@ -74,6 +74,7 @@ Please note that these are known limitations and we plan to improve Expanded Wei | `batch_first=False` | ✅ Supported | Not supported | ✅ Supported | | Recurrent networks | ✅ Supported | Not supported | ✅ Supported | | Padding `same` in Conv | ✅ Supported | Not supported | ✅ Supported | +| Empty poisson batches | ✅ Supported | Not supported | ✅ Supported | † Note, that performance differences are unstable and can vary a lot depending on the exact model and batch size. Numbers above are averaged over benchmarks with small models consisting of convolutional and linear layers. diff --git a/opacus/grad_sample/conv.py b/opacus/grad_sample/conv.py index 2b8e6299..4014d80e 100644 --- a/opacus/grad_sample/conv.py +++ b/opacus/grad_sample/conv.py @@ -41,6 +41,14 @@ def compute_conv_grad_sample( backprops: Backpropagations """ n = activations.shape[0] + if n == 0: + # Empty batch + ret = {} + ret[layer.weight] = torch.zeros_like(layer.weight).unsqueeze(0) + if layer.bias is not None and layer.bias.requires_grad: + ret[layer.bias] = torch.zeros_like(layer.bias).unsqueeze(0) + return ret + # get activations and backprops in shape depending on the Conv layer if type(layer) == nn.Conv2d: activations = unfold2d( diff --git a/opacus/optimizers/optimizer.py b/opacus/optimizers/optimizer.py index 46a414d9..1afb3ce5 100644 --- a/opacus/optimizers/optimizer.py +++ b/opacus/optimizers/optimizer.py @@ -394,13 +394,17 @@ def clip_and_accumulate(self): Stores clipped and aggregated gradients into `p.summed_grad``` """ - per_param_norms = [ - g.reshape(len(g), -1).norm(2, dim=-1) for g in self.grad_samples - ] - per_sample_norms = torch.stack(per_param_norms, dim=1).norm(2, dim=1) - per_sample_clip_factor = (self.max_grad_norm / (per_sample_norms + 1e-6)).clamp( - max=1.0 - ) + if len(self.grad_samples[0]) == 0: + # Empty batch + per_sample_clip_factor = torch.zeros((0,)) + else: + per_param_norms = [ + g.reshape(len(g), -1).norm(2, dim=-1) for g in self.grad_samples + ] + per_sample_norms = torch.stack(per_param_norms, dim=1).norm(2, dim=1) + per_sample_clip_factor = ( + self.max_grad_norm / (per_sample_norms + 1e-6) + ).clamp(max=1.0) for p in self.params: _check_processed_flag(p.grad_sample) diff --git a/opacus/privacy_engine.py b/opacus/privacy_engine.py index 4b46d337..e2414335 100644 --- a/opacus/privacy_engine.py +++ b/opacus/privacy_engine.py @@ -138,11 +138,12 @@ def __init__(self, *, accountant: str = "rdp", secure_mode: bool = False): self.secure_rng = csprng.create_random_device_generator("/dev/urandom") else: - warnings.warn( - "Secure RNG turned off. This is perfectly fine for experimentation as it allows " - "for much faster training performance, but remember to turn it on and retrain " - "one last time before production with ``secure_mode`` turned on." - ) + # warnings.warn( + # "Secure RNG turned off. This is perfectly fine for experimentation as it allows " + # "for much faster training performance, but remember to turn it on and retrain " + # "one last time before production with ``secure_mode`` turned on." + # ) + pass def _prepare_optimizer( self, diff --git a/opacus/tests/batch_memory_manager_test.py b/opacus/tests/batch_memory_manager_test.py index 08bca477..7a57b573 100644 --- a/opacus/tests/batch_memory_manager_test.py +++ b/opacus/tests/batch_memory_manager_test.py @@ -37,8 +37,7 @@ class BatchMemoryManagerTest(unittest.TestCase): GSM_MODE = "hooks" def setUp(self) -> None: - self.data_size = 100 - self.batch_size = 10 + self.data_size = 256 self.inps = torch.randn(self.data_size, 5) self.tgts = torch.randn( self.data_size, @@ -46,11 +45,11 @@ def setUp(self) -> None: self.dataset = TensorDataset(self.inps, self.tgts) - def _init_training(self, **data_loader_kwargs): + def _init_training(self, batch_size=10, **data_loader_kwargs): model = Model() optimizer = torch.optim.SGD(model.parameters(), lr=0.1) data_loader = DataLoader( - self.dataset, batch_size=self.batch_size, **data_loader_kwargs + self.dataset, batch_size=batch_size, **data_loader_kwargs ) return model, optimizer, data_loader @@ -58,16 +57,22 @@ def _init_training(self, **data_loader_kwargs): @given( num_workers=st.integers(0, 4), pin_memory=st.booleans(), + batch_size=st.sampled_from([8, 16, 64]), + max_physical_batch_size=st.sampled_from([4, 8]), ) @settings(deadline=10000) def test_basic( self, num_workers: int, pin_memory: bool, + batch_size: int, + max_physical_batch_size: int, ): + batches_per_step = max(1, batch_size // max_physical_batch_size) model, optimizer, data_loader = self._init_training( num_workers=num_workers, pin_memory=pin_memory, + batch_size=batch_size, ) privacy_engine = PrivacyEngine() @@ -80,22 +85,19 @@ def test_basic( poisson_sampling=False, grad_sample_mode=self.GSM_MODE, ) - max_physical_batch_size = 3 with BatchMemoryManager( data_loader=data_loader, max_physical_batch_size=max_physical_batch_size, optimizer=optimizer, ) as new_data_loader: - self.assertEqual( - len(data_loader), len(data_loader.dataset) // self.batch_size - ) + self.assertEqual(len(data_loader), len(data_loader.dataset) // batch_size) self.assertEqual( len(new_data_loader), len(data_loader.dataset) // max_physical_batch_size, ) weights_before = torch.clone(model._module.fc.weight) for i, (x, y) in enumerate(new_data_loader): - self.assertTrue(x.shape[0] <= 3) + self.assertTrue(x.shape[0] <= max_physical_batch_size) out = model(x) loss = (y - out).mean() @@ -104,7 +106,63 @@ def test_basic( optimizer.step() optimizer.zero_grad() - if i % 4 < 3: + if (i + 1) % batches_per_step > 0: + self.assertTrue( + torch.allclose(model._module.fc.weight, weights_before) + ) + else: + self.assertFalse( + torch.allclose(model._module.fc.weight, weights_before) + ) + weights_before = torch.clone(model._module.fc.weight) + + @given( + num_workers=st.integers(0, 4), + pin_memory=st.booleans(), + ) + @settings(deadline=10000) + def test_empty_batch( + self, + num_workers: int, + pin_memory: bool, + ): + batch_size = 2 + max_physical_batch_size = 10 + torch.manual_seed(30) + + model, optimizer, data_loader = self._init_training( + num_workers=num_workers, + pin_memory=pin_memory, + batch_size=batch_size, + ) + + privacy_engine = PrivacyEngine() + model, optimizer, data_loader = privacy_engine.make_private( + module=model, + optimizer=optimizer, + data_loader=data_loader, + noise_multiplier=0.0, + max_grad_norm=1e5, + poisson_sampling=True, + grad_sample_mode=self.GSM_MODE, + ) + with BatchMemoryManager( + data_loader=data_loader, + max_physical_batch_size=max_physical_batch_size, + optimizer=optimizer, + ) as new_data_loader: + weights_before = torch.clone(model._module.fc.weight) + for i, (x, y) in enumerate(new_data_loader): + self.assertTrue(x.shape[0] <= max_physical_batch_size) + + out = model(x) + loss = (y - out).mean() + + loss.backward() + optimizer.step() + optimizer.zero_grad() + + if len(x) == 0: self.assertTrue( torch.allclose(model._module.fc.weight, weights_before) ) @@ -174,3 +232,10 @@ def test_equivalent_to_one_batch(self): ) class BatchMemoryManagerTestWithExpandedWeights(BatchMemoryManagerTest): GSM_MODE = "ew" + + def test_empty_batch(self): + pass + + +class BatchMemoryManagerTestWithFunctorch(BatchMemoryManagerTest): + GSM_MODE = "functorch" diff --git a/opacus/tests/privacy_engine_test.py b/opacus/tests/privacy_engine_test.py index 90af717a..aede7578 100644 --- a/opacus/tests/privacy_engine_test.py +++ b/opacus/tests/privacy_engine_test.py @@ -805,6 +805,20 @@ def _init_model( return SampleConvNet() +class PrivacyEngineConvNetEmptyBatchTest(PrivacyEngineConvNetTest): + def setUp(self): + super().setUp() + + # This will trigger multiple empty batches with poisson sampling enabled + self.BATCH_SIZE = 1 + + def test_checkpoints(self): + pass + + def test_noise_level(self): + pass + + class PrivacyEngineConvNetFrozenTest(BasePrivacyEngineTest, unittest.TestCase): def _init_data(self): ds = FakeData( diff --git a/opacus/utils/batch_memory_manager.py b/opacus/utils/batch_memory_manager.py index f7e3b65f..8f757e5d 100644 --- a/opacus/utils/batch_memory_manager.py +++ b/opacus/utils/batch_memory_manager.py @@ -53,6 +53,11 @@ def __init__( def __iter__(self): for batch_idxs in self.sampler: + if len(batch_idxs) == 0: + self.optimizer.signal_skip_step(do_skip=False) + yield [] + continue + split_idxs = np.array_split( batch_idxs, math.ceil(len(batch_idxs) / self.max_batch_size) ) From 2e1b9d7adae727e7bf541c9271be5fb99d9b8216 Mon Sep 17 00:00:00 2001 From: Igor Shilov Date: Tue, 25 Oct 2022 14:15:23 +0100 Subject: [PATCH 2/9] restore warning --- opacus/privacy_engine.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/opacus/privacy_engine.py b/opacus/privacy_engine.py index e2414335..c1826559 100644 --- a/opacus/privacy_engine.py +++ b/opacus/privacy_engine.py @@ -138,11 +138,11 @@ def __init__(self, *, accountant: str = "rdp", secure_mode: bool = False): self.secure_rng = csprng.create_random_device_generator("/dev/urandom") else: - # warnings.warn( - # "Secure RNG turned off. This is perfectly fine for experimentation as it allows " - # "for much faster training performance, but remember to turn it on and retrain " - # "one last time before production with ``secure_mode`` turned on." - # ) + warnings.warn( + "Secure RNG turned off. This is perfectly fine for experimentation as it allows " + "for much faster training performance, but remember to turn it on and retrain " + "one last time before production with ``secure_mode`` turned on." + ) pass def _prepare_optimizer( From df9d1ab3b5750e428dab8f0b77b0d943b2af6aaf Mon Sep 17 00:00:00 2001 From: Igor Shilov Date: Tue, 25 Oct 2022 16:49:32 +0100 Subject: [PATCH 3/9] disable functorch test for 1.13+ --- opacus/tests/batch_memory_manager_test.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/opacus/tests/batch_memory_manager_test.py b/opacus/tests/batch_memory_manager_test.py index 7a57b573..14bdbea7 100644 --- a/opacus/tests/batch_memory_manager_test.py +++ b/opacus/tests/batch_memory_manager_test.py @@ -237,5 +237,8 @@ def test_empty_batch(self): pass +@unittest.skipIf( + torch.__version__ >= API_CUTOFF_VERSION, "not supported in this torch version" +) class BatchMemoryManagerTestWithFunctorch(BatchMemoryManagerTest): GSM_MODE = "functorch" From b952c2ae77f97675fc45fcb089805591f37588bd Mon Sep 17 00:00:00 2001 From: Igor Shilov Date: Thu, 27 Oct 2022 14:33:51 +0100 Subject: [PATCH 4/9] 0-batch tests --- opacus/grad_sample/README.md | 2 +- opacus/grad_sample/embedding.py | 4 ++++ opacus/tests/grad_samples/common.py | 24 ++++++++++++------- opacus/tests/grad_samples/conv1d_test.py | 2 +- opacus/tests/grad_samples/conv2d_test.py | 6 ++--- opacus/tests/grad_samples/conv3d_test.py | 2 +- opacus/tests/grad_samples/embedding_test.py | 2 +- opacus/tests/grad_samples/group_norm_test.py | 2 +- opacus/tests/grad_samples/linear_test.py | 2 +- .../tests/grad_samples/sequence_bias_test.py | 2 +- 10 files changed, 29 insertions(+), 19 deletions(-) diff --git a/opacus/grad_sample/README.md b/opacus/grad_sample/README.md index e3eed52a..7827680f 100644 --- a/opacus/grad_sample/README.md +++ b/opacus/grad_sample/README.md @@ -74,7 +74,7 @@ Please note that these are known limitations and we plan to improve Expanded Wei | `batch_first=False` | ✅ Supported | Not supported | ✅ Supported | | Recurrent networks | ✅ Supported | Not supported | ✅ Supported | | Padding `same` in Conv | ✅ Supported | Not supported | ✅ Supported | -| Empty poisson batches | ✅ Supported | Not supported | ✅ Supported | +| Empty poisson batches | ✅ Supported | Not supported | Not supported | † Note, that performance differences are unstable and can vary a lot depending on the exact model and batch size. Numbers above are averaged over benchmarks with small models consisting of convolutional and linear layers. diff --git a/opacus/grad_sample/embedding.py b/opacus/grad_sample/embedding.py index 94b86c4b..f0aa575a 100644 --- a/opacus/grad_sample/embedding.py +++ b/opacus/grad_sample/embedding.py @@ -39,6 +39,10 @@ def compute_embedding_grad_sample( torch.backends.cudnn.deterministic = True batch_size = activations.shape[0] + if batch_size == 0: + ret[layer.weight] = torch.zeros_like(layer.weight).unsqueeze(0) + return ret + index = ( activations.unsqueeze(-1) .expand(*activations.shape, layer.embedding_dim) diff --git a/opacus/tests/grad_samples/common.py b/opacus/tests/grad_samples/common.py index f7fa1eac..5a981ba2 100644 --- a/opacus/tests/grad_samples/common.py +++ b/opacus/tests/grad_samples/common.py @@ -226,6 +226,9 @@ def run_test( except ImportError: grad_sample_modes = ["hooks"] + if type(x) is not PackedSequence and x.numel() == 0: + grad_sample_modes = ["hooks"] + for grad_sample_mode in grad_sample_modes: for loss_reduction in ["sum", "mean"]: @@ -262,6 +265,14 @@ def run_test_with_reduction( rtol=10e-5, grad_sample_mode="hooks", ): + opacus_grad_samples = self.compute_opacus_grad_sample( + x, + module, + batch_first=batch_first, + loss_reduction=loss_reduction, + grad_sample_mode=grad_sample_mode, + ) + if type(x) is PackedSequence: x_unpacked = _unpack_packedsequences(x) microbatch_grad_samples = self.compute_microbatch_grad_sample( @@ -270,18 +281,13 @@ def run_test_with_reduction( batch_first=batch_first, loss_reduction=loss_reduction, ) - else: + elif x.numel() > 0: microbatch_grad_samples = self.compute_microbatch_grad_sample( x, module, batch_first=batch_first, loss_reduction=loss_reduction ) - - opacus_grad_samples = self.compute_opacus_grad_sample( - x, - module, - batch_first=batch_first, - loss_reduction=loss_reduction, - grad_sample_mode=grad_sample_mode, - ) + else: + # We've checked opacus can handle 0-sized batch. Microbatch doesn't make sense + return if microbatch_grad_samples.keys() != opacus_grad_samples.keys(): raise ValueError( diff --git a/opacus/tests/grad_samples/conv1d_test.py b/opacus/tests/grad_samples/conv1d_test.py index 179e496e..9ad2981a 100644 --- a/opacus/tests/grad_samples/conv1d_test.py +++ b/opacus/tests/grad_samples/conv1d_test.py @@ -25,7 +25,7 @@ class Conv1d_test(GradSampleHooks_test): @given( - N=st.integers(1, 4), + N=st.integers(0, 4), C=st.sampled_from([1, 3, 32]), W=st.integers(6, 10), out_channels_mapper=st.sampled_from([expander, shrinker]), diff --git a/opacus/tests/grad_samples/conv2d_test.py b/opacus/tests/grad_samples/conv2d_test.py index f27ad158..6d9a5b33 100644 --- a/opacus/tests/grad_samples/conv2d_test.py +++ b/opacus/tests/grad_samples/conv2d_test.py @@ -29,7 +29,7 @@ class Conv2d_test(GradSampleHooks_test): @given( - N=st.integers(1, 4), + N=st.integers(0, 4), C=st.sampled_from([1, 3, 32]), H=st.integers(11, 17), W=st.integers(11, 17), @@ -73,7 +73,7 @@ def test_conv2d( groups=groups, ) is_ew_compatible = ( - padding != "same" + padding != "same" and N > 0 ) # TODO add support for padding = 'same' with EW # Test regular GSM @@ -86,7 +86,7 @@ def test_conv2d( ew_compatible=is_ew_compatible, ) - if padding != "same": + if padding != "same" and N > 0: # Test 'convolution as a backward' GSM # 'convolution as a backward' doesn't support padding=same conv2d_gsm = GradSampleModule.GRAD_SAMPLERS[nn.Conv2d] diff --git a/opacus/tests/grad_samples/conv3d_test.py b/opacus/tests/grad_samples/conv3d_test.py index afa01b4b..1647ecef 100644 --- a/opacus/tests/grad_samples/conv3d_test.py +++ b/opacus/tests/grad_samples/conv3d_test.py @@ -25,7 +25,7 @@ class Conv3d_test(GradSampleHooks_test): @given( - N=st.integers(1, 4), + N=st.integers(0, 4), C=st.sampled_from([1, 3, 32]), D=st.integers(3, 6), H=st.integers(6, 10), diff --git a/opacus/tests/grad_samples/embedding_test.py b/opacus/tests/grad_samples/embedding_test.py index ff02a130..e803f2dd 100644 --- a/opacus/tests/grad_samples/embedding_test.py +++ b/opacus/tests/grad_samples/embedding_test.py @@ -23,7 +23,7 @@ class Embedding_test(GradSampleHooks_test): @given( - N=st.integers(1, 4), + N=st.integers(0, 4), T=st.integers(1, 5), Q=st.integers(1, 4), R=st.integers(1, 2), diff --git a/opacus/tests/grad_samples/group_norm_test.py b/opacus/tests/grad_samples/group_norm_test.py index 2f4bbaff..4f32f0e0 100644 --- a/opacus/tests/grad_samples/group_norm_test.py +++ b/opacus/tests/grad_samples/group_norm_test.py @@ -30,7 +30,7 @@ class GroupNorm_test(GradSampleHooks_test): """ @given( - N=st.integers(1, 4), + N=st.integers(0, 4), C=st.integers(1, 8), H=st.integers(5, 10), W=st.integers(4, 8), diff --git a/opacus/tests/grad_samples/linear_test.py b/opacus/tests/grad_samples/linear_test.py index 3b23f3ef..82ce7409 100644 --- a/opacus/tests/grad_samples/linear_test.py +++ b/opacus/tests/grad_samples/linear_test.py @@ -23,7 +23,7 @@ class Linear_test(GradSampleHooks_test): @given( - N=st.integers(1, 4), + N=st.integers(0, 4), Z=st.integers(1, 4), H=st.integers(1, 3), W=st.integers(10, 17), diff --git a/opacus/tests/grad_samples/sequence_bias_test.py b/opacus/tests/grad_samples/sequence_bias_test.py index b61ffc66..ec36d74b 100644 --- a/opacus/tests/grad_samples/sequence_bias_test.py +++ b/opacus/tests/grad_samples/sequence_bias_test.py @@ -23,7 +23,7 @@ class SequenceBias_test(GradSampleHooks_test): @given( - N=st.integers(1, 4), + N=st.integers(0, 4), T=st.integers(10, 20), D=st.integers(4, 8), batch_first=st.booleans(), From 5c7fc6ff0e0cac98d0bb44b0e1836962b8014da3 Mon Sep 17 00:00:00 2001 From: Igor Shilov Date: Thu, 27 Oct 2022 14:36:33 +0100 Subject: [PATCH 5/9] lint --- opacus/privacy_engine.py | 1 - 1 file changed, 1 deletion(-) diff --git a/opacus/privacy_engine.py b/opacus/privacy_engine.py index c1826559..4b46d337 100644 --- a/opacus/privacy_engine.py +++ b/opacus/privacy_engine.py @@ -143,7 +143,6 @@ def __init__(self, *, accountant: str = "rdp", secure_mode: bool = False): "for much faster training performance, but remember to turn it on and retrain " "one last time before production with ``secure_mode`` turned on." ) - pass def _prepare_optimizer( self, From 64f08adff89645a51b116dd516870b9676ab9728 Mon Sep 17 00:00:00 2001 From: Igor Shilov Date: Thu, 27 Oct 2022 15:37:31 +0100 Subject: [PATCH 6/9] EW test fix --- opacus/tests/grad_samples/conv1d_test.py | 4 +++- opacus/tests/grad_samples/conv3d_test.py | 2 +- opacus/tests/grad_samples/embedding_test.py | 2 +- opacus/tests/grad_samples/group_norm_test.py | 2 +- opacus/tests/grad_samples/linear_test.py | 2 +- 5 files changed, 7 insertions(+), 5 deletions(-) diff --git a/opacus/tests/grad_samples/conv1d_test.py b/opacus/tests/grad_samples/conv1d_test.py index 9ad2981a..2576f159 100644 --- a/opacus/tests/grad_samples/conv1d_test.py +++ b/opacus/tests/grad_samples/conv1d_test.py @@ -67,4 +67,6 @@ def test_conv1d( dilation=dilation, groups=groups, ) - self.run_test(x, conv, batch_first=True, atol=10e-5, rtol=10e-4) + self.run_test( + x, conv, batch_first=True, atol=10e-5, rtol=10e-4, ew_compatible=N > 0 + ) diff --git a/opacus/tests/grad_samples/conv3d_test.py b/opacus/tests/grad_samples/conv3d_test.py index 1647ecef..e50909e2 100644 --- a/opacus/tests/grad_samples/conv3d_test.py +++ b/opacus/tests/grad_samples/conv3d_test.py @@ -71,7 +71,7 @@ def test_conv3d( groups=groups, ) is_ew_compatible = ( - dilation == 1 and padding != "same" + dilation == 1 and padding != "same" and N > 0 ) # TODO add support for padding = 'same' with EW self.run_test( x, diff --git a/opacus/tests/grad_samples/embedding_test.py b/opacus/tests/grad_samples/embedding_test.py index e803f2dd..e0142d36 100644 --- a/opacus/tests/grad_samples/embedding_test.py +++ b/opacus/tests/grad_samples/embedding_test.py @@ -56,4 +56,4 @@ def test_input_across_dims( emb = nn.Embedding(V, D) x = torch.randint(low=0, high=V - 1, size=size) - self.run_test(x, emb, batch_first=batch_first) + self.run_test(x, emb, batch_first=batch_first, ew_compatible=N > 0) diff --git a/opacus/tests/grad_samples/group_norm_test.py b/opacus/tests/grad_samples/group_norm_test.py index 4f32f0e0..e3836b93 100644 --- a/opacus/tests/grad_samples/group_norm_test.py +++ b/opacus/tests/grad_samples/group_norm_test.py @@ -54,4 +54,4 @@ def test_3d_input_groups( x = torch.randn([N, C, H, W]) norm = nn.GroupNorm(num_groups=num_groups, num_channels=C, affine=True) - self.run_test(x, norm, batch_first=True) + self.run_test(x, norm, batch_first=True, ew_compatible=N > 0) diff --git a/opacus/tests/grad_samples/linear_test.py b/opacus/tests/grad_samples/linear_test.py index 82ce7409..e856e9d3 100644 --- a/opacus/tests/grad_samples/linear_test.py +++ b/opacus/tests/grad_samples/linear_test.py @@ -57,4 +57,4 @@ def test_input_bias( x = torch.randn(x_shape) if not batch_first: x = x.transpose(0, 1) - self.run_test(x, linear, batch_first=batch_first) + self.run_test(x, linear, batch_first=batch_first, ew_compatible=N > 0) From df7c355064a94fe894eba2c10fe5290fac480388 Mon Sep 17 00:00:00 2001 From: Igor Shilov Date: Thu, 27 Oct 2022 15:59:24 +0100 Subject: [PATCH 7/9] docstring up --- opacus/data_loader.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/opacus/data_loader.py b/opacus/data_loader.py index 8b200d49..884eed88 100644 --- a/opacus/data_loader.py +++ b/opacus/data_loader.py @@ -80,7 +80,7 @@ def dtype_safe(x: Any): x: any object Returns: - ``x.shape`` if attribute exists, empty tuple otherwise + ``x.dtype`` if attribute exists, type of x otherwise """ return x.dtype if hasattr(x, "dtype") else type(x) From 338097cf99e0e970e90b654985df60422b071c20 Mon Sep 17 00:00:00 2001 From: Igor Shilov Date: Mon, 31 Oct 2022 18:41:20 +0000 Subject: [PATCH 8/9] typing improvement --- opacus/data_loader.py | 19 ++++++++++++------- 1 file changed, 12 insertions(+), 7 deletions(-) diff --git a/opacus/data_loader.py b/opacus/data_loader.py index 884eed88..0f45bbfb 100644 --- a/opacus/data_loader.py +++ b/opacus/data_loader.py @@ -13,7 +13,7 @@ # limitations under the License. import logging -from typing import Any, Optional, Sequence +from typing import Any, Optional, Sequence, Tuple, Type, Union import torch from opacus.utils.uniform_sampler import ( @@ -29,9 +29,10 @@ def wrap_collate_with_empty( + *, collate_fn: Optional[_collate_fn_t], - sample_empty_shapes: Sequence[torch.Size], - dtypes: Sequence[torch.dtype], + sample_empty_shapes: Sequence[Tuple], + dtypes: Sequence[Union[torch.dtype, Type]], ): """ Wraps given collate function to handle empty batches. @@ -59,7 +60,7 @@ def collate(batch): return collate -def shape_safe(x: Any): +def shape_safe(x: Any) -> Tuple: """ Exception-safe getter for ``shape`` attribute @@ -72,7 +73,7 @@ def shape_safe(x: Any): return x.shape if hasattr(x, "shape") else () -def dtype_safe(x: Any): +def dtype_safe(x: Any) -> Union[torch.dtype, Type]: """ Exception-safe getter for ``dtype`` attribute @@ -161,7 +162,7 @@ def __init__( sample_rate=sample_rate, generator=generator, ) - sample_empty_shapes = [[0, *shape_safe(x)] for x in dataset[0]] + sample_empty_shapes = [(0, *shape_safe(x)) for x in dataset[0]] dtypes = [dtype_safe(x) for x in dataset[0]] if collate_fn is None: collate_fn = default_collate @@ -175,7 +176,11 @@ def __init__( dataset=dataset, batch_sampler=batch_sampler, num_workers=num_workers, - collate_fn=wrap_collate_with_empty(collate_fn, sample_empty_shapes, dtypes), + collate_fn=wrap_collate_with_empty( + collate_fn=collate_fn, + sample_empty_shapes=sample_empty_shapes, + dtypes=dtypes, + ), pin_memory=pin_memory, timeout=timeout, worker_init_fn=worker_init_fn, From 8b119672ddb11058377f6d0b92d0ecea699f80be Mon Sep 17 00:00:00 2001 From: Igor Shilov Date: Mon, 7 Nov 2022 16:20:23 +0000 Subject: [PATCH 9/9] merge3 --- opacus/tests/grad_samples/common.py | 19 +++++++++++++------ 1 file changed, 13 insertions(+), 6 deletions(-) diff --git a/opacus/tests/grad_samples/common.py b/opacus/tests/grad_samples/common.py index 850c97ab..25ffbdc9 100644 --- a/opacus/tests/grad_samples/common.py +++ b/opacus/tests/grad_samples/common.py @@ -15,7 +15,7 @@ import io import unittest -from typing import Dict, List, Union +from typing import Dict, Iterable, List, Tuple, Union import numpy as np import torch @@ -36,6 +36,13 @@ def shrinker(x, factor: int = 2): return max(1, x // factor) # if avoid returning 0 for x == 1 +def is_batch_empty(batch: Union[torch.Tensor, Iterable[torch.Tensor]]): + if type(batch) is torch.Tensor: + return batch.numel() == 0 + else: + return batch[0].numel() == 0 + + class ModelWithLoss(nn.Module): """ To test the gradients of a module, we need to have a loss. @@ -221,7 +228,7 @@ def compute_opacus_grad_sample( def run_test( self, - x: Union[torch.Tensor, PackedSequence], + x: Union[torch.Tensor, PackedSequence, Tuple], module: nn.Module, batch_first=True, atol=10e-6, @@ -235,9 +242,9 @@ def run_test( except ImportError: grad_sample_modes = ["hooks"] - if (type(x) is not PackedSequence and x.numel() == 0) or type( - module - ) is nn.EmbeddingBag: + if type(module) is nn.EmbeddingBag or ( + type(x) is not PackedSequence and is_batch_empty(x) + ): grad_sample_modes = ["hooks"] for grad_sample_mode in grad_sample_modes: @@ -295,7 +302,7 @@ def run_test_with_reduction( batch_first=batch_first, loss_reduction=loss_reduction, ) - elif x.numel() > 0: + elif not is_batch_empty(x): microbatch_grad_samples = self.compute_microbatch_grad_sample( x, module,