Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update graph for op fusion #445

Merged
merged 4 commits into from
Jan 3, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion examples/BuddyLeNet/buddy-lenet-import.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@

from buddy.compiler.frontend import DynamoCompiler
from buddy.compiler.graph import GraphDriver
from buddy.compiler.graph.transform import simply_fuse
from buddy.compiler.graph.transform import simply_fuse, apply_classic_fusion
from buddy.compiler.ops import tosa
from model import LeNet

Expand Down
2 changes: 1 addition & 1 deletion examples/BuddyLlama/import-llama2.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
from buddy.compiler.frontend import DynamoCompiler
from buddy.compiler.ops import tosa
from buddy.compiler.graph import GraphDriver
from buddy.compiler.graph.transform import simply_fuse
from buddy.compiler.graph.transform import simply_fuse, apply_classic_fusion

# Retrieve the LLaMA model path from environment variables.
model_path = os.environ.get("LLAMA_MODEL_PATH")
Expand Down
45 changes: 24 additions & 21 deletions frontend/Python/frontend.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,9 +165,9 @@ def __init__(
"cos.default": CosOp,
"sin.default": SinOp,
"argmax.default": ArgMaxOp,
"split.Tensor":SplitOp,
"max.default":MaxOp,
"gt.Scalar":GtOp,
"split.Tensor": SplitOp,
"max.default": MaxOp,
"gt.Scalar": GtOp,
"_scaled_dot_product_flash_attention_for_cpu.default": ScaledDotProductFlashAttentionForCpuOp,
"ge.Scalar": GeOp,
"gt.Tensor": GreaterThanOp,
Expand Down Expand Up @@ -237,7 +237,9 @@ def _create_node(
buddy_node.add_argument(str(input_arg))
buddy_node.add_parent(str(input_arg))
elif isinstance(input_arg, torch.dtype):
buddy_node.add_argument(self._torch_dtype_translate(str(input_arg)))
buddy_node.add_argument(
self._torch_dtype_translate(str(input_arg))
)
else:
buddy_node.add_argument(input_arg)
for user in node_users:
Expand Down Expand Up @@ -294,7 +296,7 @@ def _compiler(_gm: torch.fx.GraphModule, _inputs: List[torch.Tensor]):
nonlocal params_flat
func_inputs = []
for i in inputs_pos:
# for inp in _inputs[len(params_flat) :]:
# for inp in _inputs[len(params_flat) :]:
inp = _inputs[i]
inp_shape = inp.shape
inp_dtype = self._torch_dtype_translate(str(inp.dtype))
Expand All @@ -308,7 +310,7 @@ def _compiler(_gm: torch.fx.GraphModule, _inputs: List[torch.Tensor]):
fake_params,
self._ops_registry,
self._func_name,
self._verbose
self._verbose,
)
param_nodes = []
buffers_nodes = []
Expand Down Expand Up @@ -344,10 +346,7 @@ def _compiler(_gm: torch.fx.GraphModule, _inputs: List[torch.Tensor]):

elif gm_node.op == "output":
buddy_node = self._create_node(
gm_node.op,
gm_node.name,
gm_node.args,
node_users
gm_node.op, gm_node.name, gm_node.args, node_users
)

elif gm_node.target is operator.getitem:
Expand All @@ -367,7 +366,11 @@ def _compiler(_gm: torch.fx.GraphModule, _inputs: List[torch.Tensor]):
tensor_meta = gm_node.meta.get("tensor_meta")
val = gm_node.meta.get("val")
# num_returns = len(gm_node.target._schema.returns)
num_returns = len(val) if isinstance(val, list) else len(gm_node.target._schema.returns)
num_returns = (
len(val)
if isinstance(val, list)
else len(gm_node.target._schema.returns)
)
if num_returns == 1:
node_dtype = self._torch_dtype_translate(
str(tensor_meta.dtype)
Expand Down Expand Up @@ -477,7 +480,7 @@ def get_lib_extension():

def cast_c_ptr(outdata_ptr, memref_ptr):
"""
Casts a C pointer (`outdata_ptr`) to the type of another C pointer
Casts a C pointer (`outdata_ptr`) to the type of another C pointer
(`memref_ptr`).

Args:
Expand All @@ -488,14 +491,14 @@ def cast_c_ptr(outdata_ptr, memref_ptr):

Returns:
ctypes.POINTER
A new C pointer with the type of `memref_ptr`, representing the
A new C pointer with the type of `memref_ptr`, representing the
same memory location as `outdata_ptr`.

Example:
outdata = ctypes.pointer(ctypes.c_int())
memref = ctypes.pointer(ctypes.c_float())
casted_ptr = cast_c_ptr(outdata, memref)
# Now `casted_ptr` points to the same memory location as `outdata`,
# Now `casted_ptr` points to the same memory location as `outdata`,
but with the type of `memref`.
"""
outdata_addr = ctypes.addressof(outdata_ptr.contents)
Expand All @@ -504,15 +507,15 @@ def cast_c_ptr(outdata_ptr, memref_ptr):

def move_c_ptr(outdata_ptr, memref_ptr):
"""
Moves a C pointer (`outdata_ptr`) to the next element in memory,
based on the size of the referenced type in another C pointer
Moves a C pointer (`outdata_ptr`) to the next element in memory,
based on the size of the referenced type in another C pointer
(`memref_ptr`).

Args:
outdata_ptr: ctypes.POINTER
The C pointer whose position needs to be moved.
memref_ptr: ctypes.POINTER
The reference C pointer whose type determines the size of each
The reference C pointer whose type determines the size of each
element for the move.

Returns:
Expand All @@ -535,7 +538,7 @@ def exec_buddy_graph(*args):

Returns:
List[torch.Tensor]
The result of executing the graph, represented as a list of
The result of executing the graph, represented as a list of
output tensors.
"""
# A list of ctypes pointers representing memory references for input
Expand All @@ -548,13 +551,13 @@ def exec_buddy_graph(*args):
)
for tensor in args
]
# A list of ctypes pointers representing memory references for
# A list of ctypes pointers representing memory references for
# output tensors.
output_memref = [
ctypes.pointer(ctypes.pointer(graph._output_descriptor()))
]
args_memref = output_memref + input_memref
# Invoke the graph's function using the provided execution engine
# Invoke the graph's function using the provided execution engine
# and memory references
ee.invoke(graph._func_name, *args_memref)

