Skip to content

Commit

Permalink
[frontend] Update graph for op fusion (buddy-compiler#445)
Browse files Browse the repository at this point in the history
---
Co-authored-by: zhxzh-2001 <[email protected]>
  • Loading branch information
WuXintong123 authored and asdf1113 committed Mar 3, 2025
1 parent fdd4b1e commit 731e1b8
Show file tree
Hide file tree
Showing 9 changed files with 270 additions and 41 deletions.
2 changes: 1 addition & 1 deletion examples/BuddyLeNet/buddy-lenet-import.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@

from buddy.compiler.frontend import DynamoCompiler
from buddy.compiler.graph import GraphDriver
from buddy.compiler.graph.transform import simply_fuse
from buddy.compiler.graph.transform import simply_fuse, apply_classic_fusion
from buddy.compiler.ops import tosa
from model import LeNet

Expand Down
2 changes: 1 addition & 1 deletion examples/BuddyLlama/import-llama2.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
from buddy.compiler.frontend import DynamoCompiler
from buddy.compiler.ops import tosa
from buddy.compiler.graph import GraphDriver
from buddy.compiler.graph.transform import simply_fuse
from buddy.compiler.graph.transform import simply_fuse, apply_classic_fusion

# Retrieve the LLaMA model path from environment variables.
model_path = os.environ.get("LLAMA_MODEL_PATH")
Expand Down
45 changes: 24 additions & 21 deletions frontend/Python/frontend.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,9 +165,9 @@ def __init__(
"cos.default": CosOp,
"sin.default": SinOp,
"argmax.default": ArgMaxOp,
"split.Tensor":SplitOp,
"max.default":MaxOp,
"gt.Scalar":GtOp,
"split.Tensor": SplitOp,
"max.default": MaxOp,
"gt.Scalar": GtOp,
"_scaled_dot_product_flash_attention_for_cpu.default": ScaledDotProductFlashAttentionForCpuOp,
"ge.Scalar": GeOp,
"gt.Tensor": GreaterThanOp,
Expand Down Expand Up @@ -237,7 +237,9 @@ def _create_node(
buddy_node.add_argument(str(input_arg))
buddy_node.add_parent(str(input_arg))
elif isinstance(input_arg, torch.dtype):
buddy_node.add_argument(self._torch_dtype_translate(str(input_arg)))
buddy_node.add_argument(
self._torch_dtype_translate(str(input_arg))
)
else:
buddy_node.add_argument(input_arg)
for user in node_users:
Expand Down Expand Up @@ -294,7 +296,7 @@ def _compiler(_gm: torch.fx.GraphModule, _inputs: List[torch.Tensor]):
nonlocal params_flat
func_inputs = []
for i in inputs_pos:
# for inp in _inputs[len(params_flat) :]:
# for inp in _inputs[len(params_flat) :]:
inp = _inputs[i]
inp_shape = inp.shape
inp_dtype = self._torch_dtype_translate(str(inp.dtype))
Expand All @@ -308,7 +310,7 @@ def _compiler(_gm: torch.fx.GraphModule, _inputs: List[torch.Tensor]):
fake_params,
self._ops_registry,
self._func_name,
self._verbose
self._verbose,
)
param_nodes = []
buffers_nodes = []
Expand Down Expand Up @@ -344,10 +346,7 @@ def _compiler(_gm: torch.fx.GraphModule, _inputs: List[torch.Tensor]):

elif gm_node.op == "output":
buddy_node = self._create_node(
gm_node.op,
gm_node.name,
gm_node.args,
node_users
gm_node.op, gm_node.name, gm_node.args, node_users
)

elif gm_node.target is operator.getitem:
Expand All @@ -367,7 +366,11 @@ def _compiler(_gm: torch.fx.GraphModule, _inputs: List[torch.Tensor]):
tensor_meta = gm_node.meta.get("tensor_meta")
val = gm_node.meta.get("val")
# num_returns = len(gm_node.target._schema.returns)
num_returns = len(val) if isinstance(val, list) else len(gm_node.target._schema.returns)
num_returns = (
len(val)
if isinstance(val, list)
else len(gm_node.target._schema.returns)
)
if num_returns == 1:
node_dtype = self._torch_dtype_translate(
str(tensor_meta.dtype)
Expand Down Expand Up @@ -477,7 +480,7 @@ def get_lib_extension():

def cast_c_ptr(outdata_ptr, memref_ptr):
"""
Casts a C pointer (`outdata_ptr`) to the type of another C pointer
Casts a C pointer (`outdata_ptr`) to the type of another C pointer
(`memref_ptr`).
Args:
Expand All @@ -488,14 +491,14 @@ def cast_c_ptr(outdata_ptr, memref_ptr):
Returns:
ctypes.POINTER
A new C pointer with the type of `memref_ptr`, representing the
A new C pointer with the type of `memref_ptr`, representing the
same memory location as `outdata_ptr`.
Example:
outdata = ctypes.pointer(ctypes.c_int())
memref = ctypes.pointer(ctypes.c_float())
casted_ptr = cast_c_ptr(outdata, memref)
# Now `casted_ptr` points to the same memory location as `outdata`,
# Now `casted_ptr` points to the same memory location as `outdata`,
but with the type of `memref`.
"""
outdata_addr = ctypes.addressof(outdata_ptr.contents)
Expand All @@ -504,15 +507,15 @@ def cast_c_ptr(outdata_ptr, memref_ptr):

def move_c_ptr(outdata_ptr, memref_ptr):
"""
Moves a C pointer (`outdata_ptr`) to the next element in memory,
based on the size of the referenced type in another C pointer
Moves a C pointer (`outdata_ptr`) to the next element in memory,
based on the size of the referenced type in another C pointer
(`memref_ptr`).
Args:
outdata_ptr: ctypes.POINTER
The C pointer whose position needs to be moved.
memref_ptr: ctypes.POINTER
The reference C pointer whose type determines the size of each
The reference C pointer whose type determines the size of each
element for the move.
Returns:
Expand All @@ -535,7 +538,7 @@ def exec_buddy_graph(*args):
Returns:
List[torch.Tensor]
The result of executing the graph, represented as a list of
The result of executing the graph, represented as a list of
output tensors.
"""
# A list of ctypes pointers representing memory references for input
Expand All @@ -548,13 +551,13 @@ def exec_buddy_graph(*args):
)
for tensor in args
]
# A list of ctypes pointers representing memory references for
# A list of ctypes pointers representing memory references for
# output tensors.
output_memref = [
ctypes.pointer(ctypes.pointer(graph._output_descriptor()))
]
args_memref = output_memref + input_memref
# Invoke the graph's function using the provided execution engine
# Invoke the graph's function using the provided execution engine
# and memory references
ee.invoke(graph._func_name, *args_memref)

Expand All @@ -571,7 +574,7 @@ def exec_buddy_graph(*args):
# Move to the next element in memory based on the size of the
# current output type
outdata_ptr = move_c_ptr(outdata_ptr, output_ptr[0])
# Convert each NumPy array to a PyTorch tensor and return the list
# Convert each NumPy array to a PyTorch tensor and return the list
# of tensors
return [torch.from_numpy(tensor) for tensor in output_tensor]

Expand Down
104 changes: 88 additions & 16 deletions frontend/Python/graph/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ def __init__(
fake_params: List[TensorMeta],
ops_registry: dict,
func_name: str,
verbose=False
verbose=False,
) -> None:
"""
Initializes the Graph.
Expand Down Expand Up @@ -164,6 +164,78 @@ def add_node(self, node: Op):
self._body.append(node)
self.node_table[node.name] = node

def check_delete_node(self, node: Op) -> bool:
"""
Determines if a node exists in the graph and has no child nodes.
Args:
node (Op): The operation node to check for deletion eligibility.
Returns:
bool: True if the node exists in the graph and has no children.
"""
if not (node.name in self.node_table):
raise KeyError("node{0} not in graph".format(node.name))

if len(node._children) == 0:
return True
return False

def delete_node(self, node: Op, parents: List[Op]):
"""
Removes a node from the graph and updates its parent nodes accordingly.
Args:
node (Op): The operation node to be deleted from the graph.
parents (List[Op]): A list of parent operation nodes that reference the node to be deleted.
Returns:
None
"""
for i in parents:
i._children.remove(node.name)
node.args.clear()
node.kwargs.clear()
node._children.clear()
self._body.remove(node)
self.node_table.pop(node.name)

def displace_node(self, node: Op, newnode: Op):
"""
Replaces an existing node with a new node in the graph.
Args:
node (Op): The operation node to be replaced.
newnode (Op): The new operation node that will replace the existing node.
Returns:
None
"""
newnode._arguments = node.args
newnode._keyword_arguments = node.kwargs
newnode._tensor_meta = node.tensor_meta
newnode._op_type = node._op_type

for i in node._children:
newnode.add_children(i)
users = [self.node_table[i] for i in node._children]
for user in users:
if node.name in user._parents:
user._parents[user._parents.index(node.name)] = newnode.name
user.args[user.args.index(node.name)] = newnode.name
node._children.clear()
# deal with parents+args
for i in node._parents:
newnode.add_parent(i)
parents = [self.node_table[i] for i in node._parents]
for parent in parents:
parent._children[parent._children.index(node.name)] = newnode.name
node._parents.clear()
# update node table
self._body[self._body.index(node)] = newnode
self.node_table.pop(node.name)
self.node_table[newnode.name] = newnode

def init_op_group(self):
"""
Initializes operation groups within the graph.
Expand Down Expand Up @@ -239,7 +311,7 @@ def lower_to_top_level_ir(self):
self._inputs,
self._func_name,
self._ops_registry,
verbose=self._verbose
verbose=self._verbose,
)
self._imported_module = fx_importer.import_graph()
outputs = fx_importer.get_output_nodes()
Expand Down Expand Up @@ -352,7 +424,7 @@ def __init__(
func_name: str,
ops_registry: dict,
do_param_pack: bool = False,
verbose=False
verbose=False,
):
"""
Initializes the buddy Graph importer.
Expand Down Expand Up @@ -475,27 +547,27 @@ def generated_func(*args):
elif isinstance(node, PlaceholderOp):
self._import_placeholder(node, args_list)
elif isinstance(node, GetItemOp):
self._symbol_table[
(str(node.name), 0)
] = self._symbol_table[
(str(node.args[0]), node.args[1])
]
self._symbol_table[(str(node.name), 0)] = (
self._symbol_table[
(str(node.args[0]), node.args[1])
]
)
else:
self._import_op(node)
new_ops = [op for op in func_op.body.blocks[0].operations]
if self._verbose:
print('='*20 + "Graph Node" + "="*20)
print("=" * 20 + "Graph Node" + "=" * 20)
print("Node: " + node.name)
print("Type: " + str(node._op_type))
print("Arguments: " + str(node.args))
print("Parents: " + str(node._parents))
print("Children: " + str(node._children))
print('-'*20 + "MLIR OPS" + '-'*20)
print("-" * 20 + "MLIR OPS" + "-" * 20)
for op in new_ops:
if op not in old_ops:
print(op)
print("")

return self._symbol_table.get(("output", 0))

return self._module
Expand Down Expand Up @@ -544,11 +616,11 @@ def generated_func(*args):
elif isinstance(node, PlaceholderOp):
self._import_placeholder(node, args_list)
elif isinstance(node, GetItemOp):
self._symbol_table[
(str(node.name), 0)
] = self._symbol_table[
(str(node.args[0]), node.args[1])
]
self._symbol_table[(str(node.name), 0)] = (
self._symbol_table[
(str(node.args[0]), node.args[1])
]
)
else:
self._import_op(node)

Expand Down
6 changes: 6 additions & 0 deletions frontend/Python/graph/operation.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,12 @@ def __init__(self) -> None:
self._op_type = OpType.ReduceType


class TransposeMatmulFusedOp(Op):
def __init__(self) -> None:
super().__init__()
self._op_type = OpType.ReduceType


class GetItemOp(Op):
def __init__(self) -> None:
super().__init__()
Expand Down
2 changes: 1 addition & 1 deletion frontend/Python/graph/transform/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,5 +18,5 @@
#
# ===---------------------------------------------------------------------------

from .fuse_ops import simply_fuse
from .fuse_ops import simply_fuse, apply_classic_fusion
from .useless_op_eliminate import maxpool2d_simplify
Loading

0 comments on commit 731e1b8

Please sign in to comment.