From 67eee5e4916dcf04aae55b10ea2fa5f4799dde81 Mon Sep 17 00:00:00 2001 From: Adrian Punga Date: Wed, 27 Dec 2023 22:11:48 +0200 Subject: [PATCH] Fix support for MPS MPS doesn't support float64 --- .../schedulers/scheduling_k_dpm_2_ancestral_discrete.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/src/diffusers/schedulers/scheduling_k_dpm_2_ancestral_discrete.py b/src/diffusers/schedulers/scheduling_k_dpm_2_ancestral_discrete.py index dbf0984ed503..523b1f4f3b96 100644 --- a/src/diffusers/schedulers/scheduling_k_dpm_2_ancestral_discrete.py +++ b/src/diffusers/schedulers/scheduling_k_dpm_2_ancestral_discrete.py @@ -277,7 +277,11 @@ def set_timesteps( self.sigmas_up = torch.cat([sigmas_up[:1], sigmas_up[1:].repeat_interleave(2), sigmas_up[-1:]]) self.sigmas_down = torch.cat([sigmas_down[:1], sigmas_down[1:].repeat_interleave(2), sigmas_down[-1:]]) - timesteps = torch.from_numpy(timesteps).to(device) + if str(device).startswith("mps"): + timesteps = torch.from_numpy(timesteps).to(device, dtype=torch.float32) + else: + timesteps = torch.from_numpy(timesteps).to(device) + sigmas_interpol = sigmas_interpol.cpu() log_sigmas = self.log_sigmas.cpu() timesteps_interpol = np.array(