Skip to content

Commit

Permalink
Fix support for MPS in KDPM2AncestralDiscreteScheduler (#6365)
Browse files Browse the repository at this point in the history
Fix support for MPS

MPS doesn't support float64
  • Loading branch information
adi authored Dec 28, 2023
1 parent 4c483de commit 84d7fae
Showing 1 changed file with 5 additions and 1 deletion.
Original file line number Diff line number Diff line change
Expand Up @@ -277,7 +277,11 @@ def set_timesteps(
self.sigmas_up = torch.cat([sigmas_up[:1], sigmas_up[1:].repeat_interleave(2), sigmas_up[-1:]])
self.sigmas_down = torch.cat([sigmas_down[:1], sigmas_down[1:].repeat_interleave(2), sigmas_down[-1:]])

timesteps = torch.from_numpy(timesteps).to(device)
if str(device).startswith("mps"):
timesteps = torch.from_numpy(timesteps).to(device, dtype=torch.float32)
else:
timesteps = torch.from_numpy(timesteps).to(device)

sigmas_interpol = sigmas_interpol.cpu()
log_sigmas = self.log_sigmas.cpu()
timesteps_interpol = np.array(
Expand Down

0 comments on commit 84d7fae

Please sign in to comment.