Skip to content

Commit

Permalink
upsample_bilinear: fix output data-type. (#7111)
Browse files Browse the repository at this point in the history
  • Loading branch information
ysiraichi authored May 29, 2024
1 parent c7bbdfb commit 468a5c9
Show file tree
Hide file tree
Showing 2 changed files with 54 additions and 2 deletions.
48 changes: 48 additions & 0 deletions test/test_operations.py
Original file line number Diff line number Diff line change
Expand Up @@ -2069,6 +2069,54 @@ def test(f, xshape, ishapes):
for xshape, i0shape, i1shape in cases[f2]:
test(f2, xshape, (i0shape, i1shape))

def test_upsample_bilinear_double(self):
# Originally, the upsample_bilinear implementation (in resize_ops.cpp)
# was copied from TF. The computation was done intentionally on F32 and
# not cast back[1]. However, that didn't reflect in the returned tensor.
# Basically, what would happen is:
#
# 1. A tensor of data-type other than F32 is created:
# > a = torch.rand(..., dtype=torch.double)
#
# 2. Call upsample_bilinear on it
# > r = torch.nn.functional.upsample_bilinear(a, scale_factor=2)
#
# 3. The result's data-type would show as torch.float64, but its inner
# HLO representation would be actually F32.
#
# - It would rarely surface as an error, since we do data-type
# promotion at the HLO level.
#
# - When this result is the argument of a new HLO function, XLA
# would actually expect a F16 tensor, since its torch.Tensor
# data-type "is" torch.float16. However, since the actual HLO
# data-type is F32, XLA raises an error.
#
# See more details at [2].
#
# [1]: https://github.com/tensorflow/tensorflow/commit/f8b35e00afe09c8606bcb0441a51be8bd38168d2
# [2]: https://github.com/pytorch/xla/issues/7095

def foo(x, is_xla=False):
# Compute upsample_bilinear.
r = torch.nn.functional.upsample_bilinear(x, scale_factor=2)

if is_xla:
# Mark the end of the HLO graph.
xm.mark_step()

# Start a new HLO graph using the upsample_bilinear result as
# one of its arguments.
return r + 5

inp = torch.rand(1, 3, 10, 10, dtype=torch.double)
Xinp = inp.to(xm.xla_device())

out = foo(inp)
Xout = foo(Xinp, is_xla=True)

self.assertEqual(out, Xout.cpu())


class MNISTComparator(nn.Module):

Expand Down
8 changes: 6 additions & 2 deletions torch_xla/csrc/resize_ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,10 @@ xla::XlaOp BuildResize(xla::XlaOp input, const xla::Shape& output_shape,
bool is_kernel_bilinear) {
// Code copied from
// https://github.com/tensorflow/tensorflow/blob/e51d6ab5730092775d516b18fa4ee85d49602cd8/tensorflow/compiler/tf2xla/kernels/image_resize_ops.cc#L477-L672
//
// Changes:
// - Remove F32 data-type conversion when is_kernel_bilinear
// See: https://github.com/pytorch/xla/issues/7095

// We implement bilinear interpolation and nearest neighbor with a Gather op.
// For each output pixel, we gather the necessary slices of the input.
Expand Down Expand Up @@ -53,7 +57,7 @@ xla::XlaOp BuildResize(xla::XlaOp input, const xla::Shape& output_shape,
<< "input and output must have the same element type";

xla::PrimitiveType original_input_type = input_type;
if (is_kernel_bilinear || xla::primitive_util::IsIntegralType(input_type)) {
if (xla::primitive_util::IsIntegralType(input_type)) {
input = xla::ConvertElementType(input, xla::F32);
input_type = xla::F32;
}
Expand Down Expand Up @@ -210,7 +214,7 @@ xla::XlaOp BuildResize(xla::XlaOp input, const xla::Shape& output_shape,
absl::InlinedVector<int64_t, 4> perm = {2, 0, 1, 3};
input = xla::Transpose(input, perm);

if (!is_kernel_bilinear && original_input_type != input_type) {
if (original_input_type != input_type) {
input = xla::ConvertElementType(input, original_input_type);
}
return input;
Expand Down

0 comments on commit 468a5c9

Please sign in to comment.