Skip to content

Commit

Permalink
Fixed Opacus's Runtime error with an empty batch (issue 612) (#631)
Browse files Browse the repository at this point in the history
Summary:

In case of an empty batch, in the ```clip_and_accumulate``` function, the ```per_sample_clip_factor``` variable is set to a tensor of size 0. However, the device was not specified, which throws a runtime error. Added it.

Differential Revision: D53733081
  • Loading branch information
EnayatUllah authored and facebook-github-bot committed Feb 20, 2024
1 parent a7c2853 commit 52d9797
Showing 1 changed file with 3 additions and 1 deletion.
4 changes: 3 additions & 1 deletion opacus/optimizers/optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -396,7 +396,9 @@ def clip_and_accumulate(self):

if len(self.grad_samples[0]) == 0:
# Empty batch
per_sample_clip_factor = torch.zeros((0,))
per_sample_clip_factor = torch.zeros(
(0,), device=self.grad_samples[0].device
)
else:
per_param_norms = [
g.reshape(len(g), -1).norm(2, dim=-1) for g in self.grad_samples
Expand Down

0 comments on commit 52d9797

Please sign in to comment.