From 8196bfd2c50046808b188a7b1d00a65f49996f78 Mon Sep 17 00:00:00 2001 From: effrey-liu <2318266514@qq.com> Date: Wed, 30 Oct 2024 16:02:17 +0800 Subject: [PATCH] fix FileCheck error --- frontend/Python/frontend.py | 7 ++- frontend/Python/graph/operation.py | 8 ++- frontend/Python/ops/func.py | 2 +- frontend/Python/ops/linalg.py | 93 +++++++++++++++++++++++++++++- frontend/Python/ops/utils.py | 1 - tests/Python/test_embedding.py | 2 +- tests/Python/test_expand.py | 5 +- tests/Python/test_ge.py | 15 ++--- 8 files changed, 115 insertions(+), 18 deletions(-) diff --git a/frontend/Python/frontend.py b/frontend/Python/frontend.py index 865614da5a..a5bb2b6794 100644 --- a/frontend/Python/frontend.py +++ b/frontend/Python/frontend.py @@ -165,10 +165,11 @@ def __init__( "cos.default": CosOp, "sin.default": SinOp, "argmax.default": ArgMaxOp, - "split.Tensor":SplitOp, - "max.default":MaxOp, - "gt.Scalar":GtOp, + "split.Tensor": SplitOp, + "max.default": MaxOp, + "gt.Scalar": GtOp, "ge.Scalar": GeOp, + "gt.Tensor": GreaterThanOp, } @property diff --git a/frontend/Python/graph/operation.py b/frontend/Python/graph/operation.py index 9c2618e3c6..4910cb1773 100644 --- a/frontend/Python/graph/operation.py +++ b/frontend/Python/graph/operation.py @@ -539,4 +539,10 @@ def __init__(self) -> None: class GeOp(Op): def __init__(self) -> None: super().__init__() - self._op_type = OpType.ElementwiseType \ No newline at end of file + self._op_type = OpType.ElementwiseType + + +class GreaterThanOp(Op): + def __init__(self) -> None: + super().__init__() + self._op_type = OpType.BroadcastType \ No newline at end of file diff --git a/frontend/Python/ops/func.py b/frontend/Python/ops/func.py index a7dcc5e11b..6b35c03410 100644 --- a/frontend/Python/ops/func.py +++ b/frontend/Python/ops/func.py @@ -106,7 +106,7 @@ def param_extract( TensorDType.Int64: ir.IntegerType.get_signless(64), } memref_element_type = dtype_mapping[node.tensor_meta["dtype"]] - if(len(node.tensor_meta['shape'])== 0): + if len(node.tensor_meta["shape"]) == 0: output_shape = [1] else: output_shape = list(node.tensor_meta["shape"]) diff --git a/frontend/Python/ops/linalg.py b/frontend/Python/ops/linalg.py index ec67b39a19..e1a5cfb819 100644 --- a/frontend/Python/ops/linalg.py +++ b/frontend/Python/ops/linalg.py @@ -2044,6 +2044,96 @@ def ge_op( return op +def greater_than_op( + node: GreaterThanOp, + symbol_table: Dict[Tuple[str, int], ir.Operation], +): + """ + Import the tensor greater than operation. + From buddy GreaterThanOp to MLIR arith `constant` operation. + Note: This op, campare two input nodes, and output bool tensor to represent + compare result. + Args: + node: Containing information from the input graph node. + symbol_table: A dictionary mapping symbols to their corresponding + operations. + Returns: + op: The operation return the linalg.generic op. + """ + input1 = symbol_table.get((str(node.args[0]), 0)) + input2 = symbol_table.get((str(node.args[1]), 0)) + output_shape = list(node.tensor_meta["shape"]) + dtype = node.tensor_meta["dtype"] + value = ir.IntegerAttr.get(ir.IntegerType.get_signless(64), 4) + shp1 = list(ir.RankedTensorType(ir.Value(input1).type).shape) + shp2 = list(ir.RankedTensorType(ir.Value(input2).type).shape) + dtype = mlir_element_type_get(dtype) + tensor_type = ir.RankedTensorType.get(output_shape, dtype) + output = tensor.EmptyOp(output_shape, dtype) + if len(shp1) < len(shp2): + if int(shp1[-1]) > 1 and shp2[-1] == 1: + generic_map = ir.AffineMap.get_permutation( + [i for i in range(len(shp2) + 1)] + ) + op = linalg.GenericOp( + [tensor_type], + [input1, input2], + [output], + ir.ArrayAttr.get( + [ + ir.AffineMapAttr.get( + generic_map.get_submap( + [ + i + for i in range( + len(shp2) - len(shp1), len(shp2) + ) + ] + ) + ), + ir.AffineMapAttr.get( + generic_map.get_submap( + [i for i in range(0, len(shp2) - 1)] + + [len(shp2)] + ) + ), + ir.AffineMapAttr.get( + generic_map.get_submap( + [i for i in range(0, len(shp2))] + ) + ), + ] + ), + ir.ArrayAttr.get( + [ir.Attribute.parse("#linalg.iterator_type")] + * len(shp2) + + [ir.Attribute.parse("#linalg.iterator_type")] + ), + ) + block = ir.Block.create_at_start( + op.region, + [ + ir.RankedTensorType(input2.type).element_type, + ir.RankedTensorType(input2.type).element_type, + dtype, + ], + ) + if ( + str(ir.RankedTensorType(input2.type).element_type).find("i") + != -1 + ): + cmpop = arith.CmpIOp( + value, block.arguments[0], block.arguments[1] + ) + else: + cmpop = arith.CmpFOp( + value, block.arguments[0], block.arguments[1] + ) + block.append(cmpop) + block.append(linalg.YieldOp([cmpop.result])) + + return op + ops_registry = { "MatmulOp": matmul_op, @@ -2080,5 +2170,6 @@ def ge_op( "SplitOp": split_op, "MaxOp": max_op, "GtOp": gt_op, - "GeOp":ge_op, + "GeOp": ge_op, + "GreaterThanOp": greater_than_op, } diff --git a/frontend/Python/ops/utils.py b/frontend/Python/ops/utils.py index 337f5a6b49..dad07bd68c 100644 --- a/frontend/Python/ops/utils.py +++ b/frontend/Python/ops/utils.py @@ -53,4 +53,3 @@ def mlir_element_attr_get(type_name, value): return ir.IntegerAttr.get(ir.IntegerType.get_signless(64), value) case TensorDType.Bool: return ir.IntegerAttr.get(ir.IntegerType.get_signless(1), value) - diff --git a/tests/Python/test_embedding.py b/tests/Python/test_embedding.py index 484bb617b5..c3ae33672e 100644 --- a/tests/Python/test_embedding.py +++ b/tests/Python/test_embedding.py @@ -70,4 +70,4 @@ def foo(weight, indices): # CHECK: %{{.*}} = tosa.reshape # CHECK: return %{{.*}} # CHECK: } -# CHECK: } \ No newline at end of file +# CHECK: } diff --git a/tests/Python/test_expand.py b/tests/Python/test_expand.py index 80642b0840..713bea84fb 100644 --- a/tests/Python/test_expand.py +++ b/tests/Python/test_expand.py @@ -30,9 +30,8 @@ def foo(x, y): # CHECK: module { # CHECK-LABEL: func.func @forward -# CHECK: %{{.*}} = arith.constant -# CHECK: %{{.*}} = tensor.empty -# CHECK: %{{.*}} = linalg.generic +# CHECK: %{{.*}} = "tosa.const" +# CHECK: %{{.*}} = tosa.add # CHECK: return %{{.*}} # CHECK: } # CHECK: } diff --git a/tests/Python/test_ge.py b/tests/Python/test_ge.py index 24e202a18d..95230324c3 100644 --- a/tests/Python/test_ge.py +++ b/tests/Python/test_ge.py @@ -26,10 +26,11 @@ def foo(x, y): graph.lower_to_top_level_ir() print(graph._imported_module) -# CHECK: module { -# CHECK-LABEL: func.func @forward -# CHECK: %{{.*}} = tensor.empty -# CHECK: %{{.*}} = linalg.generic -# CHECK: return %{{.*}} -# CHECK: } -# CHECK: } +# CHECK: "builtin.module"() ({ +# CHECK-LABEL: "func.func"() <{function_type = ({{.*}} -> {{.*}}, sym_name = "forward"} +# CHECK: %{{.*}} = "arith.constant" +# CHECK: %{{.*}} = "tensor.empty" +# CHECK: %{{.*}} = "linalg.generic" +# CHECK: "func.return"(%{{.*}}) : {{.*}} -> () +# CHECK: }) : () -> () +# CHECK: }) : () -> ()