Skip to content

Commit

Permalink
is_mps is_npu
Browse files Browse the repository at this point in the history
  • Loading branch information
hlky committed Jan 10, 2025
1 parent 3454384 commit 8fe6408
Show file tree
Hide file tree
Showing 21 changed files with 83 additions and 62 deletions.
7 changes: 4 additions & 3 deletions examples/community/fresco_v2v.py
Original file line number Diff line number Diff line change
Expand Up @@ -403,11 +403,12 @@ def forward(
if not torch.is_tensor(timesteps):
# TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
# This would be a good case for the `match` statement (Python 3.10+)
is_mps_or_npu = sample.device.type == "mps" or sample.device.type == "npu"
is_mps = sample.device.type == "mps"
is_npu = sample.device.type == "npu"
if isinstance(timestep, float):
dtype = torch.float32 if is_mps_or_npu else torch.float64
dtype = torch.float32 if (is_mps or is_npu) else torch.float64
else:
dtype = torch.int32 if is_mps_or_npu else torch.int64
dtype = torch.int32 if (is_mps or is_npu) else torch.int64
timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)
elif len(timesteps.shape) == 0:
timesteps = timesteps[None].to(sample.device)
Expand Down
7 changes: 4 additions & 3 deletions examples/community/matryoshka.py
Original file line number Diff line number Diff line change
Expand Up @@ -2805,11 +2805,12 @@ def get_time_embed(
if not torch.is_tensor(timesteps):
# TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
# This would be a good case for the `match` statement (Python 3.10+)
is_mps_or_npu = sample.device.type == "mps" or sample.device.type == "npu"
is_mps = sample.device.type == "mps"
is_npu = sample.device.type == "npu"
if isinstance(timestep, float):
dtype = torch.float32 if is_mps_or_npu else torch.float64
dtype = torch.float32 if (is_mps or is_npu) else torch.float64
else:
dtype = torch.int32 if is_mps_or_npu else torch.int64
dtype = torch.int32 if (is_mps or is_npu) else torch.int64
timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)
elif len(timesteps.shape) == 0:
timesteps = timesteps[None].to(sample.device)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1030,11 +1030,12 @@ def __call__(
if not torch.is_tensor(current_timestep):
# TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
# This would be a good case for the `match` statement (Python 3.10+)
is_mps_or_npu = latent_model_input.device.type == "mps" or latent_model_input.device.type == "npu"
is_mps = latent_model_input.device.type == "mps"
is_npu = latent_model_input.device.type == "npu"
if isinstance(current_timestep, float):
dtype = torch.float32 if is_mps_or_npu else torch.float64
dtype = torch.float32 if (is_mps or is_npu) else torch.float64
else:
dtype = torch.int32 if is_mps_or_npu else torch.int64
dtype = torch.int32 if (is_mps or is_npu) else torch.int64
current_timestep = torch.tensor([current_timestep], dtype=dtype, device=latent_model_input.device)
elif len(current_timestep.shape) == 0:
current_timestep = current_timestep[None].to(latent_model_input.device)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -257,11 +257,12 @@ def forward(
if not torch.is_tensor(timesteps):
# TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
# This would be a good case for the `match` statement (Python 3.10+)
is_mps_or_npu = sample.device.type == "mps" or sample.device.type == "npu"
is_mps = sample.device.type == "mps"
is_npu = sample.device.type == "npu"
if isinstance(timestep, float):
dtype = torch.float32 if is_mps_or_npu else torch.float64
dtype = torch.float32 if (is_mps or is_npu) else torch.float64
else:
dtype = torch.int32 if is_mps_or_npu else torch.int64
dtype = torch.int32 if (is_mps or is_npu) else torch.int64
timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)
elif len(timesteps.shape) == 0:
timesteps = timesteps[None].to(sample.device)
Expand Down
7 changes: 4 additions & 3 deletions src/diffusers/models/controlnets/controlnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -739,11 +739,12 @@ def forward(
if not torch.is_tensor(timesteps):
# TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
# This would be a good case for the `match` statement (Python 3.10+)
is_mps_or_npu = sample.device.type == "mps" or sample.device.type == "npu"
is_mps = sample.device.type == "mps"
is_npu = sample.device.type == "npu"
if isinstance(timestep, float):
dtype = torch.float32 if is_mps_or_npu else torch.float64
dtype = torch.float32 if (is_mps or is_npu) else torch.float64
else:
dtype = torch.int32 if is_mps_or_npu else torch.int64
dtype = torch.int32 if (is_mps or is_npu) else torch.int64
timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)
elif len(timesteps.shape) == 0:
timesteps = timesteps[None].to(sample.device)
Expand Down
7 changes: 4 additions & 3 deletions src/diffusers/models/controlnets/controlnet_sparsectrl.py
Original file line number Diff line number Diff line change
Expand Up @@ -670,11 +670,12 @@ def forward(
if not torch.is_tensor(timesteps):
# TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
# This would be a good case for the `match` statement (Python 3.10+)
is_mps_or_npu = sample.device.type == "mps" or sample.device.type == "npu"
is_mps = sample.device.type == "mps"
is_npu = sample.device.type == "npu"
if isinstance(timestep, float):
dtype = torch.float32 if is_mps_or_npu else torch.float64
dtype = torch.float32 if (is_mps or is_npu) else torch.float64
else:
dtype = torch.int32 if is_mps_or_npu else torch.int64
dtype = torch.int32 if (is_mps or is_npu) else torch.int64
timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)
elif len(timesteps.shape) == 0:
timesteps = timesteps[None].to(sample.device)
Expand Down
5 changes: 3 additions & 2 deletions src/diffusers/models/controlnets/controlnet_union.py
Original file line number Diff line number Diff line change
Expand Up @@ -681,10 +681,11 @@ def forward(
# TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
# This would be a good case for the `match` statement (Python 3.10+)
is_mps = sample.device.type == "mps"
is_npu = sample.device.type == "npu"
if isinstance(timestep, float):
dtype = torch.float32 if is_mps else torch.float64
dtype = torch.float32 if (is_mps or is_npu) else torch.float64
else:
dtype = torch.int32 if is_mps else torch.int64
dtype = torch.int32 if (is_mps or is_npu) else torch.int64
timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)
elif len(timesteps.shape) == 0:
timesteps = timesteps[None].to(sample.device)
Expand Down
7 changes: 4 additions & 3 deletions src/diffusers/models/controlnets/controlnet_xs.py
Original file line number Diff line number Diff line change
Expand Up @@ -1087,11 +1087,12 @@ def forward(
if not torch.is_tensor(timesteps):
# TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
# This would be a good case for the `match` statement (Python 3.10+)
is_mps_or_npu = sample.device.type == "mps" or sample.device.type == "npu"
is_mps = sample.device.type == "mps"
is_npu = sample.device.type == "npu"
if isinstance(timestep, float):
dtype = torch.float32 if is_mps_or_npu else torch.float64
dtype = torch.float32 if (is_mps or is_npu) else torch.float64
else:
dtype = torch.int32 if is_mps_or_npu else torch.int64
dtype = torch.int32 if (is_mps or is_npu) else torch.int64
timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)
elif len(timesteps.shape) == 0:
timesteps = timesteps[None].to(sample.device)
Expand Down
7 changes: 4 additions & 3 deletions src/diffusers/models/unets/unet_2d_condition.py
Original file line number Diff line number Diff line change
Expand Up @@ -914,11 +914,12 @@ def get_time_embed(
if not torch.is_tensor(timesteps):
# TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
# This would be a good case for the `match` statement (Python 3.10+)
is_mps_or_npu = sample.device.type == "mps" or sample.device.type == "npu"
is_mps = sample.device.type == "mps"
is_npu = sample.device.type == "npu"
if isinstance(timestep, float):
dtype = torch.float32 if is_mps_or_npu else torch.float64
dtype = torch.float32 if (is_mps or is_npu) else torch.float64
else:
dtype = torch.int32 if is_mps_or_npu else torch.int64
dtype = torch.int32 if (is_mps or is_npu) else torch.int64
timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)
elif len(timesteps.shape) == 0:
timesteps = timesteps[None].to(sample.device)
Expand Down
7 changes: 4 additions & 3 deletions src/diffusers/models/unets/unet_3d_condition.py
Original file line number Diff line number Diff line change
Expand Up @@ -623,11 +623,12 @@ def forward(
if not torch.is_tensor(timesteps):
# TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
# This would be a good case for the `match` statement (Python 3.10+)
is_mps_or_npu = sample.device.type == "mps" or sample.device.type == "npu"
is_mps = sample.device.type == "mps"
is_npu = sample.device.type == "npu"
if isinstance(timestep, float):
dtype = torch.float32 if is_mps_or_npu else torch.float64
dtype = torch.float32 if (is_mps or is_npu) else torch.float64
else:
dtype = torch.int32 if is_mps_or_npu else torch.int64
dtype = torch.int32 if (is_mps or is_npu) else torch.int64
timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)
elif len(timesteps.shape) == 0:
timesteps = timesteps[None].to(sample.device)
Expand Down
7 changes: 4 additions & 3 deletions src/diffusers/models/unets/unet_i2vgen_xl.py
Original file line number Diff line number Diff line change
Expand Up @@ -574,11 +574,12 @@ def forward(
if not torch.is_tensor(timesteps):
# TODO: this requires sync between CPU and GPU. So try to pass `timesteps` as tensors if you can
# This would be a good case for the `match` statement (Python 3.10+)
is_mps_or_npu = sample.device.type == "mps" or sample.device.type == "npu"
is_mps = sample.device.type == "mps"
is_npu = sample.device.type == "npu"
if isinstance(timesteps, float):
dtype = torch.float32 if is_mps_or_npu else torch.float64
dtype = torch.float32 if (is_mps or is_npu) else torch.float64
else:
dtype = torch.int32 if is_mps_or_npu else torch.int64
dtype = torch.int32 if (is_mps or is_npu) else torch.int64
timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)
elif len(timesteps.shape) == 0:
timesteps = timesteps[None].to(sample.device)
Expand Down
7 changes: 4 additions & 3 deletions src/diffusers/models/unets/unet_motion_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -2113,11 +2113,12 @@ def forward(
if not torch.is_tensor(timesteps):
# TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
# This would be a good case for the `match` statement (Python 3.10+)
is_mps_or_npu = sample.device.type == "mps" or sample.device.type == "npu"
is_mps = sample.device.type == "mps"
is_npu = sample.device.type == "npu"
if isinstance(timestep, float):
dtype = torch.float32 if is_mps_or_npu else torch.float64
dtype = torch.float32 if (is_mps or is_npu) else torch.float64
else:
dtype = torch.int32 if is_mps_or_npu else torch.int64
dtype = torch.int32 if (is_mps or is_npu) else torch.int64
timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)
elif len(timesteps.shape) == 0:
timesteps = timesteps[None].to(sample.device)
Expand Down
7 changes: 4 additions & 3 deletions src/diffusers/models/unets/unet_spatio_temporal_condition.py
Original file line number Diff line number Diff line change
Expand Up @@ -401,11 +401,12 @@ def forward(
if not torch.is_tensor(timesteps):
# TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
# This would be a good case for the `match` statement (Python 3.10+)
is_mps_or_npu = sample.device.type == "mps" or sample.device.type == "npu"
is_mps = sample.device.type == "mps"
is_npu = sample.device.type == "npu"
if isinstance(timestep, float):
dtype = torch.float32 if is_mps_or_npu else torch.float64
dtype = torch.float32 if (is_mps or is_npu) else torch.float64
else:
dtype = torch.int32 if is_mps_or_npu else torch.int64
dtype = torch.int32 if (is_mps or is_npu) else torch.int64
timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)
elif len(timesteps.shape) == 0:
timesteps = timesteps[None].to(sample.device)
Expand Down
7 changes: 4 additions & 3 deletions src/diffusers/pipelines/audioldm2/modeling_audioldm2.py
Original file line number Diff line number Diff line change
Expand Up @@ -767,11 +767,12 @@ def forward(
if not torch.is_tensor(timesteps):
# TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
# This would be a good case for the `match` statement (Python 3.10+)
is_mps_or_npu = sample.device.type == "mps" or sample.device.type == "npu"
is_mps = sample.device.type == "mps"
is_npu = sample.device.type == "npu"
if isinstance(timestep, float):
dtype = torch.float32 if is_mps_or_npu else torch.float64
dtype = torch.float32 if (is_mps or is_npu) else torch.float64
else:
dtype = torch.int32 if is_mps_or_npu else torch.int64
dtype = torch.int32 if (is_mps or is_npu) else torch.int64
timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)
elif len(timesteps.shape) == 0:
timesteps = timesteps[None].to(sample.device)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1162,11 +1162,12 @@ def forward(
if not torch.is_tensor(timesteps):
# TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
# This would be a good case for the `match` statement (Python 3.10+)
is_mps_or_npu = sample.device.type == "mps" or sample.device.type == "npu"
is_mps = sample.device.type == "mps"
is_npu = sample.device.type == "npu"
if isinstance(timestep, float):
dtype = torch.float32 if is_mps_or_npu else torch.float64
dtype = torch.float32 if (is_mps or is_npu) else torch.float64
else:
dtype = torch.int32 if is_mps_or_npu else torch.int64
dtype = torch.int32 if (is_mps or is_npu) else torch.int64
timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)
elif len(timesteps.shape) == 0:
timesteps = timesteps[None].to(sample.device)
Expand Down
7 changes: 4 additions & 3 deletions src/diffusers/pipelines/dit/pipeline_dit.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,11 +186,12 @@ def __call__(
if not torch.is_tensor(timesteps):
# TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
# This would be a good case for the `match` statement (Python 3.10+)
is_mps_or_npu = latent_model_input.device.type == "mps" or latent_model_input.device.type == "npu"
is_mps = latent_model_input.device.type == "mps"
is_npu = latent_model_input.device.type == "npu"
if isinstance(timesteps, float):
dtype = torch.float32 if is_mps_or_npu else torch.float64
dtype = torch.float32 if (is_mps or is_npu) else torch.float64
else:
dtype = torch.int32 if is_mps_or_npu else torch.int64
dtype = torch.int32 if (is_mps or is_npu) else torch.int64
timesteps = torch.tensor([timesteps], dtype=dtype, device=latent_model_input.device)
elif len(timesteps.shape) == 0:
timesteps = timesteps[None].to(latent_model_input.device)
Expand Down
7 changes: 4 additions & 3 deletions src/diffusers/pipelines/latte/pipeline_latte.py
Original file line number Diff line number Diff line change
Expand Up @@ -796,11 +796,12 @@ def __call__(
if not torch.is_tensor(current_timestep):
# TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
# This would be a good case for the `match` statement (Python 3.10+)
is_mps_or_npu = latent_model_input.device.type == "mps" or latent_model_input.device.type == "npu"
is_mps = latent_model_input.device.type == "mps"
is_npu = latent_model_input.device.type == "npu"
if isinstance(current_timestep, float):
dtype = torch.float32 if is_mps_or_npu else torch.float64
dtype = torch.float32 if (is_mps or is_npu) else torch.float64
else:
dtype = torch.int32 if is_mps_or_npu else torch.int64
dtype = torch.int32 if (is_mps or is_npu) else torch.int64
current_timestep = torch.tensor([current_timestep], dtype=dtype, device=latent_model_input.device)
elif len(current_timestep.shape) == 0:
current_timestep = current_timestep[None].to(latent_model_input.device)
Expand Down
7 changes: 4 additions & 3 deletions src/diffusers/pipelines/lumina/pipeline_lumina.py
Original file line number Diff line number Diff line change
Expand Up @@ -805,11 +805,12 @@ def __call__(
if not torch.is_tensor(current_timestep):
# TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
# This would be a good case for the `match` statement (Python 3.10+)
is_mps_or_npu = latent_model_input.device.type == "mps" or latent_model_input.device.type == "npu"
is_mps = latent_model_input.device.type == "mps"
is_npu = latent_model_input.device.type == "npu"
if isinstance(current_timestep, float):
dtype = torch.float32 if is_mps_or_npu else torch.float64
dtype = torch.float32 if (is_mps or is_npu) else torch.float64
else:
dtype = torch.int32 if is_mps_or_npu else torch.int64
dtype = torch.int32 if (is_mps or is_npu) else torch.int64
current_timestep = torch.tensor(
[current_timestep],
dtype=dtype,
Expand Down
7 changes: 4 additions & 3 deletions src/diffusers/pipelines/pag/pipeline_pag_pixart_sigma.py
Original file line number Diff line number Diff line change
Expand Up @@ -806,11 +806,12 @@ def __call__(
if not torch.is_tensor(current_timestep):
# TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
# This would be a good case for the `match` statement (Python 3.10+)
is_mps_or_npu = latent_model_input.device.type == "mps" or latent_model_input.device.type == "npu"
is_mps = latent_model_input.device.type == "mps"
is_npu = latent_model_input.device.type == "npu"
if isinstance(current_timestep, float):
dtype = torch.float32 if is_mps_or_npu else torch.float64
dtype = torch.float32 if (is_mps or is_npu) else torch.float64
else:
dtype = torch.int32 if is_mps_or_npu else torch.int64
dtype = torch.int32 if (is_mps or is_npu) else torch.int64
current_timestep = torch.tensor([current_timestep], dtype=dtype, device=latent_model_input.device)
elif len(current_timestep.shape) == 0:
current_timestep = current_timestep[None].to(latent_model_input.device)
Expand Down
Loading

0 comments on commit 8fe6408

Please sign in to comment.