diff --git a/tests/schedulers/test_schedulers.py b/tests/schedulers/test_schedulers.py index a34454281a96d..159286bbf69ce 100755 --- a/tests/schedulers/test_schedulers.py +++ b/tests/schedulers/test_schedulers.py @@ -722,8 +722,8 @@ def test_add_noise_device(self): scaled_sample = scheduler.scale_model_input(sample, 0.0) self.assertEqual(sample.shape, scaled_sample.shape) - # noise = torch.randn_like(scaled_sample).to(torch_device) - # t = scheduler.timesteps[5][None] + noise = torch.randn_like(scaled_sample).to(torch_device) + t = scheduler.timesteps[5][None] # noised = scheduler.add_noise(scaled_sample, noise, t) # self.assertEqual(noised.shape, scaled_sample.shape)