Expand All @@ -571,7 +574,7 @@ def exec_buddy_graph(*args):
# Move to the next element in memory based on the size of the
# current output type
outdata_ptr = move_c_ptr(outdata_ptr, output_ptr[0])
# Convert each NumPy array to a PyTorch tensor and return the list
# Convert each NumPy array to a PyTorch tensor and return the list
# of tensors
return [torch.from_numpy(tensor) for tensor in output_tensor]

Expand Down
104 changes: 88 additions & 16 deletions frontend/Python/graph/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ def __init__(
fake_params: List[TensorMeta],
ops_registry: dict,
func_name: str,
verbose=False
verbose=False,
) -> None:
"""
Initializes the Graph.
Expand Down Expand Up @@ -164,6 +164,78 @@ def add_node(self, node: Op):
self._body.append(node)
self.node_table[node.name] = node

def check_delete_node(self, node: Op) -> bool:
"""
Determines if a node exists in the graph and has no child nodes.

Args:
node (Op): The operation node to check for deletion eligibility.

Returns:
bool: True if the node exists in the graph and has no children.
"""
if not (node.name in self.node_table):
raise KeyError("node{0} not in graph".format(node.name))

if len(node._children) == 0:
return True
return False

def delete_node(self, node: Op, parents: List[Op]):
"""
Removes a node from the graph and updates its parent nodes accordingly.

Args:
node (Op): The operation node to be deleted from the graph.
parents (List[Op]): A list of parent operation nodes that reference the node to be deleted.

Returns:
None
"""
for i in parents:
i._children.remove(node.name)
node.args.clear()
node.kwargs.clear()
node._children.clear()
self._body.remove(node)
self.node_table.pop(node.name)

def displace_node(self, node: Op, newnode: Op):
"""
Replaces an existing node with a new node in the graph.

