Skip to content

Commit

Permalink
Support bfloat16 for Upsample2D (#9480)
Browse files Browse the repository at this point in the history
* Support bfloat16 for Upsample2D

* Add test and use is_torch_version

* Resolve comments and add decorator

* Simplify require_torch_version_greater_equal decorator

* Run make style

---------

Co-authored-by: Sayak Paul <[email protected]>
Co-authored-by: YiYi Xu <[email protected]>
  • Loading branch information
3 people authored Oct 2, 2024
1 parent 33fafe3 commit 61d3764
Show file tree
Hide file tree
Showing 3 changed files with 34 additions and 6 deletions.
12 changes: 6 additions & 6 deletions src/diffusers/models/upsampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import torch.nn.functional as F

from ..utils import deprecate
from ..utils.import_utils import is_torch_version
from .normalization import RMSNorm


Expand Down Expand Up @@ -151,11 +152,10 @@ def forward(self, hidden_states: torch.Tensor, output_size: Optional[int] = None
if self.use_conv_transpose:
return self.conv(hidden_states)

# Cast to float32 to as 'upsample_nearest2d_out_frame' op does not support bfloat16
# TODO(Suraj): Remove this cast once the issue is fixed in PyTorch
# https://github.com/pytorch/pytorch/issues/86679
# Cast to float32 to as 'upsample_nearest2d_out_frame' op does not support bfloat16 until PyTorch 2.1
# https://github.com/pytorch/pytorch/issues/86679#issuecomment-1783978767
dtype = hidden_states.dtype
if dtype == torch.bfloat16:
if dtype == torch.bfloat16 and is_torch_version("<", "2.1"):
hidden_states = hidden_states.to(torch.float32)

# upsample_nearest_nhwc fails with large batch sizes. see https://github.com/huggingface/diffusers/issues/984
Expand All @@ -170,8 +170,8 @@ def forward(self, hidden_states: torch.Tensor, output_size: Optional[int] = None
else:
hidden_states = F.interpolate(hidden_states, size=output_size, mode="nearest")

# If the input is bfloat16, we cast back to bfloat16
if dtype == torch.bfloat16:
# Cast back to original dtype
if dtype == torch.bfloat16 and is_torch_version("<", "2.1"):
hidden_states = hidden_states.to(dtype)

# TODO(Suraj, Patrick) - clean up after weight dicts are correctly renamed
Expand Down
12 changes: 12 additions & 0 deletions src/diffusers/utils/testing_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -252,6 +252,18 @@ def require_torch_2(test_case):
)


def require_torch_version_greater_equal(torch_version):
"""Decorator marking a test that requires torch with a specific version or greater."""

def decorator(test_case):
correct_torch_version = is_torch_available() and is_torch_version(">=", torch_version)
return unittest.skipUnless(
correct_torch_version, f"test requires torch with the version greater than or equal to {torch_version}"
)(test_case)

return decorator


def require_torch_gpu(test_case):
"""Decorator marking a test that requires CUDA and PyTorch."""
return unittest.skipUnless(is_torch_available() and torch_device == "cuda", "test requires PyTorch+CUDA")(
Expand Down
16 changes: 16 additions & 0 deletions tests/models/test_layers_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
from diffusers.utils.testing_utils import (
backend_manual_seed,
require_torch_accelerator_with_fp64,
require_torch_version_greater_equal,
torch_device,
)

Expand Down Expand Up @@ -120,6 +121,21 @@ def test_upsample_default(self):
expected_slice = torch.tensor([-0.2173, -1.2079, -1.2079, 0.2952, 1.1254, 1.1254, 0.2952, 1.1254, 1.1254])
assert torch.allclose(output_slice.flatten(), expected_slice, atol=1e-3)

@require_torch_version_greater_equal("2.1")
def test_upsample_bfloat16(self):
torch.manual_seed(0)
sample = torch.randn(1, 32, 32, 32).to(torch.bfloat16)
upsample = Upsample2D(channels=32, use_conv=False)
with torch.no_grad():
upsampled = upsample(sample)

assert upsampled.shape == (1, 32, 64, 64)
output_slice = upsampled[0, -1, -3:, -3:]
expected_slice = torch.tensor(
[-0.2173, -1.2079, -1.2079, 0.2952, 1.1254, 1.1254, 0.2952, 1.1254, 1.1254], dtype=torch.bfloat16
)
assert torch.allclose(output_slice.flatten(), expected_slice, atol=1e-3)

def test_upsample_with_conv(self):
torch.manual_seed(0)
sample = torch.randn(1, 32, 32, 32)
Expand Down

0 comments on commit 61d3764

Please sign in to comment.