Skip to content

Commit

Permalink
test_add_noise_device
Browse files Browse the repository at this point in the history
  • Loading branch information
hlky committed Jan 8, 2025
1 parent f7fb73e commit ac2b820
Showing 1 changed file with 3 additions and 3 deletions.
6 changes: 3 additions & 3 deletions tests/schedulers/test_schedulers.py
Original file line number Diff line number Diff line change
Expand Up @@ -723,9 +723,9 @@ def test_add_noise_device(self):
self.assertEqual(sample.shape, scaled_sample.shape)

noise = torch.randn_like(scaled_sample).to(torch_device)
t = scheduler.timesteps[5][None]
# noised = scheduler.add_noise(scaled_sample, noise, t)
# self.assertEqual(noised.shape, scaled_sample.shape)
t = scheduler.timesteps[5].expand(noise.shape[0])
noised = scheduler.add_noise(scaled_sample, noise, t)
self.assertEqual(noised.shape, scaled_sample.shape)

def test_deprecated_kwargs(self):
for scheduler_class in self.scheduler_classes:
Expand Down

0 comments on commit ac2b820

Please sign in to comment.