From 98d6a54f94309365aa92a7d064949b7aaf06bb79 Mon Sep 17 00:00:00 2001 From: effrey-liu <2318266514@qq.com> Date: Sun, 1 Sep 2024 17:53:30 +0800 Subject: [PATCH] modify fuse_ops --- frontend/Python/graph/graph.py | 3 +- frontend/Python/graph/graph_driver.py | 60 ++++++++++++++++++--- frontend/Python/graph/transform/fuse_ops.py | 10 ++++ 3 files changed, 65 insertions(+), 8 deletions(-) diff --git a/frontend/Python/graph/graph.py b/frontend/Python/graph/graph.py index eb78c0ff33..99b554024e 100644 --- a/frontend/Python/graph/graph.py +++ b/frontend/Python/graph/graph.py @@ -173,7 +173,8 @@ def init_op_group(self): if isinstance(op, PlaceholderOp): continue group = [op] - subgraph_name = "subgraph{}".format(i) + # subgraph_name = "subgraph{}".format(i) + subgraph_name = op.name self.group_map_device[subgraph_name] = DeviceType.UNKNOW self.op_groups[subgraph_name] = group diff --git a/frontend/Python/graph/graph_driver.py b/frontend/Python/graph/graph_driver.py index 50a8869d5a..dbaa7f2e26 100644 --- a/frontend/Python/graph/graph_driver.py +++ b/frontend/Python/graph/graph_driver.py @@ -40,6 +40,7 @@ class GraphDriver: - _subgraphs_outputs (dict): A dictionary mapping subgraph names to their output op's result. """ + def __init__(self, graph: Graph) -> None: """ Initialize the GraphDriver object with a given computational graph. @@ -94,7 +95,7 @@ def build_subgraph_by_group(self): if isinstance(node, OutputOp): for arg in node.args: output_node.append(arg) - + # Identify outputs for each subgraph for subgraph_name in self._graph.op_groups.keys(): subgraphs_outputs[subgraph_name] = [] @@ -127,11 +128,11 @@ def build_subgraph_by_group(self): if inp in node._parents: placeholder_node.add_children(op.name) subgraph_body.append(placeholder_node) - + # Add operations to subgraph body for op in self._graph.op_groups[subgraph_name]: subgraph_body.append(op) - + # Construct output node output_node = OutputOp() output_node.name = "output" @@ -189,12 +190,12 @@ def construct_main_graph(self, do_param_pack=False): self._graph.node_table[output].tensor_meta["dtype"] ) main_graph.body.append(func_node) - + # Adding placeholder operations from the original graph for op in self._graph.body: if isinstance(op, PlaceholderOp): main_graph.body.append(op) - + # TODO: analysis topology order to sort subgraph call. if len(self._subgraphs) == 1: # Adding CallOp to invoke the single subgraph @@ -215,18 +216,63 @@ def construct_main_graph(self, do_param_pack=False): # Adding GetItemOps to retrieve individual output tensors output_node = OutputOp() - for i, output in enumerate(list(self._subgraphs_outputs.values())[0]): + for i, output in enumerate( + list(self._subgraphs_outputs.values())[0] + ): getitem_node = GetItemOp() getitem_node.add_argument(call_node.name) getitem_node.add_argument(i) getitem_node.name = "getitem{}".format(i) output_node.add_argument(getitem_node.name) main_graph.body.append(getitem_node) - + # Marking the final output of the main graph output_node.name = "output" main_graph.body.append(output_node) + # Importing the main graph + with ir.Location.unknown(ir.Context()): + main_importer = GraphImporter( + main_graph.body, + main_graph._fake_params, + main_graph._inputs, + main_graph._func_name, + main_graph._ops_registry, + do_param_pack, + ) + return main_importer.import_main_graph() + else: + for i in range(len(self._subgraphs) - 1): + # Adding CallOp to invoke the single subgraph + call_node = CallOp() + call_node.call_func_name = list(self._subgraphs.keys())[i] + call_node.name = call_node.call_func_name + call_node.tensor_meta = {"shape": [], "dtype": []} + for inp in list(self._subgraphs_inputs.values())[i]: + call_node.add_argument(inp) + for output in list(self._subgraphs_outputs.values())[i]: + call_node.tensor_meta["shape"].append( + self._graph.node_table[output].tensor_meta["shape"] + ) + call_node.tensor_meta["dtype"].append( + self._graph.node_table[output].tensor_meta["dtype"] + ) + main_graph.body.append(call_node) + + # Adding GetItemOps to retrieve individual output tensors + output_node = OutputOp() + for i, output in enumerate(list(self._subgraphs_outputs.values())[0]): + getitem_node = GetItemOp() + getitem_node.add_argument(call_node.name) + getitem_node.add_argument(i) + getitem_node.name = "getitem{}".format(i) + output_node.add_argument(getitem_node.name) + main_graph.body.append(getitem_node) + + # Marking the final output of the main graph + output_node.name = "output" + main_graph.body.append(output_node) + # Importing the main graph with ir.Location.unknown(ir.Context()): main_importer = GraphImporter( diff --git a/frontend/Python/graph/transform/fuse_ops.py b/frontend/Python/graph/transform/fuse_ops.py index 88b046ef9c..068397d7c3 100644 --- a/frontend/Python/graph/transform/fuse_ops.py +++ b/frontend/Python/graph/transform/fuse_ops.py @@ -326,6 +326,8 @@ def MergeFromTo(self, child, parent): if child == parent: return parent.num_nodes += child.num_nodes + self._graph.op_groups[parent.name][:0] = self._graph.op_groups[child.name] + del self._graph.op_groups[child.name] child.parent = parent if child.master_ref is not None: assert parent.master_ref is None @@ -386,6 +388,14 @@ def fcond0(kind, issink): node.name, node.num_nodes, node.master_ref.name ) ) + if node.master_ref.name not in self._graph.op_groups: + self._graph.op_groups[node.master_ref.name] = [] + self._graph.group_map_device = { + node.master_ref.name: DeviceType.UNKNOW + } + self._graph.op_groups[node.master_ref.name].append( + self._graph.node_table[node.name] + ) def my_fuse_ops_test(graph: Graph):