diff --git a/examples/BuddyLeNet/.gitignore b/examples/BuddyLeNet/.gitignore index 8ef196d742..6e7b67ebf8 100644 --- a/examples/BuddyLeNet/.gitignore +++ b/examples/BuddyLeNet/.gitignore @@ -3,8 +3,11 @@ log.ll log.s data *.data +*.json +*.dot __pycache__ *.pth lenet.mlir forward.mlir subgraph0.mlir +subgraph1.mlir diff --git a/examples/BuddyLeNet/CMakeLists.txt b/examples/BuddyLeNet/CMakeLists.txt index b765218c68..1902384f92 100644 --- a/examples/BuddyLeNet/CMakeLists.txt +++ b/examples/BuddyLeNet/CMakeLists.txt @@ -1,7 +1,7 @@ add_custom_command( - OUTPUT ${BUDDY_EXAMPLES_DIR}/BuddyLeNet/forward.mlir ${BUDDY_EXAMPLES_DIR}/BuddyLeNet/subgraph0.mlir ${BUDDY_EXAMPLES_DIR}/BuddyLeNet/arg0.data + OUTPUT ${BUDDY_EXAMPLES_DIR}/BuddyLeNet/forward.mlir ${BUDDY_EXAMPLES_DIR}/BuddyLeNet/subgraph0.mlir ${BUDDY_EXAMPLES_DIR}/BuddyLeNet/subgraph1.mlir ${BUDDY_EXAMPLES_DIR}/BuddyLeNet/arg0.data COMMAND python3 ${BUDDY_EXAMPLES_DIR}/BuddyLeNet/buddy-lenet-import.py - COMMENT "Generating forward.mlir, subgraph0.mlir and parameter files" + COMMENT "Generating forward.mlir, subgraph0.mlir, subgraph1.mlir and parameter files" ) add_custom_command( @@ -50,13 +50,61 @@ add_custom_command( COMMENT "Building subgraph0.o" VERBATIM) -add_library(LENET STATIC subgraph0.o forward.o) +set(ONE_SHOT_BUFFERIZE_OPTION "bufferize-function-boundaries=1 function-boundary-type-conversion=identity-layout-map") +set(LOWER_TO_NVVM_OPTION "cubin-chip=sm_80 cubin-features=+ptx71 cubin-format=fatbin") +add_custom_command( + OUTPUT subgraph1.o + COMMAND ${LLVM_TOOLS_BINARY_DIR}/mlir-opt ${BUDDY_EXAMPLES_DIR}/BuddyLeNet/subgraph1.mlir + -pass-pipeline "builtin.module(func.func(tosa-to-linalg-named, tosa-to-linalg, tosa-to-tensor, tosa-to-arith))" | + ${LLVM_TOOLS_BINARY_DIR}/mlir-opt + -one-shot-bufferize=${ONE_SHOT_BUFFERIZE_OPTION} + -buffer-deallocation + -convert-linalg-to-parallel-loops + -canonicalize + -gpu-map-parallel-loops + -convert-parallel-loops-to-gpu + -gpu-kernel-outlining + -canonicalize + -cse | + ${BUDDY_BINARY_DIR}/buddy-opt -convert-memcpy-to-gpu -gpu-async-region -canonicalize | + ${LLVM_TOOLS_BINARY_DIR}/mlir-opt -llvm-request-c-wrappers --test-lower-to-nvvm=${LOWER_TO_NVVM_OPTION} | + ${LLVM_TOOLS_BINARY_DIR}/mlir-translate -mlir-to-llvmir | + ${LLVM_TOOLS_BINARY_DIR}/llvm-as | + ${LLVM_TOOLS_BINARY_DIR}/llc -filetype=obj -relocation-model=pic -O0 -o ${BUDDY_BINARY_DIR}/../examples/BuddyLeNet/subgraph1.o + DEPENDS ${BUDDY_EXAMPLES_DIR}/BuddyLeNet/subgraph1.mlir + COMMENT "Building subgraph1.o" + VERBATIM) +set(ONE_SHOT_BUFFERIZE_OPTION "bufferize-function-boundaries=1 function-boundary-type-conversion=identity-layout-map") +set(LOWER_TO_NVVM_OPTION "cubin-chip=sm_80 cubin-features=+ptx71 cubin-format=fatbin") +add_custom_command( + OUTPUT subgraph1.o + COMMAND ${LLVM_TOOLS_BINARY_DIR}/mlir-opt ${BUDDY_EXAMPLES_DIR}/BuddyLeNet/subgraph1.mlir + -pass-pipeline "builtin.module(func.func(tosa-to-linalg-named, tosa-to-linalg, tosa-to-tensor, tosa-to-arith))" | + ${LLVM_TOOLS_BINARY_DIR}/mlir-opt + -one-shot-bufferize=${ONE_SHOT_BUFFERIZE_OPTION} + -buffer-deallocation + -convert-linalg-to-parallel-loops + -canonicalize + -gpu-map-parallel-loops + -convert-parallel-loops-to-gpu + -gpu-kernel-outlining + -canonicalize + -cse | + ${BUDDY_BINARY_DIR}/buddy-opt -convert-memcpy-to-gpu -gpu-async-region -canonicalize | + ${LLVM_TOOLS_BINARY_DIR}/mlir-opt -llvm-request-c-wrappers --test-lower-to-nvvm=${LOWER_TO_NVVM_OPTION} | + ${LLVM_TOOLS_BINARY_DIR}/mlir-translate -mlir-to-llvmir | + ${LLVM_TOOLS_BINARY_DIR}/llvm-as | + ${LLVM_TOOLS_BINARY_DIR}/llc -filetype=obj -relocation-model=pic -O0 -o ${BUDDY_BINARY_DIR}/../examples/BuddyLeNet/subgraph1.o + DEPENDS ${BUDDY_EXAMPLES_DIR}/BuddyLeNet/subgraph1.mlir + COMMENT "Building subgraph1.o" + VERBATIM) + +add_library(LENET STATIC subgraph0.o subgraph1.o forward.o) SET_TARGET_PROPERTIES(LENET PROPERTIES LINKER_LANGUAGE C) add_executable(buddy-lenet-run buddy-lenet-main.cpp) target_link_directories(buddy-lenet-run PRIVATE ${LLVM_LIBRARY_DIR}) -set(BUDDY_LENET_LIBS LENET mlir_c_runner_utils ${PNG_LIBRARIES}) - +set(BUDDY_LENET_LIBS LENET mlir_c_runner_utils mlir_async_runtime mlir_runner_utils mlir_cuda_runtime BuddyLibDIP ${PNG_LIBRARIES}) target_link_libraries(buddy-lenet-run ${BUDDY_LENET_LIBS}) diff --git a/examples/BuddyLeNet/buddy-lenet-import.py b/examples/BuddyLeNet/buddy-lenet-import.py index 95e76de253..2ef14649e6 100644 --- a/examples/BuddyLeNet/buddy-lenet-import.py +++ b/examples/BuddyLeNet/buddy-lenet-import.py @@ -27,8 +27,14 @@ from buddy.compiler.frontend import DynamoCompiler from buddy.compiler.graph import GraphDriver -from buddy.compiler.graph.transform import simply_fuse -from buddy.compiler.ops import tosa +from buddy.compiler.graph.transform import ( + simply_fuse, + gpu_fuse, + custom_partition, +) +from buddy.compiler.graph.type import DeviceType +from buddy.compiler.ops import tosa, gpu +from buddy.compiler.graph.json_decoder import json_to_graph from model import LeNet # Retrieve the LeNet model path from environment variables. @@ -56,13 +62,17 @@ assert len(graphs) == 1 graph = graphs[0] params = dynamo_compiler.imported_params[graph] -pattern_list = [simply_fuse] -graphs[0].fuse_ops(pattern_list) -driver = GraphDriver(graphs[0]) -driver.subgraphs[0].lower_to_top_level_ir() +pattern_list = [custom_partition] +graph.fuse_ops(pattern_list) path_prefix = os.path.dirname(os.path.abspath(__file__)) +driver = GraphDriver(graph) +driver.subgraphs[0].lower_to_top_level_ir() with open(os.path.join(path_prefix, "subgraph0.mlir"), "w") as module_file: print(driver.subgraphs[0]._imported_module, file=module_file) +# Add heterogeneous hardware partition +driver.subgraphs[1].lower_to_top_level_ir() +with open(os.path.join(path_prefix, "subgraph1.mlir"), "w") as module_file: + print(driver.subgraphs[1]._imported_module, file=module_file) with open(os.path.join(path_prefix, "forward.mlir"), "w") as module_file: print(driver.construct_main_graph(True), file=module_file) diff --git a/examples/BuddyLeNet/makefile b/examples/BuddyLeNet/makefile index fe87b6da1a..f29fcf0769 100644 --- a/examples/BuddyLeNet/makefile +++ b/examples/BuddyLeNet/makefile @@ -20,6 +20,22 @@ MLIR_ASYNC_RUNTIME := ${LLVM_BUILD_DIR}/lib/libmlir_async_runtime.dylib MTRIPLE := x86_64-apple-darwin endif +buddy-gpu-matmul-lower: + @${BUDDY_OPT} subgraph0.mlir \ + -transform-preload-library="transform-library-paths=transform.mlir" \ + -transform-interpreter="entry-point=codegen" \ + -o log.mlir + +buddy-gpu-matmul: + @${BUDDY_OPT} subgraph0.mlir -transform-preload-library="transform-library-paths=transform.mlir" -transform-interpreter="entry-point=codegen" | \ + ${BUDDY_OPT} --pass-pipeline='builtin.module(func.func(nvgpu-optimize-shared-memory))' | \ + ${BUDDY_OPT} -arith-expand -eliminate-empty-tensors -empty-tensor-to-alloc-tensor -linalg-bufferize -convert-linalg-to-affine-loops -affine-loop-fusion -affine-parallelize -lower-affine -canonicalize -func-bufferize -arith-bufferize -tensor-bufferize -buffer-deallocation -finalizing-bufferize -canonicalize | \ + ${BUDDY_OPT} -gpu-launch-sink-index-computations -canonicalize -legalize-shmem-outlining -canonicalize | \ + ${BUDDY_OPT} -convert-memcpy-to-gpu -gpu-async-region -canonicalize | \ + ${BUDDY_OPT} -convert-scf-to-cf -memref-expand -finalize-memref-to-llvm -convert-arith-to-llvm --convert-vector-to-llvm -convert-gpu-to-nvvm='has-redux=1' | \ + ${BUDDY_OPT} -llvm-request-c-wrappers -canonicalize -cse -sccp | \ + ${MLIR_OPT} --test-lower-to-nvvm="cubin-chip=sm_80 cubin-features=+ptx71 cubin-format=fatbin" -o matmul-cubin.mlir + buddy-lenet-lower: @${BUDDY_OPT} ./fake-lenet.mlir \ -pass-pipeline "builtin.module(func.func(tosa-to-linalg-named, tosa-to-linalg, tosa-to-tensor, tosa-to-arith))" | \ @@ -124,3 +140,4 @@ buddy-lenet-opt-run: -reconcile-unrealized-casts | \ ${MLIR_CPU_RUNNER} ${OPT_FLAG} -e main -entry-point-result=void \ -shared-libs=${MLIR_RUNNER_UTILS} -shared-libs=${MLIR_C_RUNNER_UTILS} + diff --git a/frontend/Python/frontend.py b/frontend/Python/frontend.py index 9d8c80f014..210815fb1e 100644 --- a/frontend/Python/frontend.py +++ b/frontend/Python/frontend.py @@ -45,6 +45,7 @@ from .graph import Graph, TensorDType, TensorMeta from .graph.operation import * from .graph.transform import maxpool2d_simplify +from .graph.type import * class DynamoCompiler: @@ -284,6 +285,7 @@ def _compiler(_gm: torch.fx.GraphModule, _inputs: List[torch.Tensor]): fake_params, self._ops_registry, self._func_name, + DeviceType.CPU, self._verbose ) for gm_node in _gm.graph.nodes: diff --git a/frontend/Python/graph/graph.py b/frontend/Python/graph/graph.py index ce35693efd..ddf50f697c 100644 --- a/frontend/Python/graph/graph.py +++ b/frontend/Python/graph/graph.py @@ -23,9 +23,12 @@ import ctypes import functools import numpy as np +import graphviz +import json import mlir.ir as ir import mlir.dialects.func as func +import mlir.dialects.bufferization as buffer from mlir.passmanager import * from mlir.execution_engine import * from mlir import runtime as rt @@ -105,7 +108,8 @@ def __init__( fake_params: List[TensorMeta], ops_registry: dict, func_name: str, - verbose=False + device: DeviceType = DeviceType.CPU, + verbose=False, ) -> None: """ Initializes the Graph. @@ -124,7 +128,7 @@ def __init__( self._inputs = inputs self.node_table: Dict[str, Op] = {} self._fake_params = fake_params - self.device = "cpu" + self.device = device self._imported_module = None self._verbose = verbose self._ops_registry = ops_registry @@ -171,12 +175,13 @@ def init_op_group(self): Returns: - None """ + group = [] for i, op in enumerate(self._body): - if isinstance(op, PlaceholderOp): + if isinstance(op, PlaceholderOp) or isinstance(op, OutputOp): continue group = [op] subgraph_name = "subgraph{}".format(i) - self.group_map_device[subgraph_name] = DeviceType.UNKNOW + self.group_map_device[subgraph_name] = DeviceType.CPU self.op_groups[subgraph_name] = group def fuse_ops(self, pattern_list: List[FunctionType]): @@ -193,13 +198,10 @@ def fuse_ops(self, pattern_list: List[FunctionType]): # TODO: discuss two fuse strategy # 1. fuse ops adapt for DSA(hardware dependent) # 2. common fuse strategy(hardware independent) - - # Initialize operation groups - self.init_op_group() - # Apply fusion patterns for pattern_func in pattern_list: pattern_func(self) + # Initialize operation groups def perform(self, func_list: List[FunctionType]): """ @@ -239,7 +241,9 @@ def lower_to_top_level_ir(self): self._inputs, self._func_name, self._ops_registry, - verbose=self._verbose + False, + self.device, + verbose=self._verbose, ) self._imported_module = fx_importer.import_graph() outputs = fx_importer.get_output_nodes() @@ -327,6 +331,97 @@ def compile(self): self.lower_to_top_level_ir() self.lower_to_llvm_ir() + def to_dot(self): + """ + Converts a buddy graph to a DOT string for visualization. + + Returns: + str: A DOT string representing the buddy graph for visualization. + """ + dot = graphviz.Digraph(comment="Buddy Graph") + for op in self._body: + for child in op._children: + dot.edge(op._name, child) + for op in self._body: + if isinstance(op, PlaceholderOp): + dot.node( + op._name, shape="ellipse", fillcolor="white", style="filled" + ) + elif isinstance(op, OutputOp): + dot.node( + op._name, shape="ellipse", fillcolor="white", style="filled" + ) + elif isinstance(op, MaxPool2dOp): + dot.node(op._name, shape="box", fillcolor="red", style="filled") + else: + dot.node( + op._name, + shape="box", + fillcolor="deepskyblue", + style="filled", + ) + return str(dot) + + def to_json(self): + """ + Converts a buddy graph to a JSON string. + + Returns: + str: A JSON string representing the buddy graph. + """ + json_str = json.dumps(self, cls=BuddyGraphEncoder) + return json_str + + +class BuddyGraphEncoder(json.JSONEncoder): + """ + Custom JSON encoder for converting Buddy Graph objects to JSON strings. + + This encoder handles encoding of Graph, Op, TensorMeta, OpType, TensorDType, + and DeviceType objects to their JSON representation. + + Returns: + JSONEncoder: A JSON encoder instance for Buddy Graph objects. + """ + + def default(self, obj): + if isinstance(obj, Graph): + node_map_device = {} + for subgraph_name, ops in obj.op_groups.items(): + for op in ops: + node_map_device[op.name] = obj.group_map_device[ + subgraph_name + ] + return { + "graph_name": obj._func_name, + "nodes": obj._body, + "device": obj.device, + "params": obj._fake_params, + "inputs": obj._inputs, + "node_map_device": node_map_device, + } + elif isinstance(obj, Op): + return { + "name": obj._name, + "children": obj._children, + "parents": obj._parents, + "arguments": obj._arguments, + "keyword_arguments": obj._keyword_arguments, + "tensor_meta": obj._tensor_meta, + "type": obj._op_type, + "class": obj.__class__.__name__, + } + elif isinstance(obj, TensorMeta): + return {"shape": obj.shape, "dtype": obj.dtype} + elif isinstance(obj, OpType): + return obj._name_ + elif isinstance(obj, TensorDType): + return obj._name_ + elif isinstance(obj, DeviceType): + return obj._value_ + else: + return super().default(obj) + class GraphImporter: """ @@ -350,7 +445,8 @@ def __init__( func_name: str, ops_registry: dict, do_param_pack: bool = False, - verbose=False + device: DeviceType = DeviceType.CPU, + verbose=False, ): """ Initializes the buddy Graph importer. @@ -365,6 +461,7 @@ def __init__( ops_registry = {} self._symbol_table = {} self._body = body + self._device = device self._func_name = func_name self._params = params self._inputs = inputs @@ -471,27 +568,27 @@ def generated_func(*args): elif isinstance(node, PlaceholderOp): self._import_placeholder(node, args_list) elif isinstance(node, GetItemOp): - self._symbol_table[ - (str(node.name), 0) - ] = self._symbol_table[ - (str(node.args[0]), node.args[1]) - ] + self._symbol_table[(str(node.name), 0)] = ( + self._symbol_table[ + (str(node.args[0]), node.args[1]) + ] + ) else: self._import_op(node) new_ops = [op for op in func_op.body.blocks[0].operations] if self._verbose: - print('='*20 + "Graph Node" + "="*20) + print("=" * 20 + "Graph Node" + "=" * 20) print("Node: " + node.name) print("Type: " + str(node._op_type)) print("Arguments: " + str(node.args)) print("Parents: " + str(node._parents)) print("Children: " + str(node._children)) - print('-'*20 + "MLIR OPS" + '-'*20) + print("-" * 20 + "MLIR OPS" + "-" * 20) for op in new_ops: if op not in old_ops: print(op) print("") - + return self._symbol_table.get(("output", 0)) return self._module @@ -540,11 +637,11 @@ def generated_func(*args): elif isinstance(node, PlaceholderOp): self._import_placeholder(node, args_list) elif isinstance(node, GetItemOp): - self._symbol_table[ - (str(node.name), 0) - ] = self._symbol_table[ - (str(node.args[0]), node.args[1]) - ] + self._symbol_table[(str(node.name), 0)] = ( + self._symbol_table[ + (str(node.args[0]), node.args[1]) + ] + ) else: self._import_op(node) diff --git a/frontend/Python/graph/graph_driver.py b/frontend/Python/graph/graph_driver.py index dd37aa12bd..013a9f6e0b 100644 --- a/frontend/Python/graph/graph_driver.py +++ b/frontend/Python/graph/graph_driver.py @@ -21,6 +21,7 @@ # ===--------------------------------------------------------------------------- from mlir import ir +from collections import deque, defaultdict from .graph import Graph, GraphImporter, TensorMeta from .operation import FuncOp, CallOp, PlaceholderOp, OutputOp, GetItemOp @@ -40,6 +41,7 @@ class GraphDriver: - _subgraphs_outputs (dict): A dictionary mapping subgraph names to their output op's result. """ + def __init__(self, graph: Graph) -> None: """ Initialize the GraphDriver object with a given computational graph. @@ -52,6 +54,11 @@ def __init__(self, graph: Graph) -> None: - None """ self._graph = graph + self._subgraph_dependencies = { + subgraph_name: set() + for subgraph_name in list(self._graph.op_groups.keys()) + } + self._call_table = {} ( self._subgraphs, self._subgraphs_inputs, @@ -94,14 +101,15 @@ def build_subgraph_by_group(self): if isinstance(node, OutputOp): for arg in node.args: output_node.append(arg) - - # Identify outputs for each subgraph + + # Identify outputs for each subgraph and build dependencies between subgraphs for subgraph_name in self._graph.op_groups.keys(): subgraphs_outputs[subgraph_name] = [] for op in self._graph.op_groups[subgraph_name]: for key in subgraphs_inputs.keys(): if op.name in subgraphs_inputs[key]: subgraphs_outputs[subgraph_name].append(op.name) + self._subgraph_dependencies[subgraph_name].add(key) if (op.name in output_node) and ( op.name not in subgraphs_outputs[subgraph_name] ): @@ -112,6 +120,7 @@ def build_subgraph_by_group(self): for subgraph_name in self._graph.op_groups.keys(): subgraph_input = [] subgraph_body = [] + subgraph_device = self._graph.group_map_device[subgraph_name] # Construct input placeholder nodes for inp in subgraphs_inputs[subgraph_name]: @@ -127,11 +136,11 @@ def build_subgraph_by_group(self): if inp in node._parents: placeholder_node.add_children(op.name) subgraph_body.append(placeholder_node) - + # Add operations to subgraph body for op in self._graph.op_groups[subgraph_name]: subgraph_body.append(op) - + # Construct output node output_node = OutputOp() output_node.name = "output" @@ -142,7 +151,12 @@ def build_subgraph_by_group(self): # Create subgraph and add it to the dictionary subgraph = Graph( - subgraph_input, [], self._graph._ops_registry, subgraph_name, verbose=self._graph._verbose + subgraph_input, + [], + self._graph._ops_registry, + subgraph_name, + subgraph_device, + verbose=self._graph._verbose, ) subgraph.body = subgraph_body for op in subgraph_body: @@ -151,6 +165,44 @@ def build_subgraph_by_group(self): return subgraphs, subgraphs_inputs, subgraphs_outputs + def topological_sort_subgraph(self): + """ + Performs topological sorting on the subgraphs based on their dependencies. + + Args: + - graph (Graph): The graph from which subgraphs are constructed. + + Returns: + - list: A list of subgraph names in topological order if the graph is acyclic; otherwise, None. + """ + + # Calculate in degree of each subgraph + in_degree = { + subgraph_name: 0 for subgraph_name in list(self._subgraphs.keys()) + } + for src, dests in self._subgraph_dependencies.items(): + for dest in dests: + in_degree[dest] += 1 + + # Topological sorting + queue = deque([node for node in in_degree if in_degree[node] == 0]) + topo_order = [] + + while queue: + node = queue.popleft() + topo_order.append(node) + for child in self._subgraph_dependencies[node]: + in_degree[child] -= 1 + if in_degree[child] == 0: + queue.append(child) + + # TODO: If the custom subgraph partitioning is illegal, further partition the subgraph to make it valid. + return ( + topo_order + if len(topo_order) == len(list(self._subgraphs.keys())) + else None + ) + def construct_main_graph(self, do_param_pack=False): """ Constructs the main computational graph by incorporating subgraphs' call @@ -172,7 +224,7 @@ def construct_main_graph(self, do_param_pack=False): self._graph._fake_params, self._graph._ops_registry, self._graph._func_name, - self._graph._verbose + self._graph._verbose, ) # Adding FuncOp nodes for each subgraph @@ -189,53 +241,68 @@ def construct_main_graph(self, do_param_pack=False): func_node.tensor_meta["dtype"].append( self._graph.node_table[output].tensor_meta["dtype"] ) - main_graph.body.append(func_node) - + main_graph.add_node(func_node) + # Adding placeholder operations from the original graph for op in self._graph.body: if isinstance(op, PlaceholderOp): - main_graph.body.append(op) - - # TODO: analysis topology order to sort subgraph call. - if len(self._subgraphs) == 1: - # Adding CallOp to invoke the single subgraph + main_graph.add_node(op) + + # Analysis topology order to sort subgraph call. + topo_order = self.topological_sort_subgraph() + if topo_order == None: + print("Error : Graph Partitioning is illegal!") + return None + + # Adding CallOp to invoke the single subgraph + for i, subgraph_name in enumerate(topo_order): call_node = CallOp() - call_node.name = "call0" - call_node.call_func_name = list(self._subgraphs.keys())[0] + call_node.name = "call{}".format(i) + call_node.call_func_name = subgraph_name call_node.tensor_meta = {"shape": [], "dtype": []} - for inp in list(self._subgraphs_inputs.values())[0]: - call_node.add_argument(inp) - for output in list(self._subgraphs_outputs.values())[0]: + for inp in self._subgraphs_inputs[subgraph_name]: + if inp in main_graph.node_table: + call_node.add_argument(inp) + continue + for key, value in self._subgraphs_outputs.items(): + if inp in value: + call_node.add_argument( + arg=self._call_table[key].name, + arg_index=value.index(inp), + ) + break + for output in self._subgraphs_outputs[subgraph_name]: call_node.tensor_meta["shape"].append( self._graph.node_table[output].tensor_meta["shape"] ) call_node.tensor_meta["dtype"].append( self._graph.node_table[output].tensor_meta["dtype"] ) - main_graph.body.append(call_node) + self._call_table[subgraph_name] = call_node + main_graph.add_node(call_node) - # Adding GetItemOps to retrieve individual output tensors - output_node = OutputOp() - for i, output in enumerate(list(self._subgraphs_outputs.values())[0]): - getitem_node = GetItemOp() - getitem_node.add_argument(call_node.name) - getitem_node.add_argument(i) - getitem_node.name = "getitem{}".format(i) - output_node.add_argument(getitem_node.name) - main_graph.body.append(getitem_node) - - # Marking the final output of the main graph - output_node.name = "output" - main_graph.body.append(output_node) - - # Importing the main graph - with ir.Location.unknown(ir.Context()): - main_importer = GraphImporter( - main_graph.body, - main_graph._fake_params, - main_graph._inputs, - main_graph._func_name, - main_graph._ops_registry, - do_param_pack, - ) - return main_importer.import_main_graph() + # Adding GetItemOps to retrieve individual output tensors + output_node = OutputOp() + for i, output in enumerate(self._subgraphs_outputs[topo_order[-1]]): + getitem_node = GetItemOp() + getitem_node.add_argument(call_node.name) + getitem_node.add_argument(i) + getitem_node.name = "getitem{}".format(i) + output_node.add_argument(getitem_node.name) + main_graph.add_node(getitem_node) + + # Marking the final output of the main graph + output_node.name = "output" + main_graph.add_node(output_node) + + # Importing the main graph + with ir.Location.unknown(ir.Context()): + main_importer = GraphImporter( + main_graph.body, + main_graph._fake_params, + main_graph._inputs, + main_graph._func_name, + main_graph._ops_registry, + do_param_pack, + ) + return main_importer.import_main_graph() diff --git a/frontend/Python/graph/json_decoder.py b/frontend/Python/graph/json_decoder.py new file mode 100644 index 0000000000..f3a11440ac --- /dev/null +++ b/frontend/Python/graph/json_decoder.py @@ -0,0 +1,132 @@ +# ===- json_decoder.py --------------------------------------------------------- +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# ===--------------------------------------------------------------------------- +# +# This converts the JSON string representing Buddy Graph into a Graph object. +# +# ===--------------------------------------------------------------------------- +import json +from pathlib import Path + +from .graph import Graph, TensorDType, TensorMeta +from .graph_driver import GraphDriver +from .operation import * +from .type import * + +from ..ops.linalg import ops_registry as linalg_ops_registry +from ..ops.tosa import ops_registry as tosa_ops_registry +from ..ops.math import ops_registry as math_ops_registry +from ..ops.func import ops_registry as func_ops_registry + + +def json_to_graph(json_str): + """ + Converts a buddy graph JSON string to a Graph object. + + Args: + json_str (str): The JSON string representing the buddy graph. + + Returns: + Graph: The Graph object created from the JSON data. + """ + + def json_to_tensormeta(json_data): + """ + Convert JSON data to a TensorMeta object. + + Args: + json_data (dict): JSON data representing a TensorMeta object. + + Returns: + TensorMeta: The TensorMeta object created from the JSON data. + """ + if "shape" in json_data: + shape = json_data["shape"] + dtype = next( + ( + member + for member in TensorDType.__members__.values() + if member.value.upper() == json_data["dtype"].upper() + ), + None, + ) + return TensorMeta(shape, dtype) + return {} + + json_data = json.loads(json_str) + _graph = json_data + graph_name = _graph["graph_name"] + inputs = [] + params = [] + for _input in _graph["inputs"]: + inputs.append(json_to_tensormeta(_input)) + for _param in _graph["params"]: + params.append(json_to_tensormeta(_param)) + ops_registry = {} + ops_registry.update(func_ops_registry) + ops_registry.update(linalg_ops_registry) + ops_registry.update(tosa_ops_registry) + ops_registry.update(math_ops_registry) + graph = Graph(inputs, params, ops_registry, graph_name) + graph.device = _graph["device"] + for _node in _graph["nodes"]: + op_class = _node["class"] + op = globals()[op_class]() + + op._name = _node["name"] + op._children = _node["children"] + op._parents = _node["parents"] + op._arguments = _node["arguments"] + op._keyword_arguments = _node["keyword_arguments"] + op._type = next( + ( + member + for member in OpType.__members__.values() + if member.value == _node["type"] + ), + None, + ) + + # TODO : node attr tensor_meta should be Class TensorMeta + if "shape" not in _node["tensor_meta"]: + op._tensor_meta = _node["tensor_meta"] + else: + op._tensor_meta = { + "shape": _node["tensor_meta"]["shape"], + "dtype": next( + ( + member + for member in TensorDType.__members__.values() + if member.value.upper() + == _node["tensor_meta"]["dtype"].upper() + ), + None, + ), + } + graph.add_node(op) + + for i, device in enumerate(list(set(_graph["node_map_device"].values()))): + subgraph_name = "subgraph{}".format(i) + graph.op_groups[subgraph_name] = [] + graph.group_map_device[subgraph_name] = DeviceType(device) + + for node, op_device in _graph["node_map_device"].items(): + op = graph.node_table[node] + for subgraph_name, group_device in graph.group_map_device.items(): + if op_device == group_device.value: + graph.op_groups[subgraph_name].append(op) + break + + return graph diff --git a/frontend/Python/graph/operation.py b/frontend/Python/graph/operation.py index 0eb31fd961..0ec7930c25 100644 --- a/frontend/Python/graph/operation.py +++ b/frontend/Python/graph/operation.py @@ -81,13 +81,14 @@ def __init__(self) -> None: """ self._name = None self._arguments = [] + self._args_index = [] self._keyword_arguments = {} self._tensor_meta: Dict = {} self._op_type: OpType = None self._children: List[str] = [] self._parents: List[str] = [] - def add_argument(self, arg): + def add_argument(self, arg, arg_index=0): """ Add an input argument to the operation node. @@ -96,6 +97,7 @@ def add_argument(self, arg): The input argument to be added. """ self._arguments.append(arg) + self._args_index.append(arg_index) def add_parent(self, parent: str): """ @@ -125,6 +127,14 @@ def args(self): def kwargs(self): return self._keyword_arguments + @property + def parents(self): + return self._parents + + @property + def children(self): + return self._children + @property def name(self): return self._name diff --git a/frontend/Python/graph/transform/__init__.py b/frontend/Python/graph/transform/__init__.py index d91e0d06b2..95428b3367 100644 --- a/frontend/Python/graph/transform/__init__.py +++ b/frontend/Python/graph/transform/__init__.py @@ -18,5 +18,5 @@ # # ===--------------------------------------------------------------------------- -from .fuse_ops import simply_fuse +from .fuse_ops import simply_fuse, gpu_fuse, custom_partition from .useless_op_eliminate import maxpool2d_simplify diff --git a/frontend/Python/graph/transform/fuse_ops.py b/frontend/Python/graph/transform/fuse_ops.py index ac7d34c99c..7bfd2e8f98 100644 --- a/frontend/Python/graph/transform/fuse_ops.py +++ b/frontend/Python/graph/transform/fuse_ops.py @@ -26,11 +26,33 @@ # OP_TYPE_FUSABLE = [OpType.BroadcastType, OpType.ElementwiseType, OpType.ReshapeType] # OP_TYPE_UNFUSABLE = [OpType.Unfusable, OpType.ConcatType] # OP_TYPE_FUSABLE_BY_SPECIFIC_PASS = [] -# ANCHOR_OP_TYPE = [] +# ANCHOR_OP_TYPE = [] + def simply_fuse(graph: Graph): """ - Function to fuse all operations into one graph. + Function to fuse all operations into one graph. Set the device type to CPU. + + Args: + - graph (Graph): The input graph to be simplified. + + Returns: + - None: Modifies the input graph in place. + """ + new_op_group = [] + device = DeviceType.CPU + for op in graph.body: + if isinstance(op, PlaceholderOp): + continue + new_op_group.append(op) + graph.op_groups = {} + graph.op_groups["subgraph0"] = new_op_group + graph.group_map_device = {"subgraph0": device} + + +def gpu_fuse(graph: Graph): + """ + Function to fuse all operations into one graph. Set the device type to GPU. Args: - graph (Graph): The input graph to be simplified. @@ -39,7 +61,7 @@ def simply_fuse(graph: Graph): - None: Modifies the input graph in place. """ new_op_group = [] - device = DeviceType.UNKNOW + device = DeviceType.GPU for op in graph.body: if isinstance(op, PlaceholderOp): continue @@ -47,3 +69,27 @@ def simply_fuse(graph: Graph): graph.op_groups = {} graph.op_groups["subgraph0"] = new_op_group graph.group_map_device = {"subgraph0": device} + + +def custom_partition(graph: Graph): + """ + Function to custom subgraph partition. + + Args: + - graph (Graph): The input graph to be simplified. + + Returns: + - None: Modifies the input graph in place. + """ + group = [] + for i, op in enumerate(graph._body): + if isinstance(op, PlaceholderOp) or isinstance(op, OutputOp) or i == 25: + continue + group.append(op) + subgraph_name = "subgraph1" + graph.group_map_device[subgraph_name] = DeviceType.CPU + graph.op_groups[subgraph_name] = group + new_group = [graph._body[25]] + subgraph_name = "subgraph0" + graph.group_map_device[subgraph_name] = DeviceType.GPU + graph.op_groups[subgraph_name] = new_group diff --git a/frontend/Python/graph/transform/useless_op_eliminate.py b/frontend/Python/graph/transform/useless_op_eliminate.py index a99dbe02c6..0d176be2df 100644 --- a/frontend/Python/graph/transform/useless_op_eliminate.py +++ b/frontend/Python/graph/transform/useless_op_eliminate.py @@ -42,13 +42,24 @@ def maxpool2d_simplify(graph: Graph): and getitem_node.args[1] == 0 ): new_node = MaxPool2dOp() - new_node.name = getitem_node.name + new_node.name = node.name.replace("_with_indices", "") for arg in node.args: new_node.add_argument(arg) for parent in node._parents: new_node.add_parent(parent) + parent_node = graph.node_table[parent] + for cindex, child in enumerate(parent_node.children): + if child == node.name: + parent_node.children[cindex] = new_node.name for child in getitem_node._children: new_node.add_children(child) + child_node = graph.node_table[child] + for pindex, parent in enumerate(child_node.parents): + if parent == getitem_node.name: + child_node.parents[pindex] = new_node.name + for aindex, arg in enumerate(child_node.args): + if arg == getitem_node.name: + child_node.args[aindex] = new_node.name new_node.tensor_meta["shape"] = getitem_node.tensor_meta[ "shape" ] diff --git a/frontend/Python/ops/func.py b/frontend/Python/ops/func.py index a7dcc5e11b..e885809d82 100644 --- a/frontend/Python/ops/func.py +++ b/frontend/Python/ops/func.py @@ -59,8 +59,8 @@ def call_op(node: CallOp, symbol_table: Dict[Tuple[str, int], ir.Operation]): From Buddy CallOp to MLIR FUNC call operation. """ arguments = [] - for arg in node.args: - input_node = symbol_table.get((str(arg), 0)) + for i, arg in enumerate(node.args): + input_node = symbol_table.get((str(arg), node._args_index[i])) memref_type = ir.MemRefType(input_node.type) stride = [] shape = memref_type.shape diff --git a/midend/lib/Conversion/MLIRGPU/ConvertMemcpyToGPU.cpp b/midend/lib/Conversion/MLIRGPU/ConvertMemcpyToGPU.cpp index dd50feccf8..e44f21cb6e 100644 --- a/midend/lib/Conversion/MLIRGPU/ConvertMemcpyToGPU.cpp +++ b/midend/lib/Conversion/MLIRGPU/ConvertMemcpyToGPU.cpp @@ -18,11 +18,9 @@ // //===---------------------------------------------------------------------===// -#include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/GPU/IR/GPUDialect.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/IR/Builders.h" -#include "mlir/IR/BuiltinOps.h" #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/OperationSupport.h" #include "mlir/IR/TypeRange.h" @@ -30,9 +28,7 @@ #include "mlir/IR/Visitors.h" #include "mlir/Support/LLVM.h" #include "llvm/Support/Casting.h" -#include "llvm/Support/Debug.h" #include "llvm/Support/ErrorHandling.h" -#include "llvm/Support/raw_ostream.h" #include #include #include @@ -42,11 +38,8 @@ #include #include -#include -#include -#include -#include -#include +#include + using namespace mlir; using namespace vector; @@ -82,6 +75,9 @@ class ConvertMemcpyToGPUPass void ConvertMemcpyToGPUPass::runOnOperation() { auto funcOp = getOperation(); + if (funcOp.isDeclaration() || funcOp.isExternal()) + return; + // Make sure the gpu function is already outlined. funcOp->walk([&](Operation *nestedOp) { if (auto gpuLaunchOp = dyn_cast(nestedOp)) { @@ -90,8 +86,9 @@ void ConvertMemcpyToGPUPass::runOnOperation() { return WalkResult::advance(); }); - std::set unDeallocatedOperations; + std::vector unDeallocatedValue; OpBuilder builder(funcOp->getContext()); + // Copy all function arguments to gpu, needs deallocation if (processArgs) { builder.setInsertionPointToStart(&(funcOp.getBody().front())); @@ -103,23 +100,11 @@ void ConvertMemcpyToGPUPass::runOnOperation() { auto memrefType = dyn_cast(arg.getType()); auto gpuAllocOp = builder.create( builder.getUnknownLoc(), TypeRange({memrefType}), ValueRange({})); - unDeallocatedOperations.insert(&gpuAllocOp); + unDeallocatedValue.push_back(gpuAllocOp->getResult(0)); auto gpuMemcpyOp = builder.create( gpuAllocOp.getLoc(), TypeRange(), ValueRange(), gpuAllocOp.getResult(0), arg); - // Replace all users with GPU memory - auto users = arg.getUsers(); - std::vector usersVec(users.begin(), users.end()); - for (auto user : usersVec) { - // Don't replace memcpy's operand - if (isa(user)) - continue; - for (size_t j = 0; j < user->getNumOperands(); j++) { - if (user->getOperand(j) == arg) { - user->setOperand(j, gpuAllocOp.getResult(0)); - } - } - } + arg.replaceAllUsesExcept(gpuAllocOp->getResult(0), gpuMemcpyOp); } } @@ -149,19 +134,18 @@ void ConvertMemcpyToGPUPass::runOnOperation() { auto gpuAllocOp = builder.create( allocOp->getLoc(), TypeRange({memrefType}), ValueRange({})); - auto users = result.getUsers(); - std::vector usersVec(users.begin(), users.end()); - for (auto user : usersVec) { - for (size_t j = 0; j < user->getNumOperands(); j++) { - // Only the return value will not have dealloc op - if (auto deallocOp = dyn_cast(user)) { - builder.setInsertionPointAfter(deallocOp); - auto gpuDeallocOp = builder.create( - deallocOp->getLoc(), TypeRange(), ValueRange(), - gpuAllocOp.getResult(0)); - deallocOp->erase(); - } else if (user->getOperand(j) == result) { - user->setOperand(j, gpuAllocOp.getResult(0)); + + for (auto user : llvm::make_early_inc_range(result.getUsers())) { + if (auto deallocOp = dyn_cast(user)) { + builder.setInsertionPointAfter(deallocOp); + builder.create(deallocOp->getLoc(), TypeRange(), + ValueRange(), gpuAllocOp.getResult(0)); + deallocOp->erase(); + } else { + for (auto &opOperand : user->getOpOperands()) { + if (opOperand.is(result)) { + opOperand.set(gpuAllocOp.getResult(0)); + } } } } @@ -175,28 +159,8 @@ void ConvertMemcpyToGPUPass::runOnOperation() { builder.setInsertionPointAfter(copyOp); auto gpuMemcpyOp = builder.create( copyOp->getLoc(), TypeRange(), ValueRange(), dst, src); - { - auto users = src.getUsers(); - std::vector usersVec(users.begin(), users.end()); - for (auto user : usersVec) { - for (size_t j = 0; j < user->getNumOperands(); j++) { - if (user->getOperand(j) == src) { - user->setOperand(j, gpuMemcpyOp.getOperand(1)); - } - } - } - } - { - auto users = dst.getUsers(); - std::vector usersVec(users.begin(), users.end()); - for (auto user : usersVec) { - for (size_t j = 0; j < user->getNumOperands(); j++) { - if (user->getOperand(j) == src) { - user->setOperand(j, gpuMemcpyOp.getOperand(0)); - } - } - } - } + src.replaceAllUsesWith(gpuMemcpyOp->getResult(1)); + dst.replaceAllUsesWith(gpuMemcpyOp->getResult(0)); copyOp->erase(); } // Allocate space on GPU and copy global memrefs to GPU, needs deallocation @@ -206,47 +170,34 @@ void ConvertMemcpyToGPUPass::runOnOperation() { auto memrefType = dyn_cast(result.getType()); auto gpuAllocOp = builder.create( getGlobalOp->getLoc(), TypeRange({memrefType}), ValueRange({})); - unDeallocatedOperations.insert(&gpuAllocOp); + unDeallocatedValue.push_back(gpuAllocOp->getResult(0)); + auto src = result; auto dst = gpuAllocOp->getResult(0); auto gpuMemcpyOp = builder.create( gpuAllocOp->getLoc(), TypeRange(), ValueRange(), dst, src); - { - auto users = src.getUsers(); - std::vector usersVec(users.begin(), users.end()); - for (auto user : usersVec) { - if (isa(user)) - continue; - // TODO: replace with src.replaceAllUsesExcept() - for (size_t j = 0; j < user->getNumOperands(); j++) { - if (user->getOperand(j) == src) { - user->setOperand(j, dst); - } - } - } - } + src.replaceAllUsesExcept(dst, gpuMemcpyOp); } // Copy data back to CPU, deallocate GPU, then return else if (auto returnOp = dyn_cast(nestedOp)) { builder.setInsertionPoint(returnOp); - - for (auto *gpuAllocOp : unDeallocatedOperations) { - auto gpuDeallocOp = builder.create( - builder.getUnknownLoc(), TypeRange(), ValueRange(), - gpuAllocOp->getResult(0)); - } - builder.setInsertionPoint(returnOp); for (unsigned i = 0; i < returnOp.getNumOperands(); ++i) { auto val = returnOp->getOperand(i); - auto memRefType = dyn_cast(val.getType()); - auto allocOp = builder.create(builder.getUnknownLoc(), - memRefType); - auto gpuMemcpyOp = builder.create( - allocOp.getLoc(), TypeRange(), ValueRange(), allocOp->getResult(0), - val); - auto gpuDeallocOp = builder.create( - gpuMemcpyOp->getLoc(), TypeRange(), ValueRange(), val); - returnOp->setOperand(i, allocOp->getResult(0)); + if (auto memrefType = dyn_cast(val.getType())) { + auto allocOp = + builder.create(returnOp->getLoc(), memrefType); + builder.create(allocOp.getLoc(), TypeRange(), + ValueRange(), allocOp->getResult(0), + val); + // FIXME: may be leak memory + // auto gpuDeallocOp = builder.create( + // gpuMemcpyOp->getLoc(), TypeRange(), ValueRange(), val); + returnOp->setOperand(i, allocOp->getResult(0)); + } + } + for (auto value : unDeallocatedValue) { + builder.create(returnOp->getLoc(), TypeRange(), + ValueRange(), value); } } return WalkResult::advance(); diff --git a/tests/Conversion/convert-memcpy-to-gpu.mlir b/tests/Conversion/convert-memcpy-to-gpu.mlir index 63edfd8d02..65e9301e4a 100644 --- a/tests/Conversion/convert-memcpy-to-gpu.mlir +++ b/tests/Conversion/convert-memcpy-to-gpu.mlir @@ -1,22 +1,68 @@ -// RUN: buddy-opt -convert-memcpy-to-gpu -canonicalize %s | FileCheck %s +// RUN: buddy-opt -convert-memcpy-to-gpu="process-args=1" %s | FileCheck %s -// CHECK: %memref = gpu.alloc () : memref<32x32xf32> -// CHECK: %memref_0 = gpu.alloc () : memref<32x32xf32> -// CHECK: gpu.dealloc %memref : memref<32x32xf32> -// CHECK: %alloc = memref.alloc() : memref<32x32xf32> -// CHECK: gpu.memcpy %alloc, %memref_0 : memref<32x32xf32>, memref<32x32xf32> -// CHECK: gpu.dealloc %memref_0 : memref<32x32xf32> +#map = affine_map<(d0)[s0, s1] -> (d0 * s0 + s1)> module attributes {gpu.container_module} { - func.func @matmul(%arg0: memref<32x32xf32>, %arg1: memref<32x32xf32>) -> memref<32x32xf32> { - %c2 = arith.constant 2 : index - %c64 = arith.constant 64 : index + memref.global "private" constant @__constant_1x10x10xf32 : memref<1x10x10xf32> = dense<1.000000e+00> {alignment = 64 : i64} + func.func @matmul(%arg0: memref<1x10x10xf32>, %arg1: memref<1x10x10xf32>) -> memref<1x10x10xf32> { + // CHECK: %[[d_arg0:.*]] = gpu.alloc () : memref<1x10x10xf32> + // CHECK-NEXT: gpu.memcpy %[[d_arg0]], %arg0 : memref<1x10x10xf32>, memref<1x10x10xf32> + // CHECK: %[[d_arg1:.*]] = gpu.alloc () : memref<1x10x10xf32> + // CHECK-NEXT: gpu.memcpy %[[d_arg1:.*]], %arg1 : memref<1x10x10xf32>, memref<1x10x10xf32> + %c10 = arith.constant 10 : index + %c0 = arith.constant 0 : index %c1 = arith.constant 1 : index - %alloc = memref.alloc() {alignment = 64 : i64} : memref<32x32xf32> - gpu.launch_func @matmul_kernel::@matmul_kernel blocks in (%c1, %c1, %c1) threads in (%c64, %c2, %c1) - return %alloc : memref<32x32xf32> + %cst = arith.constant 0.000000e+00 : f32 + // CHECK: %[[h_global_data:.*]] = memref.get_global @__constant_1x10x10xf32 : memref<1x10x10xf32> + // CHECK: %[[d_global_data:.*]] = gpu.alloc () : memref<1x10x10xf32> + // CHECK: gpu.memcpy %[[d_global_data]], %[[h_global_data]] : memref<1x10x10xf32>, memref<1x10x10xf32> + %0 = memref.get_global @__constant_1x10x10xf32 : memref<1x10x10xf32> + // CHECK: %[[d_alloc0:.*]] = gpu.alloc () : memref<1x10x10xf32> + %alloc = memref.alloc() {alignment = 64 : i64} : memref<1x10x10xf32> + // CHECK: gpu.launch_func + gpu.launch_func @kernel::@fill blocks in (%c10, %c10, %c1) threads in (%c1, %c1, %c1) args(%c1 : index, %c0 : index, %cst : f32, %alloc : memref<1x10x10xf32>) + // CHECK: gpu.launch_func + // CHECK-SAME: %[[d_arg0]] + // CHECK-SAME: %[[d_arg1]] + // CHECK-SAME: %[[d_alloc0]] + gpu.launch_func @kernel::@matmul blocks in (%c10, %c10, %c1) threads in (%c1, %c1, %c1) args(%c1 : index, %c0 : index, %arg0 : memref<1x10x10xf32>, %arg1 : memref<1x10x10xf32>, %alloc : memref<1x10x10xf32>, %c10 : index) + // CHECK: %[[d_alloc1:.*]] = gpu.alloc () : memref<1x10x10xf32> + %alloc_0 = memref.alloc() {alignment = 64 : i64} : memref<1x10x10xf32> + // CHECK: gpu.launch_func + gpu.launch_func @kernel::@fill blocks in (%c10, %c10, %c1) threads in (%c1, %c1, %c1) args(%c1 : index, %c0 : index, %cst : f32, %alloc_0 : memref<1x10x10xf32>) + // CHECK: gpu.launch_func + // CHECK-SAME: %[[d_global_data]] + // CHECK-SAME: %[[d_alloc0]] + // CHECK-SAME: %[[d_alloc1]] + gpu.launch_func @kernel::@matmul blocks in (%c10, %c10, %c1) threads in (%c1, %c1, %c1) args(%c1 : index, %c0 : index, %0 : memref<1x10x10xf32>, %alloc : memref<1x10x10xf32>, %alloc_0 : memref<1x10x10xf32>, %c10 : index) + // CHECK: %[[d_result:.*]] = gpu.alloc () : memref<1x10x10xf32> + %alloc_1 = memref.alloc() {alignment = 64 : i64} : memref<1x10x10xf32> + // CHECK: gpu.launch_func + gpu.launch_func @kernel::@fill blocks in (%c10, %c10, %c1) threads in (%c1, %c1, %c1) args(%c1 : index, %c0 : index, %cst : f32, %alloc_1 : memref<1x10x10xf32>) + // CHECK: gpu.launch_func + // CHECK-SAME: %[[d_alloc0]] + // CHECK-SAME: %[[d_alloc1]] + // CHECK-SAME: %[[d_result]] + gpu.launch_func @kernel::@matmul blocks in (%c10, %c10, %c1) threads in (%c1, %c1, %c1) args(%c1 : index, %c0 : index, %alloc : memref<1x10x10xf32>, %alloc_0 : memref<1x10x10xf32>, %alloc_1 : memref<1x10x10xf32>, %c10 : index) + // CHECK: gpu.dealloc %[[d_alloc1]] : memref<1x10x10xf32> + memref.dealloc %alloc_0 : memref<1x10x10xf32> + // CHECK: gpu.dealloc %[[d_alloc0]] : memref<1x10x10xf32> + memref.dealloc %alloc : memref<1x10x10xf32> + + // CHECK: %[[h_alloc:.*]] = memref.alloc() : memref<1x10x10xf32> + // CHECK-NEXT: gpu.memcpy %[[h_alloc]], %[[d_result]] : memref<1x10x10xf32>, memref<1x10x10xf32> + + // CHECK: gpu.dealloc %[[d_arg0]] : memref<1x10x10xf32> + // CHECK: gpu.dealloc %[[d_arg1]] : memref<1x10x10xf32> + // CHECK: gpu.dealloc %[[d_global_data]] : memref<1x10x10xf32> + + // CHECK: return %[[h_alloc]] : memref<1x10x10xf32> + return %alloc_1 : memref<1x10x10xf32> } - gpu.module @matmul_kernel { - gpu.func @matmul_kernel() kernel attributes {gpu.known_block_size = array, gpu.known_grid_size = array} { + gpu.module @kernel { + gpu.func @fill(%arg0: index, %arg1: index, %arg2: f32, %arg3: memref<1x10x10xf32>) kernel attributes {gpu.known_block_size = array} { + gpu.return + } + gpu.func @matmul(%arg0: index, %arg1: index, %arg2: memref<1x10x10xf32>, %arg3: memref<1x10x10xf32>, %arg4: memref<1x10x10xf32>, %arg5: index) kernel attributes {gpu.known_block_size = array} { gpu.return } }