Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[torch-frontend] update torch-mlir #480

Merged
merged 4 commits into from
Nov 6, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -709,110 +709,7 @@ class ConvertDynamicMaskStitchCustomOp : public OpConversionPattern<CustomOp> {
};
} // namespace

// AtenNllLossForwardOp
// output, weight = torch.aten.nll_loss_forward(input, target)
namespace {
class ConvertAtenNllLossForwardOp
: public OpConversionPattern<AtenNllLossForwardOp> {
public:
using OpConversionPattern::OpConversionPattern;

LogicalResult
matchAndRewrite(AtenNllLossForwardOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
Value input = adaptor.getSelf();
Value target = adaptor.getTarget();
Value weight = adaptor.getWeight();

int64_t reduction;
if (!matchPattern(op.getReduction(), m_TorchConstantInt(&reduction)))
return rewriter.notifyMatchFailure(op, "reduction must be constant");

int64_t ignoreIndex;
if (!matchPattern(op.getIgnoreIndex(), m_TorchConstantInt(&ignoreIndex)))
return rewriter.notifyMatchFailure(op, "ignore_index must be constant");

if (!isa<mlir::torch::Torch::NoneType>(weight.getType()))
return rewriter.notifyMatchFailure(
op, "Unimplemented, the weight operand is not incorporated.");

SmallVector<Value> bufferArgs({input, target});
SmallVector<Type> resultTypes;
if (failed(getTypeConverter()->convertTypes(op.getResultTypes(),
resultTypes))) {
return op.emitError("could not convert output types");
}

std::vector<NamedAttribute> byteir_attrs;
byteir_attrs.emplace_back(rewriter.getStringAttr("reduction"),
rewriter.getI64IntegerAttr(reduction));
byteir_attrs.emplace_back(rewriter.getStringAttr("ignore_index"),
rewriter.getI64IntegerAttr(ignoreIndex));

auto attrs = getDefaultAttrs(rewriter);
attrs.emplace_back(rewriter.getStringAttr("call_target_name"),
rewriter.getStringAttr(getNllLossForwardName()));
attrs.emplace_back(rewriter.getStringAttr(getCustomCallAttrName()),
rewriter.getDictionaryAttr(byteir_attrs));

auto customCallOp = rewriter.create<stablehlo::CustomCallOp>(
op->getLoc(), resultTypes, bufferArgs, ArrayRef<NamedAttribute>{attrs});
rewriter.replaceOp(op, customCallOp->getResults());
return success();
}
};

// AtenNllLossBackwardOp
// result = nll_loss_backward(grad_output, input, target)
class ConvertAtenNllLossBackwardOp
: public OpConversionPattern<AtenNllLossBackwardOp> {
public:
using OpConversionPattern::OpConversionPattern;

LogicalResult
matchAndRewrite(AtenNllLossBackwardOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
Value grad_out = adaptor.getGradOutput();
Value input = adaptor.getSelf();
Value target = adaptor.getTarget();
Value weight = adaptor.getWeight();
Value total_weight = adaptor.getTotalWeight();

int64_t reduction;
if (!matchPattern(op.getReduction(), m_TorchConstantInt(&reduction)))
return rewriter.notifyMatchFailure(op, "reduction must be constant");

int64_t ignoreIndex;
if (!matchPattern(op.getIgnoreIndex(), m_TorchConstantInt(&ignoreIndex)))
return rewriter.notifyMatchFailure(op, "ignore_index must be constant");

if (!isa<mlir::torch::Torch::NoneType>(weight.getType()))
return rewriter.notifyMatchFailure(
op, "Unimplemented, the weight operand is not incorporated.");

SmallVector<Value> bufferArgs({grad_out, input, target, total_weight});
RankedTensorType resultType = cast<RankedTensorType>(
getTypeConverter()->convertType(op->getResult(0).getType()));

std::vector<NamedAttribute> byteir_attrs;
byteir_attrs.emplace_back(rewriter.getStringAttr("reduction"),
rewriter.getI64IntegerAttr(reduction));
byteir_attrs.emplace_back(rewriter.getStringAttr("ignore_index"),
rewriter.getI64IntegerAttr(ignoreIndex));

auto attrs = getDefaultAttrs(rewriter);
attrs.emplace_back(rewriter.getStringAttr("call_target_name"),
rewriter.getStringAttr(getNllLossBackwardName()));
attrs.emplace_back(rewriter.getStringAttr(getCustomCallAttrName()),
rewriter.getDictionaryAttr(byteir_attrs));

auto customCallOp = rewriter.create<stablehlo::CustomCallOp>(
op->getLoc(), resultType, bufferArgs, ArrayRef<NamedAttribute>{attrs});
rewriter.replaceOp(op, customCallOp->getResults());
return success();
}
};

// torch.operator "byteir.l2_norm"
// operands: input, dims, eps
class ConvertByteIRL2NormOp : public OpConversionPattern<OperatorOp> {
Expand All @@ -838,7 +735,7 @@ class ConvertByteIRL2NormOp : public OpConversionPattern<OperatorOp> {
if (!matchPattern(op.getOperand(1), m_TorchListOfConstantInts(dims))) {
return op.emitError("dims must be a list of int");
}
for (int64_t i = 0; i < dims.size(); i++) {
for (size_t i = 0; i < dims.size(); i++) {
if (dims[i] < 0)
dims[i] += rank;
}
Expand Down Expand Up @@ -1431,15 +1328,6 @@ class ConvertTorchToCustomCall
patterns.add<ConvertAtenUpsampleBilinear2dVecOp>(typeConverter, context);
}

if (validCustomCallOpsSet.contains("byteir.nll_loss_forward")) {
target.addIllegalOp<AtenNllLossForwardOp>();
patterns.add<ConvertAtenNllLossForwardOp>(typeConverter, context);
}
if (validCustomCallOpsSet.contains("byteir.nll_loss_backward")) {
target.addIllegalOp<AtenNllLossBackwardOp>();
patterns.add<ConvertAtenNllLossBackwardOp>(typeConverter, context);
}

populateMathToCustomCallPattern(target, typeConverter, patterns,
validCustomCallOpsSet);

Expand Down Expand Up @@ -1489,6 +1377,10 @@ void mlir::populateMathToCustomCallPattern(
CONVERT_MATH_TO_CUSTOM_CALL_PATTERN(AtenCoshOp, "math.cosh");
CONVERT_MATH_TO_CUSTOM_CALL_PATTERN(AtenErfOp, "math.erf");
CONVERT_MATH_TO_CUSTOM_CALL_PATTERN(AtenTruncOp, "math.trunc");
CONVERT_MATH_TO_CUSTOM_CALL_PATTERN(AtenExp2Op, "math.exp2");
CONVERT_MATH_TO_CUSTOM_CALL_PATTERN(AtenCopysignTensorOp, "math.copysign");
CONVERT_MATH_TO_CUSTOM_CALL_PATTERN(AtenLdexpTensorOp, "math.ldexp");
CONVERT_MATH_TO_CUSTOM_CALL_PATTERN(AtenSignbitOp, "math.signbit");
#undef CONVERT_MATH_TO_CUSTOM_CALL_PATTERN
}

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
import torch
import torch_frontend
from torch_frontend import compile, MATH_CUSTOM_OPS

def custom_test_helper(model, inputs, custom_op_name):
mlir_module = compile(model, inputs, "stablehlo", backend_legal_ops=MATH_CUSTOM_OPS)
mlir_str = mlir_module.operation.get_asm()
compare_str = "stablehlo.custom_call @{}".format(custom_op_name)
assert compare_str in mlir_str

# ==============================================================================

class TruncModule(torch.nn.Module):
def forward(self, x):
return torch.trunc(x)

def test_trunc():
custom_test_helper(TruncModule(), [torch.rand(3, 4)], "math.trunc")

# ==============================================================================

class Exp2Module(torch.nn.Module):
def forward(self, x):
return torch.exp2(x)

def test_exp2():
custom_test_helper(Exp2Module(), [torch.rand(3, 4)], "math.exp2")

# ==============================================================================

class CopysignModule(torch.nn.Module):
def forward(self, x, y):
return torch.copysign(x, y)

def test_exp2():
custom_test_helper(CopysignModule(), [torch.rand(3, 4), torch.rand(3, 4)], "math.copysign")

# ==============================================================================

class LdexpModule(torch.nn.Module):
def forward(self, x, y):
return torch.ldexp(x, y)

def test_ldexp():
custom_test_helper(LdexpModule(), [torch.rand(3, 4), torch.rand(3, 4)], "math.ldexp")

# ==============================================================================

class SignbitModule(torch.nn.Module):
def forward(self, x):
return torch.signbit(x)

def test_signbit():
custom_test_helper(SignbitModule(), [torch.rand(3, 4)], "math.signbit")
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,21 @@ def test_attention():
module = compile(model, inputs, "stablehlo")
numerical_test_helper(module, inputs, model(*inputs))

# ==============================================================================

# class NllLossStaticModule(torch.nn.Module):
# # Here the 2nd index is ignored.
# def forward(self, x, y):
# return torch.ops.aten.nll_loss_forward(
# x, target=y, weight=None, reduction=0, ignore_index=2
# )

# def test_nll_loss_forward():
# inputs = [tu.rand(2, 3), tu.randint(low=0, high=3, size=(2,))]
# module = compile(NllLossStaticModule(), inputs, "stablehlo", verbose=True, debug=torch_frontend.DebugType(1))
# numerical_test_helper(module, inputs, model(*inputs))


# ==============================================================================

class VarDimModule(torch.nn.Module):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -117,3 +117,6 @@ def test_flash_attn():
optimized_model = torch.compile(model, backend=flash_attn_compile_fx)
output = optimized_model(q, k, v)
output.sum().backward()

if __name__ == "__main__":
test_flash_attn()
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,10 @@
"math.cosh",
"math.erf",
"math.trunc",
"math.exp2",
"math.copysign",
"math.ldexp",
"math.signbit",
]

BYTEIR_CUSTOM_OPS = [
Expand All @@ -60,6 +64,7 @@
"byteir.topk": ["aten.topk"],
"byteir.non_zero": ["aten.nonzero"],
"byteir.resize": ["aten.upsample_nearest2d.vec"],
# math custom ops
"math.asin": ["aten.asin"],
"math.asinh": ["aten.asinh"],
"math.sinh": ["aten.sinh"],
Expand All @@ -70,6 +75,10 @@
"math.cosh": ["aten.cosh"],
"math.erf": ["aten.erf"],
"math.trunc": ["aten.trunc"],
"math.exp2": ["aten.exp2"],
"math.copysign": ["aten.copysign.Tensor"],
"math.ldexp": ["aten.ldexp.Tensor"],
"math.signbit": ["aten.signbit"],
# torch.operator
"byteir.flash_attn_fwd": ["byteir.flash_attn_fwd"],
"byteir.flash_attn_kvcache": ["byteir.flash_attn_kvcache"],
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// RUN: torch-frontend-opt %s -convert-torch-to-custom-call="valid-custom-call-ops=byteir.layer_norm,byteir.l2_norm,byteir.softmax,byteir.log_softmax,byteir.nll_loss_forward,byteir.nll_loss_backward,byteir.gelu,byteir.arg_max,byteir.arg_min,byteir.one_hot,byteir.topk,byteir.non_zero,byteir.resize" --canonicalize-ext | FileCheck %s
// RUN: torch-frontend-opt %s -convert-torch-to-custom-call="valid-custom-call-ops=byteir.layer_norm,byteir.l2_norm,byteir.softmax,byteir.log_softmax,byteir.gelu,byteir.arg_max,byteir.arg_min,byteir.one_hot,byteir.topk,byteir.non_zero,byteir.resize" --canonicalize-ext | FileCheck %s
// RUN: torch-frontend-opt %s -convert-torch-to-custom-call --canonicalize-ext | FileCheck %s --check-prefix NONE
// RUN: torch-frontend-opt %s -convert-torch-to-custom-call="valid-custom-call-ops=math.asin" --canonicalize-ext | FileCheck %s --check-prefix MATH

Expand Down Expand Up @@ -188,32 +188,6 @@ func.func @torch.custom.dynamic_mask_stitch(%arg0: !torch.vtensor<[?,?],f32>, %a
// CHECK: byteir_attrs = {}
// CHECH-NOT: torch.custom_op

func.func @torch.aten.nll_loss_forward(%arg0: !torch.vtensor<[8192,50257],f32>, %arg1: !torch.vtensor<[8192],si64>) -> (!torch.vtensor<[],f32>, !torch.vtensor<[],f32>) {
%int1 = torch.constant.int 1
%int-1 = torch.constant.int -1
%none = torch.constant.none
%output, %total_weight = torch.aten.nll_loss_forward %arg0, %arg1, %none, %int1, %int-1 : !torch.vtensor<[8192,50257],f32>, !torch.vtensor<[8192],si64>, !torch.none, !torch.int, !torch.int -> !torch.vtensor<[],f32>, !torch.vtensor<[],f32>
return %output, %total_weight : !torch.vtensor<[],f32>, !torch.vtensor<[],f32>
}
// CHECK-LABEL: func.func @torch.aten.nll_loss_forward
// CHECK: stablehlo.custom_call
// CHECK-SAME: @byteir.nll_loss_forward
// CHECK: byteir_attrs = {ignore_index = -1 : i64, reduction = 1 : i64}
// CHECH-NOT: torch.aten.nll_loss_forward

func.func @torch.aten.nll_loss_backward(%arg0: !torch.vtensor<[],f32>, %arg1: !torch.vtensor<[8192,50257],f32>, %arg2: !torch.vtensor<[8192],si64>, %arg3: !torch.vtensor<[],f32>) -> (!torch.vtensor<[8192,50257],f32>) {
%int1 = torch.constant.int 1
%int-1 = torch.constant.int -1
%none = torch.constant.none
%0 = torch.aten.nll_loss_backward %arg0, %arg1, %arg2, %none, %int1, %int-1, %arg3 : !torch.vtensor<[],f32>, !torch.vtensor<[8192,50257],f32>, !torch.vtensor<[8192],si64>, !torch.none, !torch.int, !torch.int, !torch.vtensor<[],f32> -> !torch.vtensor<[8192,50257],f32>
return %0 : !torch.vtensor<[8192,50257],f32>
}
// CHECK-LABEL: func.func @torch.aten.nll_loss_backward
// CHECK: stablehlo.custom_call
// CHECK-SAME: @byteir.nll_loss_backward
// CHECK: byteir_attrs = {ignore_index = -1 : i64, reduction = 1 : i64}
// CHECH-NOT: torch.aten.nll_loss_backward

func.func @torch.byteir.l2_norm(%arg0: !torch.vtensor<[3,4],f32>) -> !torch.vtensor<[3,4],f32> {
%float9.999990e-13 = torch.constant.float 9.9999999999999998E-13
%int1 = torch.constant.int -1
Expand Down
Loading