Skip to content

Commit

Permalink
Change fp16 error to warning (#764)
Browse files Browse the repository at this point in the history
* Swap fp16 error to warning

Also remove the associated test

* Formatting

* warn -> warning

* Update src/diffusers/pipeline_utils.py

Co-authored-by: Patrick von Platen <[email protected]>

* make style

Co-authored-by: Patrick von Platen <[email protected]>
  • Loading branch information
apolinario and patrickvonplaten authored Oct 7, 2022
1 parent d3f1a4c commit fdfa7c8
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 15 deletions.
10 changes: 6 additions & 4 deletions src/diffusers/pipeline_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,10 +169,12 @@ def to(self, torch_device: Optional[Union[str, torch.device]] = None):
module = getattr(self, name)
if isinstance(module, torch.nn.Module):
if module.dtype == torch.float16 and str(torch_device) in ["cpu", "mps"]:
raise ValueError(
"Pipelines loaded with `torch_dtype=torch.float16` cannot be moved to `cpu` or `mps` "
"due to the lack of support for `float16` operations on those devices in PyTorch. "
"Please remove the `torch_dtype=torch.float16` argument, or use a `cuda` device."
logger.warning(
"Pipelines loaded with `torch_dtype=torch.float16` cannot run with `cpu` or `mps` device. It"
" is not recommended to move them to `cpu` or `mps` as running them will fail. Please make"
" sure to use a `cuda` device to run the pipeline in inference. due to the lack of support for"
" `float16` operations on those devices in PyTorch. Please remove the"
" `torch_dtype=torch.float16` argument, or use a `cuda` device to run inference."
)
module.to(torch_device)
return self
Expand Down
11 changes: 0 additions & 11 deletions tests/test_pipelines.py
Original file line number Diff line number Diff line change
Expand Up @@ -247,17 +247,6 @@ def to(self, device):

return extract

def test_pipeline_fp16_cpu_error(self):
model = self.dummy_uncond_unet
scheduler = DDPMScheduler(num_train_timesteps=10)
pipe = DDIMPipeline(model.half(), scheduler)

if str(torch_device) in ["cpu", "mps"]:
self.assertRaises(ValueError, pipe.to, torch_device)
else:
# moving the pipeline to GPU should work
pipe.to(torch_device)

def test_ddim(self):
unet = self.dummy_uncond_unet
scheduler = DDIMScheduler()
Expand Down

0 comments on commit fdfa7c8

Please sign in to comment.