Skip to content

Commit

Permalink
fix FileCheck error
Browse files Browse the repository at this point in the history
  • Loading branch information
effrey-liu committed Oct 30, 2024
1 parent a0685b2 commit 8196bfd
Show file tree
Hide file tree
Showing 8 changed files with 115 additions and 18 deletions.
7 changes: 4 additions & 3 deletions frontend/Python/frontend.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,10 +165,11 @@ def __init__(
"cos.default": CosOp,
"sin.default": SinOp,
"argmax.default": ArgMaxOp,
"split.Tensor":SplitOp,
"max.default":MaxOp,
"gt.Scalar":GtOp,
"split.Tensor": SplitOp,
"max.default": MaxOp,
"gt.Scalar": GtOp,
"ge.Scalar": GeOp,
"gt.Tensor": GreaterThanOp,
}

@property
Expand Down
8 changes: 7 additions & 1 deletion frontend/Python/graph/operation.py
Original file line number Diff line number Diff line change
Expand Up @@ -539,4 +539,10 @@ def __init__(self) -> None:
class GeOp(Op):
def __init__(self) -> None:
super().__init__()
self._op_type = OpType.ElementwiseType
self._op_type = OpType.ElementwiseType


class GreaterThanOp(Op):
def __init__(self) -> None:
super().__init__()
self._op_type = OpType.BroadcastType
2 changes: 1 addition & 1 deletion frontend/Python/ops/func.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ def param_extract(
TensorDType.Int64: ir.IntegerType.get_signless(64),
}
memref_element_type = dtype_mapping[node.tensor_meta["dtype"]]
if(len(node.tensor_meta['shape'])== 0):
if len(node.tensor_meta["shape"]) == 0:
output_shape = [1]
else:
output_shape = list(node.tensor_meta["shape"])
Expand Down
93 changes: 92 additions & 1 deletion frontend/Python/ops/linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -2044,6 +2044,96 @@ def ge_op(

return op

def greater_than_op(
node: GreaterThanOp,
symbol_table: Dict[Tuple[str, int], ir.Operation],
):
"""
Import the tensor greater than operation.
From buddy GreaterThanOp to MLIR arith `constant` operation.
Note: This op, campare two input nodes, and output bool tensor to represent
compare result.
Args:
node: Containing information from the input graph node.
symbol_table: A dictionary mapping symbols to their corresponding
operations.
Returns:
op: The operation return the linalg.generic op.
"""
input1 = symbol_table.get((str(node.args[0]), 0))
input2 = symbol_table.get((str(node.args[1]), 0))
output_shape = list(node.tensor_meta["shape"])
dtype = node.tensor_meta["dtype"]
value = ir.IntegerAttr.get(ir.IntegerType.get_signless(64), 4)
shp1 = list(ir.RankedTensorType(ir.Value(input1).type).shape)
shp2 = list(ir.RankedTensorType(ir.Value(input2).type).shape)
dtype = mlir_element_type_get(dtype)
tensor_type = ir.RankedTensorType.get(output_shape, dtype)
output = tensor.EmptyOp(output_shape, dtype)
if len(shp1) < len(shp2):
if int(shp1[-1]) > 1 and shp2[-1] == 1:
generic_map = ir.AffineMap.get_permutation(
[i for i in range(len(shp2) + 1)]
)
op = linalg.GenericOp(
[tensor_type],
[input1, input2],
[output],
ir.ArrayAttr.get(
[
ir.AffineMapAttr.get(
generic_map.get_submap(
[
i
for i in range(
len(shp2) - len(shp1), len(shp2)
)
]
)
),
ir.AffineMapAttr.get(
generic_map.get_submap(
[i for i in range(0, len(shp2) - 1)]
+ [len(shp2)]
)
),
ir.AffineMapAttr.get(
generic_map.get_submap(
[i for i in range(0, len(shp2))]
)
),
]
),
ir.ArrayAttr.get(
[ir.Attribute.parse("#linalg.iterator_type<parallel>")]
* len(shp2)
+ [ir.Attribute.parse("#linalg.iterator_type<reduction>")]
),
)
block = ir.Block.create_at_start(
op.region,
[
ir.RankedTensorType(input2.type).element_type,
ir.RankedTensorType(input2.type).element_type,
dtype,
],
)
if (
str(ir.RankedTensorType(input2.type).element_type).find("i")
!= -1
):
cmpop = arith.CmpIOp(
value, block.arguments[0], block.arguments[1]
)
else:
cmpop = arith.CmpFOp(
value, block.arguments[0], block.arguments[1]
)
block.append(cmpop)
block.append(linalg.YieldOp([cmpop.result]))

return op


ops_registry = {
"MatmulOp": matmul_op,
Expand Down Expand Up @@ -2080,5 +2170,6 @@ def ge_op(
"SplitOp": split_op,
"MaxOp": max_op,
"GtOp": gt_op,
"GeOp":ge_op,
"GeOp": ge_op,
"GreaterThanOp": greater_than_op,
}
1 change: 0 additions & 1 deletion frontend/Python/ops/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,4 +53,3 @@ def mlir_element_attr_get(type_name, value):
return ir.IntegerAttr.get(ir.IntegerType.get_signless(64), value)
case TensorDType.Bool:
return ir.IntegerAttr.get(ir.IntegerType.get_signless(1), value)

2 changes: 1 addition & 1 deletion tests/Python/test_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,4 +70,4 @@ def foo(weight, indices):
# CHECK: %{{.*}} = tosa.reshape
# CHECK: return %{{.*}}
# CHECK: }
# CHECK: }
# CHECK: }
5 changes: 2 additions & 3 deletions tests/Python/test_expand.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,9 +30,8 @@ def foo(x, y):

# CHECK: module {
# CHECK-LABEL: func.func @forward
# CHECK: %{{.*}} = arith.constant
# CHECK: %{{.*}} = tensor.empty
# CHECK: %{{.*}} = linalg.generic
# CHECK: %{{.*}} = "tosa.const"
# CHECK: %{{.*}} = tosa.add
# CHECK: return %{{.*}}
# CHECK: }
# CHECK: }
15 changes: 8 additions & 7 deletions tests/Python/test_ge.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,10 +26,11 @@ def foo(x, y):
graph.lower_to_top_level_ir()
print(graph._imported_module)

# CHECK: module {
# CHECK-LABEL: func.func @forward
# CHECK: %{{.*}} = tensor.empty
# CHECK: %{{.*}} = linalg.generic
# CHECK: return %{{.*}}
# CHECK: }
# CHECK: }
# CHECK: "builtin.module"() ({
# CHECK-LABEL: "func.func"() <{function_type = ({{.*}} -> {{.*}}, sym_name = "forward"}
# CHECK: %{{.*}} = "arith.constant"
# CHECK: %{{.*}} = "tensor.empty"
# CHECK: %{{.*}} = "linalg.generic"
# CHECK: "func.return"(%{{.*}}) : {{.*}} -> ()
# CHECK: }) : () -> ()
# CHECK: }) : () -> ()

0 comments on commit 8196bfd

Please sign in to comment.