Args:
node (Op): The operation node to be replaced.
newnode (Op): The new operation node that will replace the existing node.

Returns:
None
"""
newnode._arguments = node.args
newnode._keyword_arguments = node.kwargs
newnode._tensor_meta = node.tensor_meta
newnode._op_type = node._op_type

for i in node._children:
newnode.add_children(i)
users = [self.node_table[i] for i in node._children]
for user in users:
if node.name in user._parents:
user._parents[user._parents.index(node.name)] = newnode.name
user.args[user.args.index(node.name)] = newnode.name
node._children.clear()
# deal with parents+args
for i in node._parents:
newnode.add_parent(i)
parents = [self.node_table[i] for i in node._parents]
for parent in parents:
parent._children[parent._children.index(node.name)] = newnode.name
node._parents.clear()
# update node table
self._body[self._body.index(node)] = newnode
self.node_table.pop(node.name)
self.node_table[newnode.name] = newnode

def init_op_group(self):
"""
Initializes operation groups within the graph.
Expand Down Expand Up @@ -239,7 +311,7 @@ def lower_to_top_level_ir(self):
self._inputs,
self._func_name,
self._ops_registry,
verbose=self._verbose
verbose=self._verbose,
)
self._imported_module = fx_importer.import_graph()
outputs = fx_importer.get_output_nodes()
Expand Down Expand Up @@ -352,7 +424,7 @@ def __init__(
func_name: str,
ops_registry: dict,
do_param_pack: bool = False,
verbose=False
verbose=False,
):
"""
Initializes the buddy Graph importer.
Expand Down Expand Up @@ -475,27 +547,27 @@ def generated_func(*args):
elif isinstance(node, PlaceholderOp):
self._import_placeholder(node, args_list)
elif isinstance(node, GetItemOp):
self._symbol_table[
(str(node.name), 0)
] = self._symbol_table[
(str(node.args[0]), node.args[1])
]
self._symbol_table[(str(node.name), 0)] = (
self._symbol_table[
(str(node.args[0]), node.args[1])
]
)
else:
self._import_op(node)
new_ops = [op for op in func_op.body.blocks[0].operations]
if self._verbose:
print('='*20 + "Graph Node" + "="*20)
print("=" * 20 + "Graph Node" + "=" * 20)
print("Node: " + node.name)
print("Type: " + str(node._op_type))
print("Arguments: " + str(node.args))
print("Parents: " + str(node._parents))
print("Children: " + str(node._children))
print('-'*20 + "MLIR OPS" + '-'*20)
print("-" * 20 + "MLIR OPS" + "-" * 20)
for op in new_ops:
if op not in old_ops:
print(op)
print("")

return self._symbol_table.get(("output", 0))

return self._module
Expand Down Expand Up @@ -544,11 +616,11 @@ def generated_func(*args):
elif isinstance(node, PlaceholderOp):
self._import_placeholder(node, args_list)
elif isinstance(node, GetItemOp):
self._symbol_table[
(str(node.name), 0)
] = self._symbol_table[
(str(node.args[0]), node.args[1])
]
self._symbol_table[(str(node.name), 0)] = (
self._symbol_table[
(str(node.args[0]), node.args[1])
]
)
else:
self._import_op(node)

Expand Down
6 changes: 6 additions & 0 deletions frontend/Python/graph/operation.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,12 @@ def __init__(self) -> None:
self._op_type = OpType.ReduceType


class TransposeMatmulFusedOp(Op):
def __init__(self) -> None:
super().__init__()
self._op_type = OpType.ReduceType


class GetItemOp(Op):
def __init__(self) -> None:
super().__init__()
Expand Down
2 changes: 1 addition & 1 deletion frontend/Python/graph/transform/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,5 +18,5 @@
#
# ===---------------------------------------------------------------------------

from .fuse_ops import simply_fuse
from .fuse_ops import simply_fuse, apply_classic_fusion
from .useless_op_eliminate import maxpool2d_simplify
Loading