diff --git a/tests/schedulers/test_schedulers.py b/tests/schedulers/test_schedulers.py index 159286bbf69ce..0abc43a2c5192 100755 --- a/tests/schedulers/test_schedulers.py +++ b/tests/schedulers/test_schedulers.py @@ -723,9 +723,9 @@ def test_add_noise_device(self): self.assertEqual(sample.shape, scaled_sample.shape) noise = torch.randn_like(scaled_sample).to(torch_device) - t = scheduler.timesteps[5][None] - # noised = scheduler.add_noise(scaled_sample, noise, t) - # self.assertEqual(noised.shape, scaled_sample.shape) + t = scheduler.timesteps[5].expand(noise.shape[0]) + noised = scheduler.add_noise(scaled_sample, noise, t) + self.assertEqual(noised.shape, scaled_sample.shape) def test_deprecated_kwargs(self): for scheduler_class in self.scheduler_classes: