Skip to content

Commit

Permalink
test_add_noise_device
Browse files Browse the repository at this point in the history
  • Loading branch information
hlky committed Jan 8, 2025
1 parent 93dcc72 commit adb4238
Showing 1 changed file with 4 additions and 4 deletions.
8 changes: 4 additions & 4 deletions tests/schedulers/test_schedulers.py
Original file line number Diff line number Diff line change
Expand Up @@ -722,10 +722,10 @@ def test_add_noise_device(self):
scaled_sample = scheduler.scale_model_input(sample, 0.0)
self.assertEqual(sample.shape, scaled_sample.shape)

noise = torch.randn_like(scaled_sample).to(torch_device)
# t = scheduler.timesteps[5].expand(noise.shape[0])
# noised = scheduler.add_noise(scaled_sample, noise, t)
# self.assertEqual(noised.shape, scaled_sample.shape)
noise = torch.randn(scaled_sample.shape).to(torch_device)
t = scheduler.timesteps[5][None]
noised = scheduler.add_noise(scaled_sample, noise, t)
self.assertEqual(noised.shape, scaled_sample.shape)

def test_deprecated_kwargs(self):
for scheduler_class in self.scheduler_classes:
Expand Down

0 comments on commit adb4238

Please sign in to comment.