Skip to content

Commit

Permalink
[torch-frontend] update torch-mlir (#450)
Browse files Browse the repository at this point in the history
* support lowering `aten.upsample_bilinear2d.vec` to `byteir.resize`
  • Loading branch information
qingyunqu authored Sep 14, 2024
1 parent 0f09e5b commit 39871c8
Show file tree
Hide file tree
Showing 5 changed files with 71 additions and 4 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -1146,6 +1146,8 @@ class ConvertAtenUpsampleNearest2dVecOp
Value input = adaptor.getInput();
RankedTensorType resultType = cast<RankedTensorType>(
getTypeConverter()->convertType(op.getResult().getType()));

// TODO: if result have dynamic shape, should lowering to target_mode=scale
if (!resultType.hasStaticShape())
return failure();

Expand Down Expand Up @@ -1173,6 +1175,55 @@ class ConvertAtenUpsampleNearest2dVecOp
return success();
}
};

class ConvertAtenUpsampleBilinear2dVecOp
: public OpConversionPattern<AtenUpsampleBilinear2dVecOp> {
public:
using OpConversionPattern<AtenUpsampleBilinear2dVecOp>::OpConversionPattern;
LogicalResult
matchAndRewrite(AtenUpsampleBilinear2dVecOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
Value input = adaptor.getInput();
RankedTensorType resultType = cast<RankedTensorType>(
getTypeConverter()->convertType(op.getResult().getType()));

// TODO: if result have dynamic shape, should lowering to target_mode=scale
if (!resultType.hasStaticShape())
return failure();

bool align_corners = false;
if (!matchPattern(op.getAlignCorners(),
m_TorchConstantBool(&align_corners))) {
return rewriter.notifyMatchFailure(
op, "unimplemented: align_corners must be a constant bool");
}
if (!align_corners)
return failure();

std::vector<NamedAttribute> byteir_attrs;
byteir_attrs.emplace_back(rewriter.getStringAttr("target_mode"),
rewriter.getStringAttr("size"));
byteir_attrs.emplace_back(rewriter.getStringAttr("mode"),
rewriter.getStringAttr("linear"));
byteir_attrs.emplace_back(
rewriter.getStringAttr("coordinate_transformation_mode"),
rewriter.getStringAttr("align_corners"));

auto attrs = getDefaultAttrs(rewriter);
attrs.emplace_back(rewriter.getStringAttr("call_target_name"),
rewriter.getStringAttr(getResizeName()));
attrs.emplace_back(rewriter.getStringAttr(getCustomCallAttrName()),
rewriter.getDictionaryAttr(byteir_attrs));

Value size = rewriter.create<stablehlo::ConstantOp>(
op->getLoc(), rewriter.getI64TensorAttr(resultType.getShape()));
auto customCallOp = rewriter.create<stablehlo::CustomCallOp>(
op->getLoc(), TypeRange{resultType}, ValueRange{input, size},
ArrayRef<NamedAttribute>(attrs));
rewriter.replaceOp(op, customCallOp->getResults());
return success();
}
};
} // namespace

// math ops
Expand Down Expand Up @@ -1311,6 +1362,8 @@ class ConvertTorchToCustomCall
if (validCustomCallOpsSet.contains("byteir.resize")) {
target.addIllegalOp<AtenUpsampleNearest2dVecOp>();
patterns.add<ConvertAtenUpsampleNearest2dVecOp>(typeConverter, context);
target.addIllegalOp<AtenUpsampleBilinear2dVecOp>();
patterns.add<ConvertAtenUpsampleBilinear2dVecOp>(typeConverter, context);
}

if (validCustomCallOpsSet.contains("byteir.nll_loss_forward")) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ struct ConvertAten_IndexPutImplOp
bool accumulate;
if (!matchPattern(op.getAccumulate(), m_TorchConstantBool(&accumulate))) {
return rewriter.notifyMatchFailure(
op, "unimplemented: accumulate must be a constant beool");
op, "unimplemented: accumulate must be a constant bool");
}
if (!accumulate) {
return op->emitError("accumulate must be true");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -268,11 +268,25 @@ def forward(self, x):
return torch.ops.aten.upsample_nearest2d.vec(x, (11, 25), None)

@pytest.mark.mhlo_tools
def test_resize():
def test_resize_nearest():
inputs = [tu.randn(3, 3, 10, 20)]
model = torch.jit.trace(UpsampleNearest2dModule(), inputs)
custom_test_helper(model, inputs, "byteir.resize")

class UpsampleBilinear2dModule(torch.nn.Module):
def __init__(self):
super().__init__()

def forward(self, x):
#FIXME: use torch.nn.interpolate to avoid torch.jit.trace
return torch.ops.aten.upsample_bilinear2d.vec(x, (11, 25), True, None)

@pytest.mark.mhlo_tools
def test_resize_bilinear():
inputs = [tu.randn(3, 3, 10, 20)]
model = torch.jit.trace(UpsampleBilinear2dModule(), inputs)
custom_test_helper(model, inputs, "byteir.resize")

# ==============================================================================

class L2NormalizeModule(torch.nn.Module):
Expand Down
2 changes: 1 addition & 1 deletion frontends/torch-frontend/torch-frontend/python/version.txt
Original file line number Diff line number Diff line change
@@ -1 +1 @@
1.2.4
1.2.5

0 comments on commit 39871c8

Please sign in to comment.