From 2776c10484b79323e4eeb6390c7533d32edbda8d Mon Sep 17 00:00:00 2001 From: Wu Xintong <13683168028@163.com> Date: Fri, 3 Jan 2025 16:27:21 +0800 Subject: [PATCH] [frontend] Update graph for op fusion (#445) --- Co-authored-by: zhxzh-2001 <70198007+zhxzh-2001@users.noreply.github.com> --- examples/BuddyLeNet/buddy-lenet-import.py | 2 +- examples/BuddyLlama/import-llama2.py | 2 +- frontend/Python/frontend.py | 45 +++++---- frontend/Python/graph/graph.py | 104 +++++++++++++++++--- frontend/Python/graph/operation.py | 6 ++ frontend/Python/graph/transform/__init__.py | 2 +- frontend/Python/graph/transform/fuse_ops.py | 89 ++++++++++++++++- frontend/Python/ops/linalg.py | 21 ++++ tests/Python/test_permute_matmul_fusion.py | 40 ++++++++ 9 files changed, 270 insertions(+), 41 deletions(-) create mode 100644 tests/Python/test_permute_matmul_fusion.py diff --git a/examples/BuddyLeNet/buddy-lenet-import.py b/examples/BuddyLeNet/buddy-lenet-import.py index c787061a55..e4f85f905f 100644 --- a/examples/BuddyLeNet/buddy-lenet-import.py +++ b/examples/BuddyLeNet/buddy-lenet-import.py @@ -26,7 +26,7 @@ from buddy.compiler.frontend import DynamoCompiler from buddy.compiler.graph import GraphDriver -from buddy.compiler.graph.transform import simply_fuse +from buddy.compiler.graph.transform import simply_fuse, apply_classic_fusion from buddy.compiler.ops import tosa from model import LeNet diff --git a/examples/BuddyLlama/import-llama2.py b/examples/BuddyLlama/import-llama2.py index d893ee87f6..af89329e62 100644 --- a/examples/BuddyLlama/import-llama2.py +++ b/examples/BuddyLlama/import-llama2.py @@ -28,7 +28,7 @@ from buddy.compiler.frontend import DynamoCompiler from buddy.compiler.ops import tosa from buddy.compiler.graph import GraphDriver -from buddy.compiler.graph.transform import simply_fuse +from buddy.compiler.graph.transform import simply_fuse, apply_classic_fusion # Retrieve the LLaMA model path from environment variables. model_path = os.environ.get("LLAMA_MODEL_PATH") diff --git a/frontend/Python/frontend.py b/frontend/Python/frontend.py index f5a17a1c31..c11843eab7 100644 --- a/frontend/Python/frontend.py +++ b/frontend/Python/frontend.py @@ -165,9 +165,9 @@ def __init__( "cos.default": CosOp, "sin.default": SinOp, "argmax.default": ArgMaxOp, - "split.Tensor":SplitOp, - "max.default":MaxOp, - "gt.Scalar":GtOp, + "split.Tensor": SplitOp, + "max.default": MaxOp, + "gt.Scalar": GtOp, "_scaled_dot_product_flash_attention_for_cpu.default": ScaledDotProductFlashAttentionForCpuOp, "ge.Scalar": GeOp, "gt.Tensor": GreaterThanOp, @@ -237,7 +237,9 @@ def _create_node( buddy_node.add_argument(str(input_arg)) buddy_node.add_parent(str(input_arg)) elif isinstance(input_arg, torch.dtype): - buddy_node.add_argument(self._torch_dtype_translate(str(input_arg))) + buddy_node.add_argument( + self._torch_dtype_translate(str(input_arg)) + ) else: buddy_node.add_argument(input_arg) for user in node_users: @@ -294,7 +296,7 @@ def _compiler(_gm: torch.fx.GraphModule, _inputs: List[torch.Tensor]): nonlocal params_flat func_inputs = [] for i in inputs_pos: - # for inp in _inputs[len(params_flat) :]: + # for inp in _inputs[len(params_flat) :]: inp = _inputs[i] inp_shape = inp.shape inp_dtype = self._torch_dtype_translate(str(inp.dtype)) @@ -308,7 +310,7 @@ def _compiler(_gm: torch.fx.GraphModule, _inputs: List[torch.Tensor]): fake_params, self._ops_registry, self._func_name, - self._verbose + self._verbose, ) param_nodes = [] buffers_nodes = [] @@ -344,10 +346,7 @@ def _compiler(_gm: torch.fx.GraphModule, _inputs: List[torch.Tensor]): elif gm_node.op == "output": buddy_node = self._create_node( - gm_node.op, - gm_node.name, - gm_node.args, - node_users + gm_node.op, gm_node.name, gm_node.args, node_users ) elif gm_node.target is operator.getitem: @@ -367,7 +366,11 @@ def _compiler(_gm: torch.fx.GraphModule, _inputs: List[torch.Tensor]): tensor_meta = gm_node.meta.get("tensor_meta") val = gm_node.meta.get("val") # num_returns = len(gm_node.target._schema.returns) - num_returns = len(val) if isinstance(val, list) else len(gm_node.target._schema.returns) + num_returns = ( + len(val) + if isinstance(val, list) + else len(gm_node.target._schema.returns) + ) if num_returns == 1: node_dtype = self._torch_dtype_translate( str(tensor_meta.dtype) @@ -477,7 +480,7 @@ def get_lib_extension(): def cast_c_ptr(outdata_ptr, memref_ptr): """ - Casts a C pointer (`outdata_ptr`) to the type of another C pointer + Casts a C pointer (`outdata_ptr`) to the type of another C pointer (`memref_ptr`). Args: @@ -488,14 +491,14 @@ def cast_c_ptr(outdata_ptr, memref_ptr): Returns: ctypes.POINTER - A new C pointer with the type of `memref_ptr`, representing the + A new C pointer with the type of `memref_ptr`, representing the same memory location as `outdata_ptr`. Example: outdata = ctypes.pointer(ctypes.c_int()) memref = ctypes.pointer(ctypes.c_float()) casted_ptr = cast_c_ptr(outdata, memref) - # Now `casted_ptr` points to the same memory location as `outdata`, + # Now `casted_ptr` points to the same memory location as `outdata`, but with the type of `memref`. """ outdata_addr = ctypes.addressof(outdata_ptr.contents) @@ -504,15 +507,15 @@ def cast_c_ptr(outdata_ptr, memref_ptr): def move_c_ptr(outdata_ptr, memref_ptr): """ - Moves a C pointer (`outdata_ptr`) to the next element in memory, - based on the size of the referenced type in another C pointer + Moves a C pointer (`outdata_ptr`) to the next element in memory, + based on the size of the referenced type in another C pointer (`memref_ptr`). Args: outdata_ptr: ctypes.POINTER The C pointer whose position needs to be moved. memref_ptr: ctypes.POINTER - The reference C pointer whose type determines the size of each + The reference C pointer whose type determines the size of each element for the move. Returns: @@ -535,7 +538,7 @@ def exec_buddy_graph(*args): Returns: List[torch.Tensor] - The result of executing the graph, represented as a list of + The result of executing the graph, represented as a list of output tensors. """ # A list of ctypes pointers representing memory references for input @@ -548,13 +551,13 @@ def exec_buddy_graph(*args): ) for tensor in args ] - # A list of ctypes pointers representing memory references for + # A list of ctypes pointers representing memory references for # output tensors. output_memref = [ ctypes.pointer(ctypes.pointer(graph._output_descriptor())) ] args_memref = output_memref + input_memref - # Invoke the graph's function using the provided execution engine + # Invoke the graph's function using the provided execution engine # and memory references ee.invoke(graph._func_name, *args_memref) @@ -571,7 +574,7 @@ def exec_buddy_graph(*args): # Move to the next element in memory based on the size of the # current output type outdata_ptr = move_c_ptr(outdata_ptr, output_ptr[0]) - # Convert each NumPy array to a PyTorch tensor and return the list + # Convert each NumPy array to a PyTorch tensor and return the list # of tensors return [torch.from_numpy(tensor) for tensor in output_tensor] diff --git a/frontend/Python/graph/graph.py b/frontend/Python/graph/graph.py index 88c6a85df6..751ddb0066 100644 --- a/frontend/Python/graph/graph.py +++ b/frontend/Python/graph/graph.py @@ -105,7 +105,7 @@ def __init__( fake_params: List[TensorMeta], ops_registry: dict, func_name: str, - verbose=False + verbose=False, ) -> None: """ Initializes the Graph. @@ -164,6 +164,78 @@ def add_node(self, node: Op): self._body.append(node) self.node_table[node.name] = node + def check_delete_node(self, node: Op) -> bool: + """ + Determines if a node exists in the graph and has no child nodes. + + Args: + node (Op): The operation node to check for deletion eligibility. + + Returns: + bool: True if the node exists in the graph and has no children. + """ + if not (node.name in self.node_table): + raise KeyError("node{0} not in graph".format(node.name)) + + if len(node._children) == 0: + return True + return False + + def delete_node(self, node: Op, parents: List[Op]): + """ + Removes a node from the graph and updates its parent nodes accordingly. + + Args: + node (Op): The operation node to be deleted from the graph. + parents (List[Op]): A list of parent operation nodes that reference the node to be deleted. + + Returns: + None + """ + for i in parents: + i._children.remove(node.name) + node.args.clear() + node.kwargs.clear() + node._children.clear() + self._body.remove(node) + self.node_table.pop(node.name) + + def displace_node(self, node: Op, newnode: Op): + """ + Replaces an existing node with a new node in the graph. + + Args: + node (Op): The operation node to be replaced. + newnode (Op): The new operation node that will replace the existing node. + + Returns: + None + """ + newnode._arguments = node.args + newnode._keyword_arguments = node.kwargs + newnode._tensor_meta = node.tensor_meta + newnode._op_type = node._op_type + + for i in node._children: + newnode.add_children(i) + users = [self.node_table[i] for i in node._children] + for user in users: + if node.name in user._parents: + user._parents[user._parents.index(node.name)] = newnode.name + user.args[user.args.index(node.name)] = newnode.name + node._children.clear() + # deal with parents+args + for i in node._parents: + newnode.add_parent(i) + parents = [self.node_table[i] for i in node._parents] + for parent in parents: + parent._children[parent._children.index(node.name)] = newnode.name + node._parents.clear() + # update node table + self._body[self._body.index(node)] = newnode + self.node_table.pop(node.name) + self.node_table[newnode.name] = newnode + def init_op_group(self): """ Initializes operation groups within the graph. @@ -239,7 +311,7 @@ def lower_to_top_level_ir(self): self._inputs, self._func_name, self._ops_registry, - verbose=self._verbose + verbose=self._verbose, ) self._imported_module = fx_importer.import_graph() outputs = fx_importer.get_output_nodes() @@ -352,7 +424,7 @@ def __init__( func_name: str, ops_registry: dict, do_param_pack: bool = False, - verbose=False + verbose=False, ): """ Initializes the buddy Graph importer. @@ -475,27 +547,27 @@ def generated_func(*args): elif isinstance(node, PlaceholderOp): self._import_placeholder(node, args_list) elif isinstance(node, GetItemOp): - self._symbol_table[ - (str(node.name), 0) - ] = self._symbol_table[ - (str(node.args[0]), node.args[1]) - ] + self._symbol_table[(str(node.name), 0)] = ( + self._symbol_table[ + (str(node.args[0]), node.args[1]) + ] + ) else: self._import_op(node) new_ops = [op for op in func_op.body.blocks[0].operations] if self._verbose: - print('='*20 + "Graph Node" + "="*20) + print("=" * 20 + "Graph Node" + "=" * 20) print("Node: " + node.name) print("Type: " + str(node._op_type)) print("Arguments: " + str(node.args)) print("Parents: " + str(node._parents)) print("Children: " + str(node._children)) - print('-'*20 + "MLIR OPS" + '-'*20) + print("-" * 20 + "MLIR OPS" + "-" * 20) for op in new_ops: if op not in old_ops: print(op) print("") - + return self._symbol_table.get(("output", 0)) return self._module @@ -544,11 +616,11 @@ def generated_func(*args): elif isinstance(node, PlaceholderOp): self._import_placeholder(node, args_list) elif isinstance(node, GetItemOp): - self._symbol_table[ - (str(node.name), 0) - ] = self._symbol_table[ - (str(node.args[0]), node.args[1]) - ] + self._symbol_table[(str(node.name), 0)] = ( + self._symbol_table[ + (str(node.args[0]), node.args[1]) + ] + ) else: self._import_op(node) diff --git a/frontend/Python/graph/operation.py b/frontend/Python/graph/operation.py index c1a7b09746..218752abc0 100644 --- a/frontend/Python/graph/operation.py +++ b/frontend/Python/graph/operation.py @@ -154,6 +154,12 @@ def __init__(self) -> None: self._op_type = OpType.ReduceType +class TransposeMatmulFusedOp(Op): + def __init__(self) -> None: + super().__init__() + self._op_type = OpType.ReduceType + + class GetItemOp(Op): def __init__(self) -> None: super().__init__() diff --git a/frontend/Python/graph/transform/__init__.py b/frontend/Python/graph/transform/__init__.py index d91e0d06b2..a1e294f8cd 100644 --- a/frontend/Python/graph/transform/__init__.py +++ b/frontend/Python/graph/transform/__init__.py @@ -18,5 +18,5 @@ # # ===--------------------------------------------------------------------------- -from .fuse_ops import simply_fuse +from .fuse_ops import simply_fuse, apply_classic_fusion from .useless_op_eliminate import maxpool2d_simplify diff --git a/frontend/Python/graph/transform/fuse_ops.py b/frontend/Python/graph/transform/fuse_ops.py index ac7d34c99c..992168aecc 100644 --- a/frontend/Python/graph/transform/fuse_ops.py +++ b/frontend/Python/graph/transform/fuse_ops.py @@ -21,12 +21,99 @@ from .. import Graph from ..operation import * from .. import DeviceType +from torch.fx.immutable_collections import immutable_list + +classicfuse_register = {"transpose_matmul_fusion": TransposeMatmulFusedOp} # TODO: classify op type for op fusion # OP_TYPE_FUSABLE = [OpType.BroadcastType, OpType.ElementwiseType, OpType.ReshapeType] # OP_TYPE_UNFUSABLE = [OpType.Unfusable, OpType.ConcatType] # OP_TYPE_FUSABLE_BY_SPECIFIC_PASS = [] -# ANCHOR_OP_TYPE = [] +# ANCHOR_OP_TYPE = [] + + +def classic_fuse_check(graph: Graph): + """ + Function to identifies and fuses PermuteOp operations with preceding + MatmulOp operations in a computation graph to optimize performance. + + Args: + graph (Graph): The computation graph to analyze and optimize. + + Returns: + None + """ + for op in graph.body: + pattern = None + if isinstance(op, MatmulOp): + parentop = [graph.node_table[str(i)] for i in op._parents] + for target in parentop: + if isinstance(target, PermuteOp) and target.args[ + 1 + ] == immutable_list([1, 0]): + pattern = target, parentop, "transpose_matmul_fusion" + if pattern: + transpose_matmul_fusion( + graph, op, pattern[0], pattern[1], pattern[2] + ) + + +def transpose_matmul_fusion( + graph: Graph, node, target: Op, parents: List[Op], pattern: str +): + """ + Function to fuse some typical operations into one operation. + Such as transpose + matmul + Args: + - graph (Graph): The input graph to be simplified. + - node (Op): The operation to be fused. + - target (Op): The target operation to be fused. + - parents (List[Op]): The parents of the node to be fused. + - pattern (str): The pattern of the fusion. + Returns: + - None: Modifies the input graph in place. + """ + fused_op = classicfuse_register.get(pattern)() + # matmulop -> fusedmatmulopnode + fused_op.name = "fused" + node.name + graph.displace_node(node, fused_op) + fused_op.args.pop(fused_op.args.index(target.name)) + fused_op._parents.pop(fused_op._parents.index(target.name)) + fused_op.args.extend(target.args) + + fused_op._parents.extend(target._parents) + targets_parent = [graph.node_table[i] for i in target._parents] + for i in targets_parent: + i.add_children(fused_op.name) + target._children.pop(target._children.index(fused_op.name)) + + if graph.check_delete_node(target): + graph.delete_node(target, targets_parent) + + +def apply_classic_fusion(graph: Graph): + """ + Function to fuse some typical operations into one operation and fuse + all operations into one graph. + + Args: + - graph (Graph): The input graph to be simplified. + + Returns: + - None: Modifies the input graph in place. + """ + new_op_group = [] + device = DeviceType.UNKNOW + # Run the first round of op fusion + classic_fuse_check(graph) + for op in graph.body: + if isinstance(op, PlaceholderOp): + continue + new_op_group.append(op) + graph.op_groups = {} + graph.op_groups["subgraph0"] = new_op_group + graph.group_map_device = {"subgraph0": device} + def simply_fuse(graph: Graph): """ diff --git a/frontend/Python/ops/linalg.py b/frontend/Python/ops/linalg.py index ec6c827e6c..6bd3a2f318 100644 --- a/frontend/Python/ops/linalg.py +++ b/frontend/Python/ops/linalg.py @@ -1171,6 +1171,26 @@ def matmul_op( return op +def matmul_transpose_b_op( + node: TransposeMatmulFusedOp, + symbol_table: Dict[Tuple[str, int], ir.Operation], +): + input1 = symbol_table.get((str(node.args[0]), 0)) + input2 = symbol_table.get((str(node.args[1]), 0)) + + if input1 is None or input2 is None: + return + output_shape = list(node.tensor_meta["shape"]) + dtype = node.tensor_meta["dtype"] + mlir_dtype = mlir_element_type_get(dtype) + tensor_type = ir.RankedTensorType.get(output_shape, mlir_dtype) + element = mlir_element_attr_get(dtype, 0.0) + attr = ir.DenseElementsAttr.get_splat(tensor_type, element) + result_buffer = arith.ConstantOp(tensor_type, attr).result + op = linalg.matmul_transpose_b(input1, input2, outs=[result_buffer]) + return op + + def transpose_op( node: TransposeOp, symbol_table: Dict[Tuple[str, int], ir.Operation], @@ -2344,6 +2364,7 @@ def unsafe_index_op( ops_registry = { "MatmulOp": matmul_op, + "TransposeMatmulFusedOp": matmul_transpose_b_op, "ArangeOp": arange_op, "UnsqueezeOp": unsqueeze_op, "ViewOp": view_op, diff --git a/tests/Python/test_permute_matmul_fusion.py b/tests/Python/test_permute_matmul_fusion.py new file mode 100644 index 0000000000..70f120e5a4 --- /dev/null +++ b/tests/Python/test_permute_matmul_fusion.py @@ -0,0 +1,40 @@ +# RUN: %PYTHON %s 2>&1 | FileCheck %s + +import torch +import torch._dynamo as dynamo +from torch._inductor.decomposition import decompositions as inductor_decomp +from torch._functorch.aot_autograd import aot_autograd_decompositions + +from buddy.compiler.frontend import DynamoCompiler +from buddy.compiler.ops import linalg +from buddy.compiler.graph.transform import simply_fuse, apply_classic_fusion + +def foo(m1, m2,map): + tmp = torch.ops.aten.permute(m2,map) + return torch.matmul(m1,tmp) + +m1 = torch.ones([3, 4], dtype=torch.float32) +m2 = torch.ones([3, 4], dtype=torch.float32) +map = (1,0) +# Initialize the dynamo compiler. +dynamo_compiler = DynamoCompiler( + primary_registry=linalg.ops_registry, + aot_autograd_decomposition=aot_autograd_decompositions, +) + +graphs = dynamo_compiler.importer(foo, m1,m2,map) +assert len(graphs) == 1 +graph = graphs[0] +pattern_list = [apply_classic_fusion] +graphs[0].fuse_ops(pattern_list) + +graph.lower_to_top_level_ir() +print(graph._imported_module) + +# CHECK: module { +# CHECK-LABEL: func.func @forward +# CHECK: %{{.*}} = arith.constant +# CHECK: %{{.*}} = linalg.matmul_transpose_b +# CHECK: return %{{.*}} +# CHECK: } +# CHECK: }