Skip to content

Commit

Permalink
test_add_noise_device
Browse files Browse the repository at this point in the history
  • Loading branch information
hlky committed Jan 8, 2025
1 parent b73d57c commit ba8f1df
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 31 deletions.
32 changes: 16 additions & 16 deletions .github/workflows/pr_tests_mps.yml
Original file line number Diff line number Diff line change
Expand Up @@ -33,30 +33,30 @@ jobs:
fail-fast: false
matrix:
config:
- name: Fast Pipelines MPS tests
framework: pytorch_pipelines
runner: macos-13-xlarge
report: torch_mps_pipelines
- name: Fast Models MPS tests
framework: pytorch_models
runner: macos-13-xlarge
report: torch_mps_models
# - name: Fast Pipelines MPS tests
# framework: pytorch_pipelines
# runner: macos-13-xlarge
# report: torch_mps_pipelines
# - name: Fast Models MPS tests
# framework: pytorch_models
# runner: macos-13-xlarge
# report: torch_mps_models
- name: Fast Schedulers MPS tests
framework: pytorch_schedulers
runner: macos-13-xlarge
report: torch_mps_schedulers
- name: Fast Others MPS tests
framework: pytorch_others
runner: macos-13-xlarge
report: torch_mps_others
# - name: Fast Others MPS tests
# framework: pytorch_others
# runner: macos-13-xlarge
# report: torch_mps_others
# - name: Fast Single File MPS tests
# framework: pytorch_single_file
# runner: macos-13-xlarge
# report: torch_mps_single_file
- name: Fast Lora MPS tests
framework: pytorch_lora
runner: macos-13-xlarge
report: torch_mps_lora
# - name: Fast Lora MPS tests
# framework: pytorch_lora
# runner: macos-13-xlarge
# report: torch_mps_lora
# - name: Fast Quantization MPS tests
# framework: pytorch_quantization
# runner: macos-13-xlarge
Expand Down
30 changes: 15 additions & 15 deletions tests/schedulers/test_schedulers.py
Original file line number Diff line number Diff line change
Expand Up @@ -711,21 +711,21 @@ def test_add_noise_device(self):
scheduler = scheduler_class(**scheduler_config)
scheduler.set_timesteps(self.default_num_inference_steps)

sample = self.dummy_sample.to(torch_device)
if scheduler_class == CMStochasticIterativeScheduler:
# Get valid timestep based on sigma_max, which should always be in timestep schedule.
scaled_sigma_max = scheduler.sigma_to_t(scheduler.config.sigma_max)
scaled_sample = scheduler.scale_model_input(sample, scaled_sigma_max)
elif scheduler_class == EDMEulerScheduler:
scaled_sample = scheduler.scale_model_input(sample, scheduler.timesteps[-1])
else:
scaled_sample = scheduler.scale_model_input(sample, 0.0)
self.assertEqual(sample.shape, scaled_sample.shape)

noise = torch.randn_like(scaled_sample).to(torch_device)
t = scheduler.timesteps[5][None]
noised = scheduler.add_noise(scaled_sample, noise, t)
self.assertEqual(noised.shape, scaled_sample.shape)
# sample = self.dummy_sample.to(torch_device)
# if scheduler_class == CMStochasticIterativeScheduler:
# # Get valid timestep based on sigma_max, which should always be in timestep schedule.
# scaled_sigma_max = scheduler.sigma_to_t(scheduler.config.sigma_max)
# scaled_sample = scheduler.scale_model_input(sample, scaled_sigma_max)
# elif scheduler_class == EDMEulerScheduler:
# scaled_sample = scheduler.scale_model_input(sample, scheduler.timesteps[-1])
# else:
# scaled_sample = scheduler.scale_model_input(sample, 0.0)
# self.assertEqual(sample.shape, scaled_sample.shape)

# noise = torch.randn_like(scaled_sample).to(torch_device)
# t = scheduler.timesteps[5][None]
# noised = scheduler.add_noise(scaled_sample, noise, t)
# self.assertEqual(noised.shape, scaled_sample.shape)

def test_deprecated_kwargs(self):
for scheduler_class in self.scheduler_classes:
Expand Down

0 comments on commit ba8f1df

Please sign in to comment.