Skip to content

Commit

Permalink
modify fuse_ops
Browse files Browse the repository at this point in the history
  • Loading branch information
effrey-liu committed Sep 1, 2024
1 parent 52f923c commit 98d6a54
Show file tree
Hide file tree
Showing 3 changed files with 65 additions and 8 deletions.
3 changes: 2 additions & 1 deletion frontend/Python/graph/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,7 +173,8 @@ def init_op_group(self):
if isinstance(op, PlaceholderOp):
continue
group = [op]
subgraph_name = "subgraph{}".format(i)
# subgraph_name = "subgraph{}".format(i)
subgraph_name = op.name
self.group_map_device[subgraph_name] = DeviceType.UNKNOW
self.op_groups[subgraph_name] = group

Expand Down
60 changes: 53 additions & 7 deletions frontend/Python/graph/graph_driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ class GraphDriver:
- _subgraphs_outputs (dict): A dictionary mapping subgraph names to their
output op's result.
"""

def __init__(self, graph: Graph) -> None:
"""
Initialize the GraphDriver object with a given computational graph.
Expand Down Expand Up @@ -94,7 +95,7 @@ def build_subgraph_by_group(self):
if isinstance(node, OutputOp):
for arg in node.args:
output_node.append(arg)

# Identify outputs for each subgraph
for subgraph_name in self._graph.op_groups.keys():
subgraphs_outputs[subgraph_name] = []
Expand Down Expand Up @@ -127,11 +128,11 @@ def build_subgraph_by_group(self):
if inp in node._parents:
placeholder_node.add_children(op.name)
subgraph_body.append(placeholder_node)

# Add operations to subgraph body
for op in self._graph.op_groups[subgraph_name]:
subgraph_body.append(op)

# Construct output node
output_node = OutputOp()
output_node.name = "output"
Expand Down Expand Up @@ -189,12 +190,12 @@ def construct_main_graph(self, do_param_pack=False):
self._graph.node_table[output].tensor_meta["dtype"]
)
main_graph.body.append(func_node)

# Adding placeholder operations from the original graph
for op in self._graph.body:
if isinstance(op, PlaceholderOp):
main_graph.body.append(op)

# TODO: analysis topology order to sort subgraph call.
if len(self._subgraphs) == 1:
# Adding CallOp to invoke the single subgraph
Expand All @@ -215,18 +216,63 @@ def construct_main_graph(self, do_param_pack=False):

# Adding GetItemOps to retrieve individual output tensors
output_node = OutputOp()
for i, output in enumerate(list(self._subgraphs_outputs.values())[0]):
for i, output in enumerate(
list(self._subgraphs_outputs.values())[0]
):
getitem_node = GetItemOp()
getitem_node.add_argument(call_node.name)
getitem_node.add_argument(i)
getitem_node.name = "getitem{}".format(i)
output_node.add_argument(getitem_node.name)
main_graph.body.append(getitem_node)

# Marking the final output of the main graph
output_node.name = "output"
main_graph.body.append(output_node)

# Importing the main graph
with ir.Location.unknown(ir.Context()):
main_importer = GraphImporter(
main_graph.body,
main_graph._fake_params,
main_graph._inputs,
main_graph._func_name,
main_graph._ops_registry,
do_param_pack,
)
return main_importer.import_main_graph()
else:
for i in range(len(self._subgraphs) - 1):
# Adding CallOp to invoke the single subgraph
call_node = CallOp()
call_node.call_func_name = list(self._subgraphs.keys())[i]
call_node.name = call_node.call_func_name
call_node.tensor_meta = {"shape": [], "dtype": []}
for inp in list(self._subgraphs_inputs.values())[i]:
call_node.add_argument(inp)
for output in list(self._subgraphs_outputs.values())[i]:
call_node.tensor_meta["shape"].append(
self._graph.node_table[output].tensor_meta["shape"]
)
call_node.tensor_meta["dtype"].append(
self._graph.node_table[output].tensor_meta["dtype"]
)
main_graph.body.append(call_node)

# Adding GetItemOps to retrieve individual output tensors
output_node = OutputOp()
for i, output in enumerate(list(self._subgraphs_outputs.values())[0]):
getitem_node = GetItemOp()
getitem_node.add_argument(call_node.name)
getitem_node.add_argument(i)
getitem_node.name = "getitem{}".format(i)
output_node.add_argument(getitem_node.name)
main_graph.body.append(getitem_node)

# Marking the final output of the main graph
output_node.name = "output"
main_graph.body.append(output_node)

# Importing the main graph
with ir.Location.unknown(ir.Context()):
main_importer = GraphImporter(
Expand Down
10 changes: 10 additions & 0 deletions frontend/Python/graph/transform/fuse_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -326,6 +326,8 @@ def MergeFromTo(self, child, parent):
if child == parent:
return
parent.num_nodes += child.num_nodes
self._graph.op_groups[parent.name][:0] = self._graph.op_groups[child.name]
del self._graph.op_groups[child.name]
child.parent = parent
if child.master_ref is not None:
assert parent.master_ref is None
Expand Down Expand Up @@ -386,6 +388,14 @@ def fcond0(kind, issink):
node.name, node.num_nodes, node.master_ref.name
)
)
if node.master_ref.name not in self._graph.op_groups:
self._graph.op_groups[node.master_ref.name] = []
self._graph.group_map_device = {
node.master_ref.name: DeviceType.UNKNOW
}
self._graph.op_groups[node.master_ref.name].append(
self._graph.node_table[node.name]
)


def my_fuse_ops_test(graph: Graph):
Expand Down

0 comments on commit 98d6a54

Please sign in to comment.