From 2f208b91eef10907fafdd4b83f49bdb8fd7870ec Mon Sep 17 00:00:00 2001 From: weilinquan <58080862+weilinquan@users.noreply.github.com> Date: Fri, 10 May 2024 15:12:07 +0800 Subject: [PATCH] [BuddyLeNet] Fix lenet error and format files (#291) * fix lenet error * format files * update cmake file * fix lenet cmake files --- examples/BuddyLeNet/.gitignore | 2 ++ examples/BuddyLeNet/CMakeLists.txt | 33 +++++++++++++------ examples/BuddyLeNet/buddy-lenet-import.py | 26 ++++++++++----- frontend/Python/graph/graph_driver.py | 2 +- frontend/Python/ops/func.py | 25 ++++++++++---- .../FuncBufferize/FuncBufferizePass.cpp | 2 +- 6 files changed, 64 insertions(+), 26 deletions(-) diff --git a/examples/BuddyLeNet/.gitignore b/examples/BuddyLeNet/.gitignore index ab3c57f2c7..8ef196d742 100644 --- a/examples/BuddyLeNet/.gitignore +++ b/examples/BuddyLeNet/.gitignore @@ -6,3 +6,5 @@ data __pycache__ *.pth lenet.mlir +forward.mlir +subgraph0.mlir diff --git a/examples/BuddyLeNet/CMakeLists.txt b/examples/BuddyLeNet/CMakeLists.txt index dc7d98bd91..9698f617bc 100644 --- a/examples/BuddyLeNet/CMakeLists.txt +++ b/examples/BuddyLeNet/CMakeLists.txt @@ -1,20 +1,33 @@ add_custom_command( - OUTPUT ${BUDDY_EXAMPLES_DIR}/BuddyLeNet/lenet.mlir ${BUDDY_EXAMPLES_DIR}/BuddyLeNet/arg0.data + OUTPUT ${BUDDY_EXAMPLES_DIR}/BuddyLeNet/forward.mlir ${BUDDY_EXAMPLES_DIR}/BuddyLeNet/subgraph0.mlir ${BUDDY_EXAMPLES_DIR}/BuddyLeNet/arg0.data COMMAND python3 ${BUDDY_EXAMPLES_DIR}/BuddyLeNet/buddy-lenet-import.py - COMMENT "Generating lenet.mlir and parameter files" + COMMENT "Generating forward.mlir, subgraph0.mlir and parameter files" ) add_custom_command( - OUTPUT lenet.o - COMMAND ${LLVM_MLIR_BINARY_DIR}/mlir-opt ${BUDDY_EXAMPLES_DIR}/BuddyLeNet/lenet.mlir - -pass-pipeline "builtin.module(func.func(tosa-to-linalg-named, tosa-to-linalg, tosa-to-tensor, tosa-to-arith))" | + OUTPUT forward.o + COMMAND ${LLVM_MLIR_BINARY_DIR}/mlir-opt ${BUDDY_EXAMPLES_DIR}/BuddyLeNet/forward.mlir + -pass-pipeline "builtin.module(func.func(tosa-to-linalg-named, tosa-to-linalg, tosa-to-tensor, tosa-to-arith), empty-tensor-to-alloc-tensor, convert-elementwise-to-linalg, arith-bufferize, func.func(linalg-bufferize, tensor-bufferize), func-bufferize)" | ${LLVM_MLIR_BINARY_DIR}/mlir-opt + -pass-pipeline "builtin.module(func.func(buffer-deallocation-simplification, convert-linalg-to-loops), eliminate-empty-tensors, func.func(llvm-request-c-wrappers),convert-math-to-llvm, convert-math-to-libm, convert-scf-to-cf, convert-arith-to-llvm, expand-strided-metadata, finalize-memref-to-llvm, convert-func-to-llvm, reconcile-unrealized-casts)" | + ${LLVM_MLIR_BINARY_DIR}/mlir-translate -mlir-to-llvmir | + ${LLVM_MLIR_BINARY_DIR}/llvm-as | + ${LLVM_MLIR_BINARY_DIR}/llc -filetype=obj -relocation-model=pic -O0 -o ${BUDDY_BINARY_DIR}/../examples/BuddyLeNet/forward.o + DEPENDS ${BUDDY_EXAMPLES_DIR}/BuddyLeNet/forward.mlir + COMMENT "Building forward.o" + VERBATIM) + +add_custom_command( + OUTPUT subgraph0.o + COMMAND ${LLVM_MLIR_BINARY_DIR}/mlir-opt ${BUDDY_EXAMPLES_DIR}/BuddyLeNet/subgraph0.mlir + -pass-pipeline "builtin.module(func.func(tosa-to-linalg-named, tosa-to-linalg, tosa-to-tensor, tosa-to-arith))" | + ${BUDDY_BINARY_DIR}/buddy-opt -eliminate-empty-tensors -convert-tensor-to-linalg -linalg-bufferize -convert-linalg-to-affine-loops -lower-affine - -func-bufferize + -func-bufferize-dynamic-offset -arith-bufferize -tensor-bufferize -buffer-deallocation @@ -31,12 +44,12 @@ add_custom_command( -reconcile-unrealized-casts | ${LLVM_MLIR_BINARY_DIR}/mlir-translate -mlir-to-llvmir | ${LLVM_MLIR_BINARY_DIR}/llvm-as | - ${LLVM_MLIR_BINARY_DIR}/llc -filetype=obj -relocation-model=pic -O0 -o ${BUDDY_BINARY_DIR}/../examples/BuddyLeNet/lenet.o - DEPENDS ${BUDDY_EXAMPLES_DIR}/BuddyLeNet/lenet.mlir - COMMENT "Building lenet.o" + ${LLVM_MLIR_BINARY_DIR}/llc -filetype=obj -relocation-model=pic -O0 -o ${BUDDY_BINARY_DIR}/../examples/BuddyLeNet/subgraph0.o + DEPENDS ${BUDDY_EXAMPLES_DIR}/BuddyLeNet/subgraph0.mlir + COMMENT "Building subgraph0.o" VERBATIM) -add_library(LENET STATIC lenet.o) +add_library(LENET STATIC subgraph0.o forward.o) SET_TARGET_PROPERTIES(LENET PROPERTIES LINKER_LANGUAGE C) diff --git a/examples/BuddyLeNet/buddy-lenet-import.py b/examples/BuddyLeNet/buddy-lenet-import.py index b577ae8832..95e76de253 100644 --- a/examples/BuddyLeNet/buddy-lenet-import.py +++ b/examples/BuddyLeNet/buddy-lenet-import.py @@ -19,12 +19,15 @@ # ===--------------------------------------------------------------------------- import os +from pathlib import Path -import numpy +import numpy as np import torch from torch._inductor.decomposition import decompositions as inductor_decomp from buddy.compiler.frontend import DynamoCompiler +from buddy.compiler.graph import GraphDriver +from buddy.compiler.graph.transform import simply_fuse from buddy.compiler.ops import tosa from model import LeNet @@ -53,14 +56,21 @@ assert len(graphs) == 1 graph = graphs[0] params = dynamo_compiler.imported_params[graph] -graph.lower_to_top_level_ir(do_params_pack=True) +pattern_list = [simply_fuse] +graphs[0].fuse_ops(pattern_list) +driver = GraphDriver(graphs[0]) +driver.subgraphs[0].lower_to_top_level_ir() path_prefix = os.path.dirname(os.path.abspath(__file__)) -# Write the MLIR module to the file. -with open(os.path.join(path_prefix, "lenet.mlir"), "w") as module_file: - print(graph._imported_module, file=module_file) +with open(os.path.join(path_prefix, "subgraph0.mlir"), "w") as module_file: + print(driver.subgraphs[0]._imported_module, file=module_file) +with open(os.path.join(path_prefix, "forward.mlir"), "w") as module_file: + print(driver.construct_main_graph(True), file=module_file) -# Concatenate all parameters into a single numpy array and write to a file. -all_param = numpy.concatenate( +params = dynamo_compiler.imported_params[graph] +current_path = os.path.dirname(os.path.abspath(__file__)) + +float32_param = np.concatenate( [param.detach().numpy().reshape([-1]) for param in params] ) -all_param.tofile(os.path.join(path_prefix, "arg0.data")) + +float32_param.tofile(Path(current_path) / "arg0.data") diff --git a/frontend/Python/graph/graph_driver.py b/frontend/Python/graph/graph_driver.py index f1f9b5fa87..50a8869d5a 100644 --- a/frontend/Python/graph/graph_driver.py +++ b/frontend/Python/graph/graph_driver.py @@ -1,4 +1,4 @@ -# ===- graph_driver.py ------------------------------------------------------------- +# ===- graph_driver.py --------------------------------------------------------- # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/frontend/Python/ops/func.py b/frontend/Python/ops/func.py index 2d807eb6dc..ad6e512be7 100644 --- a/frontend/Python/ops/func.py +++ b/frontend/Python/ops/func.py @@ -37,8 +37,12 @@ def func_op(node: FuncOp, symbol_table: Dict[Tuple[str, int], ir.Operation]): mlir_dtype = mlir_element_type_get(arg.dtype) stride = [] for dim, dim_size in enumerate(shape): - stride.append(functools.reduce(lambda x, y: x * y, shape[dim+1:]+[1])) - memref_attr = ir.Attribute.parse("strided<{}, offset: ?>".format(stride)) + stride.append( + functools.reduce(lambda x, y: x * y, shape[dim + 1 :] + [1]) + ) + memref_attr = ir.Attribute.parse( + "strided<{}, offset: ?>".format(stride) + ) arguments.append(ir.MemRefType.get(shape, mlir_dtype, memref_attr)) results = [] for i, shape in enumerate(node.tensor_meta["shape"]): @@ -61,8 +65,12 @@ def call_op(node: CallOp, symbol_table: Dict[Tuple[str, int], ir.Operation]): stride = [] shape = memref_type.shape for dim, dim_size in enumerate(shape): - stride.append(functools.reduce(lambda x, y: x * y, shape[dim+1:]+[1])) - memref_attr = ir.Attribute.parse("strided<{}, offset: ?>".format(stride)) + stride.append( + functools.reduce(lambda x, y: x * y, shape[dim + 1 :] + [1]) + ) + memref_attr = ir.Attribute.parse( + "strided<{}, offset: ?>".format(stride) + ) dest = ir.MemRefType.get(shape, memref_type.element_type, memref_attr) cast_op = memref.CastOp(dest, input_node) arguments.append(cast_op) @@ -125,7 +133,9 @@ def param_extract( return memref_subview_op stride = [] for dim, dim_size in enumerate(output_shape): - stride.append(functools.reduce(lambda x, y: x * y, output_shape[dim+1:]+[1])) + stride.append( + functools.reduce(lambda x, y: x * y, output_shape[dim + 1 :] + [1]) + ) memref_attr = ir.Attribute.parse( "strided<{}, offset: {}>".format(stride, offset) ) @@ -143,9 +153,12 @@ def param_extract( None, ) axis = ir.ArrayAttr.get([axis], None) - expand_shape_op = memref.ExpandShapeOp(memref_type, memref_subview_op.result, axis) + expand_shape_op = memref.ExpandShapeOp( + memref_type, memref_subview_op.result, axis + ) return expand_shape_op + ops_registry = { "FuncOp": func_op, "CallOp": call_op, diff --git a/midend/lib/Conversion/FuncBufferize/FuncBufferizePass.cpp b/midend/lib/Conversion/FuncBufferize/FuncBufferizePass.cpp index 6609229bfe..4a14de28ae 100644 --- a/midend/lib/Conversion/FuncBufferize/FuncBufferizePass.cpp +++ b/midend/lib/Conversion/FuncBufferize/FuncBufferizePass.cpp @@ -43,7 +43,6 @@ #include "llvm/Support/Debug.h" #include #include -#include #include using namespace mlir; using namespace mlir::func; @@ -53,6 +52,7 @@ class FuncBufferizeDynamicOffsetPass : public PassWrapper> { public: + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(FuncBufferizeDynamicOffsetPass) FuncBufferizeDynamicOffsetPass() = default; llvm::StringRef getArgument() const final { return "func-bufferize-dynamic-offset";