From 39871c80f16046fbdd367839e8c41c70d5287a2d Mon Sep 17 00:00:00 2001 From: Yuanqiang Liu Date: Sat, 14 Sep 2024 17:13:43 +0800 Subject: [PATCH] [torch-frontend] update torch-mlir (#450) * support lowering `aten.upsample_bilinear2d.vec` to `byteir.resize` --- .../torch-frontend/third_party/torch-mlir | 2 +- .../Conversion/ConvertTorchToCustomCall.cpp | 53 +++++++++++++++++++ .../Conversion/ConvertTorchToStablehloExt.cpp | 2 +- .../test_byteir_customcall_ops.py | 16 +++++- .../torch-frontend/python/version.txt | 2 +- 5 files changed, 71 insertions(+), 4 deletions(-) diff --git a/frontends/torch-frontend/third_party/torch-mlir b/frontends/torch-frontend/third_party/torch-mlir index 43e3118eb..208e5fac6 160000 --- a/frontends/torch-frontend/third_party/torch-mlir +++ b/frontends/torch-frontend/third_party/torch-mlir @@ -1 +1 @@ -Subproject commit 43e3118eb91274bc01f8459ad1afed4922d0034f +Subproject commit 208e5fac64453a1ae77c6dee4cd3e6c41a66ab56 diff --git a/frontends/torch-frontend/torch-frontend/lib/Conversion/ConvertTorchToCustomCall.cpp b/frontends/torch-frontend/torch-frontend/lib/Conversion/ConvertTorchToCustomCall.cpp index bf12c5072..107cedc26 100644 --- a/frontends/torch-frontend/torch-frontend/lib/Conversion/ConvertTorchToCustomCall.cpp +++ b/frontends/torch-frontend/torch-frontend/lib/Conversion/ConvertTorchToCustomCall.cpp @@ -1146,6 +1146,8 @@ class ConvertAtenUpsampleNearest2dVecOp Value input = adaptor.getInput(); RankedTensorType resultType = cast( getTypeConverter()->convertType(op.getResult().getType())); + + // TODO: if result have dynamic shape, should lowering to target_mode=scale if (!resultType.hasStaticShape()) return failure(); @@ -1173,6 +1175,55 @@ class ConvertAtenUpsampleNearest2dVecOp return success(); } }; + +class ConvertAtenUpsampleBilinear2dVecOp + : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + LogicalResult + matchAndRewrite(AtenUpsampleBilinear2dVecOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Value input = adaptor.getInput(); + RankedTensorType resultType = cast( + getTypeConverter()->convertType(op.getResult().getType())); + + // TODO: if result have dynamic shape, should lowering to target_mode=scale + if (!resultType.hasStaticShape()) + return failure(); + + bool align_corners = false; + if (!matchPattern(op.getAlignCorners(), + m_TorchConstantBool(&align_corners))) { + return rewriter.notifyMatchFailure( + op, "unimplemented: align_corners must be a constant bool"); + } + if (!align_corners) + return failure(); + + std::vector byteir_attrs; + byteir_attrs.emplace_back(rewriter.getStringAttr("target_mode"), + rewriter.getStringAttr("size")); + byteir_attrs.emplace_back(rewriter.getStringAttr("mode"), + rewriter.getStringAttr("linear")); + byteir_attrs.emplace_back( + rewriter.getStringAttr("coordinate_transformation_mode"), + rewriter.getStringAttr("align_corners")); + + auto attrs = getDefaultAttrs(rewriter); + attrs.emplace_back(rewriter.getStringAttr("call_target_name"), + rewriter.getStringAttr(getResizeName())); + attrs.emplace_back(rewriter.getStringAttr(getCustomCallAttrName()), + rewriter.getDictionaryAttr(byteir_attrs)); + + Value size = rewriter.create( + op->getLoc(), rewriter.getI64TensorAttr(resultType.getShape())); + auto customCallOp = rewriter.create( + op->getLoc(), TypeRange{resultType}, ValueRange{input, size}, + ArrayRef(attrs)); + rewriter.replaceOp(op, customCallOp->getResults()); + return success(); + } +}; } // namespace // math ops @@ -1311,6 +1362,8 @@ class ConvertTorchToCustomCall if (validCustomCallOpsSet.contains("byteir.resize")) { target.addIllegalOp(); patterns.add(typeConverter, context); + target.addIllegalOp(); + patterns.add(typeConverter, context); } if (validCustomCallOpsSet.contains("byteir.nll_loss_forward")) { diff --git a/frontends/torch-frontend/torch-frontend/lib/Conversion/ConvertTorchToStablehloExt.cpp b/frontends/torch-frontend/torch-frontend/lib/Conversion/ConvertTorchToStablehloExt.cpp index e0e6f3f54..d6dc62705 100644 --- a/frontends/torch-frontend/torch-frontend/lib/Conversion/ConvertTorchToStablehloExt.cpp +++ b/frontends/torch-frontend/torch-frontend/lib/Conversion/ConvertTorchToStablehloExt.cpp @@ -70,7 +70,7 @@ struct ConvertAten_IndexPutImplOp bool accumulate; if (!matchPattern(op.getAccumulate(), m_TorchConstantBool(&accumulate))) { return rewriter.notifyMatchFailure( - op, "unimplemented: accumulate must be a constant beool"); + op, "unimplemented: accumulate must be a constant bool"); } if (!accumulate) { return op->emitError("accumulate must be true"); diff --git a/frontends/torch-frontend/torch-frontend/python/test/test_torchscript/test_byteir_customcall_ops.py b/frontends/torch-frontend/torch-frontend/python/test/test_torchscript/test_byteir_customcall_ops.py index 1c7ba6fd3..1b4165056 100644 --- a/frontends/torch-frontend/torch-frontend/python/test/test_torchscript/test_byteir_customcall_ops.py +++ b/frontends/torch-frontend/torch-frontend/python/test/test_torchscript/test_byteir_customcall_ops.py @@ -268,11 +268,25 @@ def forward(self, x): return torch.ops.aten.upsample_nearest2d.vec(x, (11, 25), None) @pytest.mark.mhlo_tools -def test_resize(): +def test_resize_nearest(): inputs = [tu.randn(3, 3, 10, 20)] model = torch.jit.trace(UpsampleNearest2dModule(), inputs) custom_test_helper(model, inputs, "byteir.resize") +class UpsampleBilinear2dModule(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x): + #FIXME: use torch.nn.interpolate to avoid torch.jit.trace + return torch.ops.aten.upsample_bilinear2d.vec(x, (11, 25), True, None) + +@pytest.mark.mhlo_tools +def test_resize_bilinear(): + inputs = [tu.randn(3, 3, 10, 20)] + model = torch.jit.trace(UpsampleBilinear2dModule(), inputs) + custom_test_helper(model, inputs, "byteir.resize") + # ============================================================================== class L2NormalizeModule(torch.nn.Module): diff --git a/frontends/torch-frontend/torch-frontend/python/version.txt b/frontends/torch-frontend/torch-frontend/python/version.txt index b966e81a4..3a1f10eae 100644 --- a/frontends/torch-frontend/torch-frontend/python/version.txt +++ b/frontends/torch-frontend/torch-frontend/python/version.txt @@ -1 +1 @@ -1.2.4 \ No newline at end of file +1.2.5 \ No newline at end of file