From 8a0989d3bce68b7621f3e020a52e11765f5369d1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E2=80=9Cusername=E2=80=9D?= <“email”> Date: Tue, 22 Oct 2024 10:10:57 +0800 Subject: [PATCH] correct the format,and removed some unnecessary changes to the pass. --- .../linalg-transposematmulb-f32.mlir | 2 +- frontend/Python/graph/graph.py | 54 ------------------- frontend/Python/graph/operation.py | 4 -- frontend/Python/graph/transform/fuse_ops.py | 41 -------------- frontend/Python/ops/linalg.py | 19 ------- .../MatMulTransposeBVec.cpp | 2 +- tools/buddy-opt/buddy-opt.cpp | 1 - 7 files changed, 2 insertions(+), 121 deletions(-) diff --git a/examples/BuddyMatmul/linalg-transposematmulb-f32.mlir b/examples/BuddyMatmul/linalg-transposematmulb-f32.mlir index 5f90d3307d..26a4458c53 100644 --- a/examples/BuddyMatmul/linalg-transposematmulb-f32.mlir +++ b/examples/BuddyMatmul/linalg-transposematmulb-f32.mlir @@ -72,4 +72,4 @@ func.func @main(){ call @printMemrefF32(%printed_m5) : (memref<*xf32>) -> () return -} \ No newline at end of file +} diff --git a/frontend/Python/graph/graph.py b/frontend/Python/graph/graph.py index 0f000bd00e..ce35693efd 100644 --- a/frontend/Python/graph/graph.py +++ b/frontend/Python/graph/graph.py @@ -164,50 +164,6 @@ def add_node(self, node: Op): self._body.append(node) self.node_table[node.name] = node - def check_deletenode(self, node : Op) -> bool: - #node : graphnode.name - if (not(node.name in self.node_table) ): - raise KeyError("node{0} not in graph".format(node.name)) - - if (len(node._children)==0): - return True - return False; - - def delete_node(self, node: Op,parents : List[Op]): - for i in parents: - i._children.remove(node.name) - node.args.clear() - node.kwargs.clear() - node._children.clear() - self._body.remove(node) - self.node_table.pop(node.name) - - def displace_node(self,node: Op,newnode:Op): - newnode._arguments = node.args - newnode._keyword_arguments = node.kwargs - newnode._tensor_meta = node.tensor_meta - newnode._op_type = node._op_type - - #deal with users/childrens - for i in node._children: - newnode.add_children(i) - users = [self.node_table[i] for i in node._children] - for user in users: - user._parents[user._parents.index(node.name)]=newnode.name - user.args[user.args.index(node.name)]=newnode.name - node._children.clear() - #deal with parents+args - for i in node._parents: - newnode.add_parent(i) - parents = [self.node_table[i] for i in node._parents] - for parent in parents: - parent._children[parent._children.index(node.name)]=newnode.name - node._parents.clear() - #update node table - self._body[self._body.index(node)] = newnode - self.node_table.pop(node.name) - self.node_table[newnode.name] = newnode - def init_op_group(self): """ Initializes operation groups within the graph. @@ -223,16 +179,6 @@ def init_op_group(self): self.group_map_device[subgraph_name] = DeviceType.UNKNOW self.op_groups[subgraph_name] = group - def check_classicfusetype(self,op:Op): - pattern = None - if isinstance(op,MatmulOp): - parentop = [ self.node_table[str(i)] for i in op._parents] - for target in parentop: - if (isinstance(target,PermuteOp) ): - pattern = target,parentop,"transpose+mamtmul2D" - #TODO : other patterns can be fused - return pattern - def fuse_ops(self, pattern_list: List[FunctionType]): """ Fuse operations in the graph based on provided fusion patterns. diff --git a/frontend/Python/graph/operation.py b/frontend/Python/graph/operation.py index 2e810cf198..0eb31fd961 100644 --- a/frontend/Python/graph/operation.py +++ b/frontend/Python/graph/operation.py @@ -153,10 +153,6 @@ def __init__(self) -> None: super().__init__() self._op_type = OpType.ReduceType -class transpose_Matmul_fusedOp(Op): - def __init__(self) -> None: - super().__init__() - self._op_type = OpType.ReduceType class GetItemOp(Op): def __init__(self) -> None: diff --git a/frontend/Python/graph/transform/fuse_ops.py b/frontend/Python/graph/transform/fuse_ops.py index b727c14367..ac7d34c99c 100644 --- a/frontend/Python/graph/transform/fuse_ops.py +++ b/frontend/Python/graph/transform/fuse_ops.py @@ -28,45 +28,6 @@ # OP_TYPE_FUSABLE_BY_SPECIFIC_PASS = [] # ANCHOR_OP_TYPE = [] -classicfuse_register = { - "transpose+mamtmul2D": transpose_Matmul_fusedOp -} - -def classic_fuse(graph : Graph): - for op in graph.body: - pattern = graph.check_classicfusetype(op) - if (pattern): - do_classicfusion(graph,op,pattern[0],pattern[1],pattern[2]) - else: - continue - -def do_classicfusion(graph : Graph,node,target : Op,parents : List[Op],pattern : str): - """ - Function to fuse some typical operations into one operation. - Such as transpose + matmul - - Args: - - graph (Graph): The input graph to be simplified. - - Returns: - - None: Modifies the input graph in place. - """ - - fusedop = classicfuse_register.get(pattern)() - fusedop.name = "fused"+node.name - graph.displace_node(node,fusedop) - fusedop.args.pop(fusedop.args.index(target.name)) - fusedop._parents.pop(fusedop._parents.index(target.name)) - fusedop.args.extend(target.args) - fusedop._parents.extend(target._parents) - targets_parent = [graph.node_table[i] for i in target._parents] - for i in targets_parent: - i.add_children(fusedop.name) - target._children.pop(target._children.index(fusedop.name)) - - if(graph.check_deletenode(target)): - graph.delete_node(target,targets_parent) - def simply_fuse(graph: Graph): """ Function to fuse all operations into one graph. @@ -79,8 +40,6 @@ def simply_fuse(graph: Graph): """ new_op_group = [] device = DeviceType.UNKNOW - classic_fuse(graph) - for op in graph.body: if isinstance(op, PlaceholderOp): continue diff --git a/frontend/Python/ops/linalg.py b/frontend/Python/ops/linalg.py index 3def615347..b561b3433a 100644 --- a/frontend/Python/ops/linalg.py +++ b/frontend/Python/ops/linalg.py @@ -1150,24 +1150,6 @@ def matmul_op( op = linalg.matmul(input1, input2, outs=[matmul_result_buffer]) return op -def transpose_matmul_fused_op( - node: transpose_Matmul_fusedOp, - symbol_table:Dict[Tuple[str, int], ir.Operation] - ): - input1 = symbol_table.get((str(node.args[0]),0)) - input2 = symbol_table.get((str(node.args[1]),0)) - - if input1 is None or input2 is None: - return - output_shape = list(node.tensor_meta["shape"]) - dtype = node.tensor_meta["dtype"] - mlir_dtype = mlir_element_type_get(dtype) - tensor_type = ir.RankedTensorType.get(output_shape, mlir_dtype) - element = mlir_element_attr_get(dtype, 0.0) - attr = ir.DenseElementsAttr.get_splat(tensor_type, element) - result_buffer = arith.ConstantOp(tensor_type, attr).result - op = linalg.matmul_transpose_b(input1, input2, outs=[result_buffer]) - return op def transpose_op( node: TransposeOp, @@ -1986,7 +1968,6 @@ def gt_op(node: GtOp, symbol_table): ops_registry = { "MatmulOp": matmul_op, - "transpose_Matmul_fusedOp": transpose_matmul_fused_op, "ArangeOp": arange_op, "UnsqueezeOp": unsqueeze_op, "ViewOp": view_op, diff --git a/midend/lib/Conversion/MatMulOptimization/MatMulTransposeBVec.cpp b/midend/lib/Conversion/MatMulOptimization/MatMulTransposeBVec.cpp index ae20df537d..4500119d76 100644 --- a/midend/lib/Conversion/MatMulOptimization/MatMulTransposeBVec.cpp +++ b/midend/lib/Conversion/MatMulOptimization/MatMulTransposeBVec.cpp @@ -54,7 +54,7 @@ class MatMulTransposeBVecPattern : public ConversionPattern{ Value B = op->getOperand(1); Value C = op->getOperand(2); - // Get shape of input and output + // Get shape of input and output. ShapedType ATy = A.getType().cast(); Type eleTy = ATy.getElementType(); diff --git a/tools/buddy-opt/buddy-opt.cpp b/tools/buddy-opt/buddy-opt.cpp index b24401d405..08e172f8bc 100644 --- a/tools/buddy-opt/buddy-opt.cpp +++ b/tools/buddy-opt/buddy-opt.cpp @@ -80,7 +80,6 @@ void registerLowerSchePass(); void registerFuncBufferizeDynamicOffsetPass(); void registerConvertMemcpyToGPUPass(); void registerLegalizeShmemOutliningPass(); -void registerMatMul_TransposeB_VecPass(); void registerMatMulTransposeBVecPass(); void registerConvertMemcpyToGPUPass(); void registerLegalizeShmemOutliningPass();