diff --git a/examples/community/fresco_v2v.py b/examples/community/fresco_v2v.py index 2784e2f238f6..d6c2683f1d86 100644 --- a/examples/community/fresco_v2v.py +++ b/examples/community/fresco_v2v.py @@ -404,10 +404,11 @@ def forward( # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can # This would be a good case for the `match` statement (Python 3.10+) is_mps = sample.device.type == "mps" + is_npu = sample.device.type == "npu" if isinstance(timestep, float): - dtype = torch.float32 if is_mps else torch.float64 + dtype = torch.float32 if (is_mps or is_npu) else torch.float64 else: - dtype = torch.int32 if is_mps else torch.int64 + dtype = torch.int32 if (is_mps or is_npu) else torch.int64 timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device) elif len(timesteps.shape) == 0: timesteps = timesteps[None].to(sample.device) diff --git a/examples/community/matryoshka.py b/examples/community/matryoshka.py index f80b29456c60..1d7a367ecc60 100644 --- a/examples/community/matryoshka.py +++ b/examples/community/matryoshka.py @@ -2806,10 +2806,11 @@ def get_time_embed( # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can # This would be a good case for the `match` statement (Python 3.10+) is_mps = sample.device.type == "mps" + is_npu = sample.device.type == "npu" if isinstance(timestep, float): - dtype = torch.float32 if is_mps else torch.float64 + dtype = torch.float32 if (is_mps or is_npu) else torch.float64 else: - dtype = torch.int32 if is_mps else torch.int64 + dtype = torch.int32 if (is_mps or is_npu) else torch.int64 timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device) elif len(timesteps.shape) == 0: timesteps = timesteps[None].to(sample.device) diff --git a/examples/research_projects/pixart/pipeline_pixart_alpha_controlnet.py b/examples/research_projects/pixart/pipeline_pixart_alpha_controlnet.py index d7f882974a22..4065a854c22d 100644 --- a/examples/research_projects/pixart/pipeline_pixart_alpha_controlnet.py +++ b/examples/research_projects/pixart/pipeline_pixart_alpha_controlnet.py @@ -1031,10 +1031,11 @@ def __call__( # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can # This would be a good case for the `match` statement (Python 3.10+) is_mps = latent_model_input.device.type == "mps" + is_npu = latent_model_input.device.type == "npu" if isinstance(current_timestep, float): - dtype = torch.float32 if is_mps else torch.float64 + dtype = torch.float32 if (is_mps or is_npu) else torch.float64 else: - dtype = torch.int32 if is_mps else torch.int64 + dtype = torch.int32 if (is_mps or is_npu) else torch.int64 current_timestep = torch.tensor([current_timestep], dtype=dtype, device=latent_model_input.device) elif len(current_timestep.shape) == 0: current_timestep = current_timestep[None].to(latent_model_input.device) diff --git a/examples/research_projects/promptdiffusion/promptdiffusioncontrolnet.py b/examples/research_projects/promptdiffusion/promptdiffusioncontrolnet.py index 6b1826a1c92d..7853695f0566 100644 --- a/examples/research_projects/promptdiffusion/promptdiffusioncontrolnet.py +++ b/examples/research_projects/promptdiffusion/promptdiffusioncontrolnet.py @@ -258,10 +258,11 @@ def forward( # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can # This would be a good case for the `match` statement (Python 3.10+) is_mps = sample.device.type == "mps" + is_npu = sample.device.type == "npu" if isinstance(timestep, float): - dtype = torch.float32 if is_mps else torch.float64 + dtype = torch.float32 if (is_mps or is_npu) else torch.float64 else: - dtype = torch.int32 if is_mps else torch.int64 + dtype = torch.int32 if (is_mps or is_npu) else torch.int64 timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device) elif len(timesteps.shape) == 0: timesteps = timesteps[None].to(sample.device) diff --git a/src/diffusers/models/controlnets/controlnet.py b/src/diffusers/models/controlnets/controlnet.py index bd00f6dd1906..1453aaf4362c 100644 --- a/src/diffusers/models/controlnets/controlnet.py +++ b/src/diffusers/models/controlnets/controlnet.py @@ -740,10 +740,11 @@ def forward( # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can # This would be a good case for the `match` statement (Python 3.10+) is_mps = sample.device.type == "mps" + is_npu = sample.device.type == "npu" if isinstance(timestep, float): - dtype = torch.float32 if is_mps else torch.float64 + dtype = torch.float32 if (is_mps or is_npu) else torch.float64 else: - dtype = torch.int32 if is_mps else torch.int64 + dtype = torch.int32 if (is_mps or is_npu) else torch.int64 timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device) elif len(timesteps.shape) == 0: timesteps = timesteps[None].to(sample.device) diff --git a/src/diffusers/models/controlnets/controlnet_sparsectrl.py b/src/diffusers/models/controlnets/controlnet_sparsectrl.py index fd599c10b2d7..807cbd339ef9 100644 --- a/src/diffusers/models/controlnets/controlnet_sparsectrl.py +++ b/src/diffusers/models/controlnets/controlnet_sparsectrl.py @@ -671,10 +671,11 @@ def forward( # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can # This would be a good case for the `match` statement (Python 3.10+) is_mps = sample.device.type == "mps" + is_npu = sample.device.type == "npu" if isinstance(timestep, float): - dtype = torch.float32 if is_mps else torch.float64 + dtype = torch.float32 if (is_mps or is_npu) else torch.float64 else: - dtype = torch.int32 if is_mps else torch.int64 + dtype = torch.int32 if (is_mps or is_npu) else torch.int64 timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device) elif len(timesteps.shape) == 0: timesteps = timesteps[None].to(sample.device) diff --git a/src/diffusers/models/controlnets/controlnet_union.py b/src/diffusers/models/controlnets/controlnet_union.py index fc80da76235b..1bf176101c61 100644 --- a/src/diffusers/models/controlnets/controlnet_union.py +++ b/src/diffusers/models/controlnets/controlnet_union.py @@ -681,10 +681,11 @@ def forward( # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can # This would be a good case for the `match` statement (Python 3.10+) is_mps = sample.device.type == "mps" + is_npu = sample.device.type == "npu" if isinstance(timestep, float): - dtype = torch.float32 if is_mps else torch.float64 + dtype = torch.float32 if (is_mps or is_npu) else torch.float64 else: - dtype = torch.int32 if is_mps else torch.int64 + dtype = torch.int32 if (is_mps or is_npu) else torch.int64 timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device) elif len(timesteps.shape) == 0: timesteps = timesteps[None].to(sample.device) diff --git a/src/diffusers/models/controlnets/controlnet_xs.py b/src/diffusers/models/controlnets/controlnet_xs.py index 11ad676ec92b..8a8901d82d90 100644 --- a/src/diffusers/models/controlnets/controlnet_xs.py +++ b/src/diffusers/models/controlnets/controlnet_xs.py @@ -1088,10 +1088,11 @@ def forward( # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can # This would be a good case for the `match` statement (Python 3.10+) is_mps = sample.device.type == "mps" + is_npu = sample.device.type == "npu" if isinstance(timestep, float): - dtype = torch.float32 if is_mps else torch.float64 + dtype = torch.float32 if (is_mps or is_npu) else torch.float64 else: - dtype = torch.int32 if is_mps else torch.int64 + dtype = torch.int32 if (is_mps or is_npu) else torch.int64 timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device) elif len(timesteps.shape) == 0: timesteps = timesteps[None].to(sample.device) diff --git a/src/diffusers/models/unets/unet_2d_condition.py b/src/diffusers/models/unets/unet_2d_condition.py index e488f5897ebc..2b896f89e484 100644 --- a/src/diffusers/models/unets/unet_2d_condition.py +++ b/src/diffusers/models/unets/unet_2d_condition.py @@ -915,10 +915,11 @@ def get_time_embed( # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can # This would be a good case for the `match` statement (Python 3.10+) is_mps = sample.device.type == "mps" + is_npu = sample.device.type == "npu" if isinstance(timestep, float): - dtype = torch.float32 if is_mps else torch.float64 + dtype = torch.float32 if (is_mps or is_npu) else torch.float64 else: - dtype = torch.int32 if is_mps else torch.int64 + dtype = torch.int32 if (is_mps or is_npu) else torch.int64 timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device) elif len(timesteps.shape) == 0: timesteps = timesteps[None].to(sample.device) diff --git a/src/diffusers/models/unets/unet_3d_condition.py b/src/diffusers/models/unets/unet_3d_condition.py index 3081fdc4700c..56739ac24c11 100644 --- a/src/diffusers/models/unets/unet_3d_condition.py +++ b/src/diffusers/models/unets/unet_3d_condition.py @@ -624,10 +624,11 @@ def forward( # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can # This would be a good case for the `match` statement (Python 3.10+) is_mps = sample.device.type == "mps" + is_npu = sample.device.type == "npu" if isinstance(timestep, float): - dtype = torch.float32 if is_mps else torch.float64 + dtype = torch.float32 if (is_mps or is_npu) else torch.float64 else: - dtype = torch.int32 if is_mps else torch.int64 + dtype = torch.int32 if (is_mps or is_npu) else torch.int64 timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device) elif len(timesteps.shape) == 0: timesteps = timesteps[None].to(sample.device) diff --git a/src/diffusers/models/unets/unet_i2vgen_xl.py b/src/diffusers/models/unets/unet_i2vgen_xl.py index 6ab3a577b892..d5d98c256357 100644 --- a/src/diffusers/models/unets/unet_i2vgen_xl.py +++ b/src/diffusers/models/unets/unet_i2vgen_xl.py @@ -575,10 +575,11 @@ def forward( # TODO: this requires sync between CPU and GPU. So try to pass `timesteps` as tensors if you can # This would be a good case for the `match` statement (Python 3.10+) is_mps = sample.device.type == "mps" + is_npu = sample.device.type == "npu" if isinstance(timesteps, float): - dtype = torch.float32 if is_mps else torch.float64 + dtype = torch.float32 if (is_mps or is_npu) else torch.float64 else: - dtype = torch.int32 if is_mps else torch.int64 + dtype = torch.int32 if (is_mps or is_npu) else torch.int64 timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device) elif len(timesteps.shape) == 0: timesteps = timesteps[None].to(sample.device) diff --git a/src/diffusers/models/unets/unet_motion_model.py b/src/diffusers/models/unets/unet_motion_model.py index ddc3e41c340d..1c07a0760f62 100644 --- a/src/diffusers/models/unets/unet_motion_model.py +++ b/src/diffusers/models/unets/unet_motion_model.py @@ -2114,10 +2114,11 @@ def forward( # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can # This would be a good case for the `match` statement (Python 3.10+) is_mps = sample.device.type == "mps" + is_npu = sample.device.type == "npu" if isinstance(timestep, float): - dtype = torch.float32 if is_mps else torch.float64 + dtype = torch.float32 if (is_mps or is_npu) else torch.float64 else: - dtype = torch.int32 if is_mps else torch.int64 + dtype = torch.int32 if (is_mps or is_npu) else torch.int64 timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device) elif len(timesteps.shape) == 0: timesteps = timesteps[None].to(sample.device) diff --git a/src/diffusers/models/unets/unet_spatio_temporal_condition.py b/src/diffusers/models/unets/unet_spatio_temporal_condition.py index 308b9e01c587..172c1e6bbb05 100644 --- a/src/diffusers/models/unets/unet_spatio_temporal_condition.py +++ b/src/diffusers/models/unets/unet_spatio_temporal_condition.py @@ -402,10 +402,11 @@ def forward( # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can # This would be a good case for the `match` statement (Python 3.10+) is_mps = sample.device.type == "mps" + is_npu = sample.device.type == "npu" if isinstance(timestep, float): - dtype = torch.float32 if is_mps else torch.float64 + dtype = torch.float32 if (is_mps or is_npu) else torch.float64 else: - dtype = torch.int32 if is_mps else torch.int64 + dtype = torch.int32 if (is_mps or is_npu) else torch.int64 timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device) elif len(timesteps.shape) == 0: timesteps = timesteps[None].to(sample.device) diff --git a/src/diffusers/pipelines/audioldm2/modeling_audioldm2.py b/src/diffusers/pipelines/audioldm2/modeling_audioldm2.py index 63d3957ae17d..a33e26568772 100644 --- a/src/diffusers/pipelines/audioldm2/modeling_audioldm2.py +++ b/src/diffusers/pipelines/audioldm2/modeling_audioldm2.py @@ -768,10 +768,11 @@ def forward( # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can # This would be a good case for the `match` statement (Python 3.10+) is_mps = sample.device.type == "mps" + is_npu = sample.device.type == "npu" if isinstance(timestep, float): - dtype = torch.float32 if is_mps else torch.float64 + dtype = torch.float32 if (is_mps or is_npu) else torch.float64 else: - dtype = torch.int32 if is_mps else torch.int64 + dtype = torch.int32 if (is_mps or is_npu) else torch.int64 timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device) elif len(timesteps.shape) == 0: timesteps = timesteps[None].to(sample.device) diff --git a/src/diffusers/pipelines/deprecated/versatile_diffusion/modeling_text_unet.py b/src/diffusers/pipelines/deprecated/versatile_diffusion/modeling_text_unet.py index 0fd8875a88a1..4d9e50e3a2b4 100644 --- a/src/diffusers/pipelines/deprecated/versatile_diffusion/modeling_text_unet.py +++ b/src/diffusers/pipelines/deprecated/versatile_diffusion/modeling_text_unet.py @@ -1163,10 +1163,11 @@ def forward( # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can # This would be a good case for the `match` statement (Python 3.10+) is_mps = sample.device.type == "mps" + is_npu = sample.device.type == "npu" if isinstance(timestep, float): - dtype = torch.float32 if is_mps else torch.float64 + dtype = torch.float32 if (is_mps or is_npu) else torch.float64 else: - dtype = torch.int32 if is_mps else torch.int64 + dtype = torch.int32 if (is_mps or is_npu) else torch.int64 timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device) elif len(timesteps.shape) == 0: timesteps = timesteps[None].to(sample.device) diff --git a/src/diffusers/pipelines/dit/pipeline_dit.py b/src/diffusers/pipelines/dit/pipeline_dit.py index cf5ebbce2ba8..8aee0fadaf69 100644 --- a/src/diffusers/pipelines/dit/pipeline_dit.py +++ b/src/diffusers/pipelines/dit/pipeline_dit.py @@ -187,10 +187,11 @@ def __call__( # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can # This would be a good case for the `match` statement (Python 3.10+) is_mps = latent_model_input.device.type == "mps" + is_npu = latent_model_input.device.type == "npu" if isinstance(timesteps, float): - dtype = torch.float32 if is_mps else torch.float64 + dtype = torch.float32 if (is_mps or is_npu) else torch.float64 else: - dtype = torch.int32 if is_mps else torch.int64 + dtype = torch.int32 if (is_mps or is_npu) else torch.int64 timesteps = torch.tensor([timesteps], dtype=dtype, device=latent_model_input.device) elif len(timesteps.shape) == 0: timesteps = timesteps[None].to(latent_model_input.device) diff --git a/src/diffusers/pipelines/latte/pipeline_latte.py b/src/diffusers/pipelines/latte/pipeline_latte.py index 852a2b7b795e..792a0575213e 100644 --- a/src/diffusers/pipelines/latte/pipeline_latte.py +++ b/src/diffusers/pipelines/latte/pipeline_latte.py @@ -797,10 +797,11 @@ def __call__( # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can # This would be a good case for the `match` statement (Python 3.10+) is_mps = latent_model_input.device.type == "mps" + is_npu = latent_model_input.device.type == "npu" if isinstance(current_timestep, float): - dtype = torch.float32 if is_mps else torch.float64 + dtype = torch.float32 if (is_mps or is_npu) else torch.float64 else: - dtype = torch.int32 if is_mps else torch.int64 + dtype = torch.int32 if (is_mps or is_npu) else torch.int64 current_timestep = torch.tensor([current_timestep], dtype=dtype, device=latent_model_input.device) elif len(current_timestep.shape) == 0: current_timestep = current_timestep[None].to(latent_model_input.device) diff --git a/src/diffusers/pipelines/lumina/pipeline_lumina.py b/src/diffusers/pipelines/lumina/pipeline_lumina.py index 52bb6546031d..5b37e9a503a8 100644 --- a/src/diffusers/pipelines/lumina/pipeline_lumina.py +++ b/src/diffusers/pipelines/lumina/pipeline_lumina.py @@ -806,10 +806,11 @@ def __call__( # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can # This would be a good case for the `match` statement (Python 3.10+) is_mps = latent_model_input.device.type == "mps" + is_npu = latent_model_input.device.type == "npu" if isinstance(current_timestep, float): - dtype = torch.float32 if is_mps else torch.float64 + dtype = torch.float32 if (is_mps or is_npu) else torch.float64 else: - dtype = torch.int32 if is_mps else torch.int64 + dtype = torch.int32 if (is_mps or is_npu) else torch.int64 current_timestep = torch.tensor( [current_timestep], dtype=dtype, diff --git a/src/diffusers/pipelines/pag/pipeline_pag_pixart_sigma.py b/src/diffusers/pipelines/pag/pipeline_pag_pixart_sigma.py index d927a7961a16..affda7e18add 100644 --- a/src/diffusers/pipelines/pag/pipeline_pag_pixart_sigma.py +++ b/src/diffusers/pipelines/pag/pipeline_pag_pixart_sigma.py @@ -807,10 +807,11 @@ def __call__( # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can # This would be a good case for the `match` statement (Python 3.10+) is_mps = latent_model_input.device.type == "mps" + is_npu = latent_model_input.device.type == "npu" if isinstance(current_timestep, float): - dtype = torch.float32 if is_mps else torch.float64 + dtype = torch.float32 if (is_mps or is_npu) else torch.float64 else: - dtype = torch.int32 if is_mps else torch.int64 + dtype = torch.int32 if (is_mps or is_npu) else torch.int64 current_timestep = torch.tensor([current_timestep], dtype=dtype, device=latent_model_input.device) elif len(current_timestep.shape) == 0: current_timestep = current_timestep[None].to(latent_model_input.device) diff --git a/src/diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py b/src/diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py index 46a7337051ef..b550a442fe15 100644 --- a/src/diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py +++ b/src/diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py @@ -907,10 +907,11 @@ def __call__( # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can # This would be a good case for the `match` statement (Python 3.10+) is_mps = latent_model_input.device.type == "mps" + is_npu = latent_model_input.device.type == "npu" if isinstance(current_timestep, float): - dtype = torch.float32 if is_mps else torch.float64 + dtype = torch.float32 if (is_mps or is_npu) else torch.float64 else: - dtype = torch.int32 if is_mps else torch.int64 + dtype = torch.int32 if (is_mps or is_npu) else torch.int64 current_timestep = torch.tensor([current_timestep], dtype=dtype, device=latent_model_input.device) elif len(current_timestep.shape) == 0: current_timestep = current_timestep[None].to(latent_model_input.device) diff --git a/src/diffusers/pipelines/pixart_alpha/pipeline_pixart_sigma.py b/src/diffusers/pipelines/pixart_alpha/pipeline_pixart_sigma.py index 356ba3a29af3..7f10ee89ee04 100644 --- a/src/diffusers/pipelines/pixart_alpha/pipeline_pixart_sigma.py +++ b/src/diffusers/pipelines/pixart_alpha/pipeline_pixart_sigma.py @@ -822,10 +822,11 @@ def __call__( # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can # This would be a good case for the `match` statement (Python 3.10+) is_mps = latent_model_input.device.type == "mps" + is_npu = latent_model_input.device.type == "npu" if isinstance(current_timestep, float): - dtype = torch.float32 if is_mps else torch.float64 + dtype = torch.float32 if (is_mps or is_npu) else torch.float64 else: - dtype = torch.int32 if is_mps else torch.int64 + dtype = torch.int32 if (is_mps or is_npu) else torch.int64 current_timestep = torch.tensor([current_timestep], dtype=dtype, device=latent_model_input.device) elif len(current_timestep.shape) == 0: current_timestep = current_timestep[None].to(latent_model_input.device)