From 22bb0fabec0f435e6130cfd1fb3f848d8a7bf536 Mon Sep 17 00:00:00 2001 From: matrix72c <60974665+matrix72c@users.noreply.github.com> Date: Sat, 12 Oct 2024 15:54:59 +0800 Subject: [PATCH] [GPU] Add basic GPU support and example (#381) Co-authored-by: SForeKeeper --- examples/BuddyGPU/.gitignore | 1 + examples/BuddyGPU/README.md | 40 ++ examples/BuddyGPU/makefile | 14 + examples/BuddyGPU/run-module-gpu.py | 147 +++++ examples/BuddyGPU/transform.mlir | 290 +++++++++- midend/include/Dialect/CMakeLists.txt | 1 + midend/include/Dialect/GPU/CMakeLists.txt | 4 + midend/include/Dialect/GPU/TransformOps.h | 74 +++ midend/include/Dialect/GPU/TransformOps.td | 127 +++++ midend/include/Utils/GPUUtils.h | 104 ++++ midend/lib/Conversion/CMakeLists.txt | 1 + midend/lib/Conversion/MLIRGPU/CMakeLists.txt | 28 + .../Conversion/MLIRGPU/ConvertMemcpyToGPU.cpp | 263 +++++++++ .../MLIRGPU/LegalizeShmemOutlining.cpp | 433 ++++++++++++++ midend/lib/Dialect/CMakeLists.txt | 1 + midend/lib/Dialect/GPU/CMakeLists.txt | 42 ++ midend/lib/Dialect/GPU/TransformOps.cpp | 211 +++++++ midend/lib/Utils/CMakeLists.txt | 32 +- midend/lib/Utils/GPUUtils.cpp | 536 ++++++++++++++++++ tests/Conversion/convert-memcpy-to-gpu.mlir | 23 + .../Conversion/legalize-shmem-outlining.mlir | 26 + .../Dialect/BuddyGPU/hoist-static-alloc.mlir | 92 +++ ...transform-dialect-vector-to-nvgpu-mma.mlir | 97 ++++ tools/buddy-opt/CMakeLists.txt | 6 + tools/buddy-opt/buddy-opt.cpp | 13 +- 25 files changed, 2602 insertions(+), 4 deletions(-) create mode 100644 examples/BuddyGPU/README.md create mode 100644 examples/BuddyGPU/run-module-gpu.py create mode 100644 midend/include/Dialect/GPU/CMakeLists.txt create mode 100644 midend/include/Dialect/GPU/TransformOps.h create mode 100644 midend/include/Dialect/GPU/TransformOps.td create mode 100644 midend/include/Utils/GPUUtils.h create mode 100644 midend/lib/Conversion/MLIRGPU/CMakeLists.txt create mode 100644 midend/lib/Conversion/MLIRGPU/ConvertMemcpyToGPU.cpp create mode 100644 midend/lib/Conversion/MLIRGPU/LegalizeShmemOutlining.cpp create mode 100644 midend/lib/Dialect/GPU/CMakeLists.txt create mode 100644 midend/lib/Dialect/GPU/TransformOps.cpp create mode 100644 midend/lib/Utils/GPUUtils.cpp create mode 100644 tests/Conversion/convert-memcpy-to-gpu.mlir create mode 100644 tests/Conversion/legalize-shmem-outlining.mlir create mode 100644 tests/Dialect/BuddyGPU/hoist-static-alloc.mlir create mode 100644 tests/Dialect/BuddyGPU/transform-dialect-vector-to-nvgpu-mma.mlir diff --git a/examples/BuddyGPU/.gitignore b/examples/BuddyGPU/.gitignore index 0194ea7a6..d82aeb33b 100644 --- a/examples/BuddyGPU/.gitignore +++ b/examples/BuddyGPU/.gitignore @@ -1,3 +1,4 @@ log.mlir log.ll log.s +matmul-cubin.mlir diff --git a/examples/BuddyGPU/README.md b/examples/BuddyGPU/README.md new file mode 100644 index 000000000..7c4081e40 --- /dev/null +++ b/examples/BuddyGPU/README.md @@ -0,0 +1,40 @@ +# Buddy GPU Example +This example demonstrates how to use the Buddy GPU to run a simple single-kernel program. + +## Matmul +The example program is a simple matrix multiplication kernel. The linalg definition is in the `matmul.mlir` file. +A transform sequence is in `transform.mlir` to optimize this kernel and prepare it for execution on the GPU. +The `matmul-cubin.mlir` provides a lowered file, in case the pipeline is not working. + +Run the following command to compile and run the program: +``` + make buddy-gpu-matmul + python run-module-gpu.py --source matmul.mlir --target matmul-cubin.mlir --llvm_dir ../../llvm +``` + +The result should be: +``` +[[502.9141 499.7761 511.35623 ... 500.9083 505.25574 511.03818] + [499.57034 494.8066 506.427 ... 492.7868 497.22513 509.95612] + [511.2017 516.017 513.631 ... 515.5991 515.6389 521.8318 ] + ... + [496.2721 496.3155 506.08054 ... 502.36798 505.94202 516.3577 ] + [512.06866 505.80127 518.81934 ... 510.64966 510.10333 531.85364] + [501.23514 500.17123 505.71808 ... 496.4447 500.5735 514.4204 ]] +[[503.26013 500.11093 511.70193 ... 501.24622 505.60373 511.38376] + [499.89877 495.13043 506.762 ... 493.1151 497.5555 510.29483] + [511.54883 516.35547 513.9717 ... 515.944 515.9865 522.1828 ] + ... + [496.59937 496.63785 506.41483 ... 502.70337 506.27927 516.6994 ] + [512.4154 506.1411 519.17175 ... 510.9929 510.45322 532.2152 ] + [501.57388 500.5093 506.06213 ... 496.7807 500.91638 514.77124]] +MLIR equal to NumPy? True +``` + +As the tensorcore doesn't support fp32 computation, the operands are converted to tf32, hence the result is not exactly the same as the PyTorch result. + +### Profiling +You need to install nsight compute first. +``` +ncu -o profile-result --set full python run-module-gpu.py --source matmul.mlir --target matmul-cubin.mlir --llvm_dir ../../llvm +``` \ No newline at end of file diff --git a/examples/BuddyGPU/makefile b/examples/BuddyGPU/makefile index 677396d1d..5dbd9c25c 100644 --- a/examples/BuddyGPU/makefile +++ b/examples/BuddyGPU/makefile @@ -1,8 +1,22 @@ #!/bin/bash BUDDY_OPT := ../../build/bin/buddy-opt +MLIR_OPT := ../../llvm/build/bin/mlir-opt +MLIR_TRANSLATE := ../../llvm/build/bin/mlir-translate +MLIR_CPU_RUNNER := ../../llvm/build/bin/mlir-cpu-runner +LLC := ../../llvm/build/bin/llc buddy-gpu-matmul-lower: @${BUDDY_OPT} matmul.mlir \ -transform-preload-library="transform-library-paths=transform.mlir" \ -transform-interpreter="entry-point=codegen" \ -o log.mlir + +buddy-gpu-matmul: + @${BUDDY_OPT} matmul.mlir -transform-preload-library="transform-library-paths=transform.mlir" -transform-interpreter="entry-point=codegen" | \ + ${BUDDY_OPT} --pass-pipeline='builtin.module(func.func(nvgpu-optimize-shared-memory))' | \ + ${BUDDY_OPT} -arith-expand -eliminate-empty-tensors -empty-tensor-to-alloc-tensor -linalg-bufferize -convert-linalg-to-affine-loops -affine-loop-fusion -affine-parallelize -lower-affine -canonicalize -func-bufferize -arith-bufferize -tensor-bufferize -buffer-deallocation -finalizing-bufferize -canonicalize | \ + ${BUDDY_OPT} -gpu-launch-sink-index-computations -canonicalize -legalize-shmem-outlining -canonicalize | \ + ${BUDDY_OPT} -convert-memcpy-to-gpu -gpu-async-region -canonicalize | \ + ${BUDDY_OPT} -convert-scf-to-cf -memref-expand -finalize-memref-to-llvm -convert-arith-to-llvm --convert-vector-to-llvm -convert-gpu-to-nvvm='has-redux=1' | \ + ${BUDDY_OPT} -llvm-request-c-wrappers -canonicalize -cse -sccp | \ + ${MLIR_OPT} --test-lower-to-nvvm="cubin-chip=sm_80 cubin-features=+ptx71 cubin-format=fatbin" -o matmul-cubin.mlir diff --git a/examples/BuddyGPU/run-module-gpu.py b/examples/BuddyGPU/run-module-gpu.py new file mode 100644 index 000000000..7f3b2c1e7 --- /dev/null +++ b/examples/BuddyGPU/run-module-gpu.py @@ -0,0 +1,147 @@ +# ===- run-module-gpu.py --------------------------------------------------===// +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# ===----------------------------------------------------------------------===// +# +# This file is a script to test whether the specified MLIR module on the GPU +# calculates the same result as NumPy. +# +# ===----------------------------------------------------------------------===// +import mlir.ir as ir +import mlir.dialects.func as func +import mlir.dialects.memref as memref +from mlir.passmanager import * +from mlir.execution_engine import * +from mlir import runtime as rt +from mlir.ir import * +import numpy as np +import ctypes +import ml_dtypes +import argparse as ap + + +def to_numpy(element_type: str) -> np.dtype: + match element_type: + case "f16": + return np.float16 + case "f32": + return np.float32 + case "f64": + return np.float64 + case "i8": + return np.int8 + case "i16": + return np.int16 + case "i32": + return np.int32 + case "i64": + return np.int64 + case "bf16": + return np.dtype("bfloat16") + case _: + raise ValueError(f"Unsupported type: {element_type}") + + +def new_ranked_memref_descriptor(nparray: np.ndarray): + if nparray.dtype == "bfloat16": + ctp = rt.F16 + else: + ctp = rt.as_ctype(nparray.dtype) + + if nparray.ndim == 0: + x = rt.make_zero_d_memref_descriptor(ctp)() + x.allocated = nparray.ctypes.data + x.aligned = nparray.ctypes.data_as(ctypes.POINTER(ctp)) + x.offset = ctypes.c_longlong(0) + return x + + x = rt.make_nd_memref_descriptor(nparray.ndim, ctp)() + nbytes = nparray.nbytes + buffer = ctypes.create_string_buffer(nbytes) + ctypes.memmove(buffer, nparray.ctypes.data, nbytes) + x.allocated = ctypes.cast(buffer, ctypes.c_void_p).value + x.aligned = ctypes.cast(buffer, ctypes.POINTER(ctp)) + x.offset = ctypes.c_longlong(0) + x.shape = nparray.ctypes.shape + + # Numpy uses byte quantities to express strides, MLIR OTOH uses the + # torch abstraction which specifies strides in terms of elements. + strides_ctype_t = ctypes.c_longlong * nparray.ndim + x.strides = strides_ctype_t( + *[x // nparray.itemsize for x in nparray.strides] + ) + return x + + +def get_memref_descriptors(args: list[Type]): + memref_ptrs = [] + for arg in args: + elem_type = to_numpy(str(arg.element_type)) + np_arg = np.random.rand(*arg.shape).astype(elem_type) + memref_ptrs.append( + ctypes.pointer(ctypes.pointer(new_ranked_memref_descriptor(np_arg))) + ) + return memref_ptrs + + +def test(source, target, llvm_dir): + with Context() as ctx: + file = open(source, "r") + module: Module = Module.parse(file.read()) + funcOp: func.FuncOp = ( + module.operation.regions[0].blocks[0].operations[0] + ) + funcName = str(funcOp.name).replace('"', "") + assert isinstance(funcOp, func.FuncOp) + args_type: list[Type] = [arg.type for arg in funcOp.arguments] + res_type = funcOp.type.results + + file = open(target, "r") + # newModule = lower_to_llvm_cpu(module) + newModule = Module.parse(file.read()) + memref_ptrs = get_memref_descriptors(res_type + args_type) + + engine = ExecutionEngine( + newModule, + shared_libs=[ + "/usr/lib/libomp.so", + llvm_dir + "/build/lib/libmlir_c_runner_utils.so", + llvm_dir + "/build/lib/libmlir_async_runtime.so", + llvm_dir + "/build/lib/libmlir_runner_utils.so", + llvm_dir + "/build/lib/libmlir_cuda_runtime.so", + ], + opt_level=3, + ) + engine.invoke(funcName, *memref_ptrs) + out = rt.ranked_memref_to_numpy(memref_ptrs[0][0]) + if str(res_type[0].element_type) == "bf16": + print("Running on BF16 mode, skipping numpy comparison.") + else: + print(out) + input1 = rt.ranked_memref_to_numpy(memref_ptrs[1][0]) + input2 = rt.ranked_memref_to_numpy(memref_ptrs[2][0]) + numpy_out = np.matmul(input1, input2) + print(numpy_out) + print( + f"MLIR equal to NumPy? {np.allclose(out, numpy_out,rtol=1e-03, atol=1e-03)}" + ) + + +if __name__ == "__main__": + parser = ap.ArgumentParser() + parser.add_argument("--source", type=str, required=True) + parser.add_argument("--target", type=str, required=True) + parser.add_argument("--llvm_dir", type=str, required=True) + args = parser.parse_args() + test(args.source, args.target, args.llvm_dir) diff --git a/examples/BuddyGPU/transform.mlir b/examples/BuddyGPU/transform.mlir index ef2645199..e2a02a9a9 100644 --- a/examples/BuddyGPU/transform.mlir +++ b/examples/BuddyGPU/transform.mlir @@ -9,7 +9,7 @@ module attributes { transform.with_named_sequence } { // Perform tiling for the grid. // For the matrix multiplication of 5376x2048 and 2048x5376, the compilation // strategy sets the tile size for grid-based partitioning to 128x256. - // This means that each 128x256 matmul tile is computed within a GPU block, + // This means that each [128, 2048] @ [2048, 256] matmul tile is computed within a GPU block, // while multiple such blocks are computed in parallel across the grid. // `tile_sizes` specify the dimensions of the tiled matmul result. // `%tiled_op` is the tiled matmul operation within the `scf.forall` loop. @@ -18,6 +18,294 @@ module attributes { transform.with_named_sequence } { tile_sizes [128, 256] (mapping = [#gpu.block, #gpu.block]) : (!transform.any_op) -> (!transform.any_op, !transform.any_op) + // Perform canonicalization. + %1 = transform.structured.match ops{["func.func"]} in %arg0 : (!transform.any_op) -> !transform.any_op + transform.apply_patterns to %1 { + transform.apply_patterns.linalg.tiling_canonicalization + transform.apply_patterns.scf.for_loop_canonicalization + transform.apply_patterns.canonicalization + } : !transform.any_op + transform.apply_cse to %1 : !transform.any_op + %all_loops = transform.structured.match interface{LoopLikeInterface} + in %arg0 + : (!transform.any_op) -> !transform.any_op + transform.apply_licm to %all_loops : !transform.any_op + transform.apply_patterns to %1 { + transform.apply_patterns.linalg.tiling_canonicalization + } : !transform.any_op + + // Fuse the fill operation into the scf.all op. + %fused_op, %new_containing_op = transform.structured.fuse_into_containing_op %fill into %forall_op : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op) + + // Further tile the tiled matmul + // Tile the third dimension in matmul. + // [128, 2048] @ [2048, 256] matmul is further tiled into [128, 16] @ [16, 256] matmul. + %tiled_linalg_op, %loops = transform.structured.tile_using_for %tiled_op [0, 0, 16] : (!transform.any_op) -> (!transform.any_op, !transform.any_op) + + // Create pad op and prepare for mapping to GPU. + // Nothing has changed in the operation. + %padded, %pad, %copy = transform.structured.pad %tiled_linalg_op {copy_back_op = "none", pack_paddings = [1, 1, 1], pad_to_multiple_of = [1, 1, 1], padding_dimensions = [0, 1, 2], padding_values = [0.000000e+00 : f32, 0.000000e+00 : f32, 0.000000e+00 : f32]} : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op) + + // Rewrite tensor.pad into linalg.copy. + %3 = transform.get_producer_of_operand %padded[0] : (!transform.any_op) -> !transform.any_op + %4 = transform.get_producer_of_operand %padded[1] : (!transform.any_op) -> !transform.any_op + %5 = transform.get_producer_of_operand %padded[2] : (!transform.any_op) -> !transform.any_op + %6 = transform.structured.rewrite_in_destination_passing_style %3 : (!transform.any_op) -> !transform.any_op + %7 = transform.structured.rewrite_in_destination_passing_style %4 : (!transform.any_op) -> !transform.any_op + %8 = transform.structured.rewrite_in_destination_passing_style %5 : (!transform.any_op) -> !transform.any_op + + // Tile the linalg.copy op and map it to GPU thread level, + // such that the tiled matrix are copied to GPU shared memory. + // num_threads is different from tile_sizes used above, + // as it specifies the number of tile instead of the size of the tile. + // The first transform tile the [128, 16] into [4, 4], + // and the second transform tile the [16, 256] into [2, 16]. + %tiled_op_0, %forall_op_1 = transform.structured.tile_using_forall %6 num_threads [32, 4](mapping = [#gpu.thread, #gpu.thread]) : (!transform.any_op) -> (!transform.any_op, !transform.any_op) + %tiled_op_2, %forall_op_3 = transform.structured.tile_using_forall %7 num_threads [8, 16](mapping = [#gpu.thread, #gpu.thread]) : (!transform.any_op) -> (!transform.any_op, !transform.any_op) + + // Tile the linalg.matmul op and map it to GPU warp level. + %tiled_op_4, %forall_op_5 = transform.structured.tile_using_forall %padded num_threads [2, 2](mapping = [#gpu.warp, #gpu.warp]) : (!transform.any_op) -> (!transform.any_op, !transform.any_op) + // Tile the linalg.fill op and map it to GPU warp level. + %tiled_op_6, %forall_op_7 = transform.structured.tile_using_forall %fused_op num_threads [2, 2](mapping = [#gpu.warp, #gpu.warp]) : (!transform.any_op) -> (!transform.any_op, !transform.any_op) + + // Perform canonicalization. + %9 = transform.structured.match ops{["func.func"]} in %arg0 : (!transform.any_op) -> !transform.any_op + transform.apply_patterns to %9 { + transform.apply_patterns.linalg.tiling_canonicalization + transform.apply_patterns.scf.for_loop_canonicalization + transform.apply_patterns.canonicalization + } : !transform.any_op + transform.apply_cse to %9 : !transform.any_op + %all_loops_2 = transform.structured.match interface{LoopLikeInterface} + in %9 + : (!transform.any_op) -> !transform.any_op + transform.apply_licm to %all_loops_2 : !transform.any_op + transform.apply_patterns to %9 { + transform.apply_patterns.linalg.tiling_canonicalization + transform.apply_patterns.vector.lower_masked_transfers + } : !transform.any_op + + // Perform vectorization. + // Vectorize the linalg.copy, linalg.fill, and linalg.matmul operations. + %10 = transform.structured.vectorize_children_and_apply_patterns %9 : (!transform.any_op) -> !transform.any_op + + // Perform canonicalization. + transform.apply_patterns to %10 { + transform.apply_patterns.linalg.tiling_canonicalization + transform.apply_patterns.scf.for_loop_canonicalization + transform.apply_patterns.canonicalization + } : !transform.any_op + transform.apply_cse to %10 : !transform.any_op + %all_loops_3 = transform.structured.match interface{LoopLikeInterface} + in %10 + : (!transform.any_op) -> !transform.any_op + transform.apply_licm to %all_loops_3 : !transform.any_op + transform.apply_patterns to %10 { + transform.apply_patterns.linalg.tiling_canonicalization + transform.apply_patterns.vector.lower_masked_transfers + } : !transform.any_op + + // Match bufferization.alloc_tensors inside the forall op + %scf_forall = transform.structured.match ops{["scf.forall"]} attributes{mapping = [#gpu.block, #gpu.block]} in %arg0 : (!transform.any_op) -> !transform.any_op + %alloc_tensor_ops = transform.structured.match ops{["bufferization.alloc_tensor"]} in %scf_forall : (!transform.any_op) -> !transform.any_op + + // Bufferize the alloc_tensor ops to memref.alloc ops. + // The memory_space attribute for GPU Dialect 0 means global memory, 3 means workgroup memory address, 5 means private memory address. + // According to https://discourse.llvm.org/t/rfc-memref-memory-shape-as-attribute/2229 + %buffer, %new_ops = transform.structured.bufferize_to_allocation %alloc_tensor_ops {memory_space = 3 } : !transform.any_op + + // Eliminate empty tensors and erase unnecessary inputs. + transform.structured.eliminate_empty_tensors %arg0 : !transform.any_op + %func_eras = transform.structured.match ops{["func.func"]} in %arg0 : (!transform.any_op) -> !transform.any_op + transform.apply_patterns to %func_eras { + transform.apply_patterns.linalg.erase_unnecessary_inputs + } : !transform.any_op + + // Bufferize the remaining operations in one time. + %11 = transform.bufferization.one_shot_bufferize %arg0 { bufferize_function_boundaries = true, function_boundary_type_conversion = 1 : i32} : (!transform.any_op) -> !transform.any_op + + // Erase dead alloc and stores. + %12 = transform.structured.match ops{["func.func"]} in %11 : (!transform.any_op) -> !transform.any_op + transform.memref.erase_dead_alloc_and_stores %12 : (!transform.any_op) -> () + + // Generate GPU launch. + %13 = transform.structured.match ops{["func.func"]} in %11 : (!transform.any_op) -> !transform.any_op + %gpu_launch = transform.gpu.map_forall_to_blocks %13 { generate_gpu_launch } : (!transform.any_op) -> !transform.any_op + + // Rewrite bufferized scf.forall ops to distributed gpu.thread_id attribute. + %mapped = transform.gpu.map_nested_forall_to_threads %gpu_launch block_dims = [64, 2, 1] warp_size = 32 : (!transform.any_op) -> !transform.any_op + + %15 = transform.structured.match ops{["func.func"]} in %11 : (!transform.any_op) -> !transform.any_op + + // Removes unnecessary GPU barriers from the function. + // %15 = transform.buddy.eliminate_gpu_barriers %14 : (!transform.any_op) -> !transform.any_op + + // Perform canonicalization. + transform.apply_patterns to %15 { + transform.apply_patterns.linalg.tiling_canonicalization + transform.apply_patterns.scf.for_loop_canonicalization + transform.apply_patterns.canonicalization + } : !transform.any_op + transform.apply_cse to %15 : !transform.any_op + %all_loops_4 = transform.structured.match interface{LoopLikeInterface} + in %15 + : (!transform.any_op) -> !transform.any_op + transform.apply_licm to %all_loops_4 : !transform.any_op + transform.apply_patterns to %15 { + transform.apply_patterns.linalg.tiling_canonicalization + transform.apply_patterns.vector.lower_masked_transfers + } : !transform.any_op + + // Identify static memory allocations within the given region, + // and move them to a higher level (hoisting). + transform.buddy.hoist_static_alloc %15 : (!transform.any_op) -> () + + // Collects patterns for folding memref aliasing ops (memref.subview) into consumer load/store ops (affine.load, memref.load, nvgpu.ldmatrix, vector.load, vector.transfer_read, affine.store, memref.store, etc.) and other ops (e.g., memref.subview). + transform.apply_patterns to %15 { + transform.apply_patterns.memref.fold_memref_alias_ops + } : !transform.any_op + // Collects patterns for extracting address computations from operations with memory accesses such that these memory accesses use only a base pointer. + transform.apply_patterns to %15 { + transform.apply_patterns.memref.extract_address_computations + } : !transform.any_op + // Perform canonicalization. + transform.apply_patterns to %15 { + transform.apply_patterns.linalg.tiling_canonicalization + transform.apply_patterns.scf.for_loop_canonicalization + transform.apply_patterns.canonicalization + } : !transform.any_op + transform.apply_cse to %15 : !transform.any_op + %all_loops_5 = transform.structured.match interface{LoopLikeInterface} + in %15 + : (!transform.any_op) -> !transform.any_op + transform.apply_licm to %all_loops_5 : !transform.any_op + transform.apply_patterns to %15 { + transform.apply_patterns.linalg.tiling_canonicalization + transform.apply_patterns.vector.lower_masked_transfers + } : !transform.any_op + + // Adds patterns that unroll vectors to a native tile size for GPUs with mma operations + transform.apply_patterns to %15 { + transform.apply_patterns.buddy.unroll_vectors_gpu_mma_sync + } : !transform.any_op + + // Insert a gpu.barrier after a given scf.for loop + %16 = transform.structured.match ops{["scf.for"]} in %15 : (!transform.any_op) -> !transform.op<"scf.for"> + // transform.buddy.synchronize_loop %16 : (!transform.op<"scf.for">) -> () + + + transform.apply_patterns to %15 { + transform.apply_patterns.memref.fold_memref_alias_ops + } : !transform.any_op + transform.apply_cse to %15 : !transform.any_op + + // Hoist vector.transfer_read / vector.transfer_write pairs out of immediately enclosing scf::ForOp iteratively + // Warning: Deprecated + %17 = transform.structured.hoist_redundant_vector_transfers %15 : (!transform.any_op) -> !transform.any_op + + // Perform canonicalization. + transform.apply_patterns to %17 { + transform.apply_patterns.linalg.tiling_canonicalization + transform.apply_patterns.scf.for_loop_canonicalization + transform.apply_patterns.canonicalization + } : !transform.any_op + transform.apply_cse to %17 : !transform.any_op + %all_loops_6 = transform.structured.match interface{LoopLikeInterface} + in %17 + : (!transform.any_op) -> !transform.any_op + transform.apply_licm to %all_loops_6 : !transform.any_op + transform.apply_patterns to %17 { + transform.apply_patterns.linalg.tiling_canonicalization + transform.apply_patterns.vector.lower_masked_transfers + } : !transform.any_op + + // This converts slices of operations containing vector.contract op into + // mma operations, targetting warp level tensorcore operations. + transform.buddy.vector.vector_to_mma_conversion %17 {use_mma_sync} : (!transform.any_op) -> () + + // %18 = transform.buddy.eliminate_gpu_barriers %17 : (!transform.any_op) -> !transform.any_op + + // Perform canonicalization. + transform.apply_patterns to %17 { + transform.apply_patterns.linalg.tiling_canonicalization + transform.apply_patterns.scf.for_loop_canonicalization + transform.apply_patterns.canonicalization + } : !transform.any_op + transform.apply_cse to %17 : !transform.any_op + %all_loops_7 = transform.structured.match interface{LoopLikeInterface} + in %17 + : (!transform.any_op) -> !transform.any_op + transform.apply_licm to %all_loops_7 : !transform.any_op + transform.apply_patterns to %17 { + transform.apply_patterns.linalg.tiling_canonicalization + transform.apply_patterns.vector.lower_masked_transfers + } : !transform.any_op + + %19 = transform.structured.match ops{["gpu.launch"]} in %17 : (!transform.any_op) -> !transform.any_op + %fwfa = transform.structured.match ops{["memref.alloc"]} in %19 : (!transform.any_op) -> !transform.op<"memref.alloc"> + + // Do multi-buffering/array expansion to remove dependencies on the temporary allocation between consecutive loop iterations. + transform.memref.multibuffer %fwfa {factor = 3 : i64, skip_analysis} : (!transform.op<"memref.alloc">) -> !transform.any_op + + transform.apply_patterns to %17 { + transform.apply_patterns.vector.transfer_to_scf full_unroll = true + } : !transform.any_op + transform.apply_patterns to %17 { + transform.apply_patterns.linalg.tiling_canonicalization + transform.apply_patterns.scf.for_loop_canonicalization + transform.apply_patterns.canonicalization + } : !transform.any_op + transform.apply_cse to %17 : !transform.any_op + %all_loops_8 = transform.structured.match interface{LoopLikeInterface} + in %17 + : (!transform.any_op) -> !transform.any_op + transform.apply_licm to %all_loops_8 : !transform.any_op + transform.apply_patterns to %17 { + transform.apply_patterns.linalg.tiling_canonicalization + transform.apply_patterns.vector.lower_masked_transfers + } : !transform.any_op + + // Convert sync copies to shared memory to async. + // transform.buddy.create_async_groups %17 {use_mma_sync} : (!transform.any_op) -> () + transform.apply_patterns to %17 { + transform.apply_patterns.linalg.tiling_canonicalization + transform.apply_patterns.scf.for_loop_canonicalization + transform.apply_patterns.canonicalization + transform.apply_patterns.memref.fold_memref_alias_ops + } : !transform.any_op + %all_loops_9 = transform.structured.match interface{LoopLikeInterface} + in %17 + : (!transform.any_op) -> !transform.any_op + transform.apply_licm to %all_loops_9 : !transform.any_op + transform.apply_cse to %17 : !transform.any_op + + + %20 = transform.structured.match ops{["nvgpu.mma.sync"]} in %17 : (!transform.any_op) -> !transform.any_op + %21 = transform.get_parent_op %20 {deduplicate, op_name = "scf.for"} : (!transform.any_op) -> !transform.any_op + // This applies software pipelining to a given scf.for loop. + // The pipelining strategy will look for a copy to shared memory and pipeline it to overlap it with the rest of the loop. + // %22 = transform.buddy.pipeline_shared_memory_copies %21 {depth = 3 : i64, use_mma_sync, peel_epilogue} : (!transform.any_op) -> !transform.any_op + + // Perform canonicalization. + transform.apply_patterns to %17 { + transform.apply_patterns.vector.lower_masks + } : !transform.any_op + transform.apply_patterns to %17 { + transform.apply_patterns.vector.materialize_masks + } : !transform.any_op + transform.apply_patterns to %17 { + transform.apply_patterns.linalg.tiling_canonicalization + transform.apply_patterns.scf.for_loop_canonicalization + transform.apply_patterns.canonicalization + transform.apply_patterns.memref.fold_memref_alias_ops + } : !transform.any_op + + %all_loops_10 = transform.structured.match interface{LoopLikeInterface} + in %17 + : (!transform.any_op) -> !transform.any_op + transform.apply_licm to %all_loops_10 : !transform.any_op + transform.apply_cse to %17 : !transform.any_op + transform.yield } } // module diff --git a/midend/include/Dialect/CMakeLists.txt b/midend/include/Dialect/CMakeLists.txt index 8ab8f29f5..afedee5d6 100644 --- a/midend/include/Dialect/CMakeLists.txt +++ b/midend/include/Dialect/CMakeLists.txt @@ -5,3 +5,4 @@ add_subdirectory(RVV) add_subdirectory(VectorExp) add_subdirectory(Gemmini) add_subdirectory(Sche) +add_subdirectory(GPU) diff --git a/midend/include/Dialect/GPU/CMakeLists.txt b/midend/include/Dialect/GPU/CMakeLists.txt new file mode 100644 index 000000000..727895982 --- /dev/null +++ b/midend/include/Dialect/GPU/CMakeLists.txt @@ -0,0 +1,4 @@ +set(LLVM_TARGET_DEFINITIONS TransformOps.td) +mlir_tablegen(TransformOps.h.inc -gen-op-decls) +mlir_tablegen(TransformOps.cpp.inc -gen-op-defs) +add_public_tablegen_target(TransformOpsIncGen) diff --git a/midend/include/Dialect/GPU/TransformOps.h b/midend/include/Dialect/GPU/TransformOps.h new file mode 100644 index 000000000..d69c467f5 --- /dev/null +++ b/midend/include/Dialect/GPU/TransformOps.h @@ -0,0 +1,74 @@ +//===- TransformOps.h -----------------------------------------------------===// +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +//===----------------------------------------------------------------------===// +// +// The process in this file references the IREE project, +// which is hereby acknowledged. +// For the license of the IREE project +// please see: https://github.com/iree-org/iree/blob/main/LICENSE +// +//===----------------------------------------------------------------------===// +// +// This file defines GPU transform ops for code generation. +// +//===----------------------------------------------------------------------===// + +#ifndef TRANSFORM_OPS_H +#define TRANSFORM_OPS_H + +#include "mlir/Bytecode/BytecodeOpInterface.h" +#include "mlir/Dialect/Transform/IR/TransformDialect.h" +#include "mlir/Dialect/Transform/IR/TransformInterfaces.h" + +namespace mlir { +class DialectRegistry; + +namespace func { +class FuncOp; +} + +namespace scf { +class ForallOp; +class IfOp; +class ForOp; +} // namespace scf + +namespace vector { +class VectorDialect; +class WarpExecuteOnLane0Op; +} // namespace vector + +} // namespace mlir + +namespace mlir { +namespace buddy { +void registerBuddyGPUTransformOps(mlir::DialectRegistry ®istry); + +namespace gpu { + +class TransformExtensions + : public mlir::transform::TransformDialectExtension< + TransformExtensions> { +public: + TransformExtensions(); +}; +} // namespace gpu +} // namespace buddy +} // namespace mlir + +#define GET_OP_CLASSES +#include "GPU/TransformOps.h.inc" + +#endif // TRANSFORM_OPS_H diff --git a/midend/include/Dialect/GPU/TransformOps.td b/midend/include/Dialect/GPU/TransformOps.td new file mode 100644 index 000000000..8eb7fac01 --- /dev/null +++ b/midend/include/Dialect/GPU/TransformOps.td @@ -0,0 +1,127 @@ +//===- TransformOps.td ----------------------------------------------------===// +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +//===----------------------------------------------------------------------===// +// +// The process in this file references the IREE project, +// which is hereby acknowledged. +// For the license of the IREE project +// please see: https://github.com/iree-org/iree/blob/main/LICENSE +// +//===----------------------------------------------------------------------===// +// +// This file defines the transform operations of the gpu dialect. +// +//===----------------------------------------------------------------------===// + +#ifndef TRANSFORM_OPS_TD +#define TRANSFORM_OPS_TD + +include "mlir/Dialect/Transform/IR/TransformDialect.td" +include "mlir/Dialect/Transform/IR/TransformInterfaces.td" +include "mlir/Dialect/Transform/IR/TransformTypes.td" +include "mlir/Interfaces/SideEffectInterfaces.td" +include "mlir/IR/OpBase.td" +include "mlir/Dialect/Transform/IR/TransformAttrs.td" +include "mlir/Dialect/Transform/IR/TransformDialect.td" +include "mlir/Dialect/Transform/IR/TransformInterfaces.td" +include "mlir/Dialect/Transform/IR/TransformTypes.td" +include "mlir/Interfaces/SideEffectInterfaces.td" +include "mlir/Dialect/SCF/IR/DeviceMappingInterface.td" +include "mlir/IR/EnumAttr.td" +include "mlir/IR/OpBase.td" + +// From IREE Common Extension OPs +def HoistStaticAllocOp : Op, + TransformEachOpTrait, + TransformOpInterface, + ReportTrackingListenerFailuresOpTrait]> { + let summary = "Hoist static allocations"; + let description = [{ + Find static allocations and hoist them to the top level. + + #### Return modes + This transform applies static alloc hoisting the whole region of the operand. + + It does not consume the target handle and always return success. + }]; + + let arguments = (ins TransformHandleTypeInterface:$target); + let results = (outs); + + let assemblyFormat = "$target attr-dict `:` functional-type(operands, results)"; + let cppNamespace = "mlir::buddy::gpu"; + + let extraClassDeclaration = [{ + ::mlir::DiagnosedSilenceableFailure applyToOne( + ::mlir::transform::TransformRewriter &rewriter, + ::mlir::func::FuncOp funcOp, + ::mlir::transform::ApplyToEachResultList &results, + ::mlir::transform::TransformState &state); + }]; +} + +def ApplyUnrollVectorsGpuMmaSyncPatternsOp : Op, + ReportTrackingListenerFailuresOpTrait]> { + let description = [{ + Populate patterns that unroll vectors. TODO: better documentation. + }]; + + let cppNamespace = "mlir::buddy::gpu"; + let assemblyFormat = "attr-dict"; +} + +def VectorToMMAConversionOp : Op, + TransformEachOpTrait, + TransformOpInterface, + ReportTrackingListenerFailuresOpTrait]> { + let description = [{ + This converts slices of operations containing vector.contract op into + mma operations, targetting warp level tensorcore operations. If the vector + operations are bigger than the native mma size it will first split up those + vector operations. + + Exactly one of use_wmma or use_mma_sync must be specified. + + #### Return modes + + This transform consumes the target handle and produces a result handle. + }]; + + let arguments = (ins TransformHandleTypeInterface:$target, + UnitAttr:$use_mma_sync, + UnitAttr:$use_wmma); + let results = (outs); + + let assemblyFormat = [{ + $target + attr-dict + `:` functional-type($target, results) + }]; + let cppNamespace = "mlir::buddy::gpu"; + + let extraClassDeclaration = [{ + ::mlir::DiagnosedSilenceableFailure applyToOne( + ::mlir::transform::TransformRewriter &rewriter, + ::mlir::Operation *target, + ::mlir::transform::ApplyToEachResultList &results, + ::mlir::transform::TransformState &state); + }]; +} + +#endif // TRANSFORM_OPS_TD diff --git a/midend/include/Utils/GPUUtils.h b/midend/include/Utils/GPUUtils.h new file mode 100644 index 000000000..88605fe1d --- /dev/null +++ b/midend/include/Utils/GPUUtils.h @@ -0,0 +1,104 @@ +//===- GPUUtils.h ---------------------------------------------------------===// +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +//===----------------------------------------------------------------------===// +// +// The process in this file references the IREE project, +// which is hereby acknowledged. +// For the license of the IREE project +// please see: https://github.com/iree-org/iree/blob/main/LICENSE +// +//===----------------------------------------------------------------------===// +// +// This file implements GPU dialect specific utility functions for the buddy +// compiler ecosystem. +// +//===----------------------------------------------------------------------===// + +#ifndef INCLUDE_UTILS_GPUUTILS_H +#define INCLUDE_UTILS_GPUUTILS_H + +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/Linalg/IR/Linalg.h" +#include "mlir/Dialect/Linalg/Utils/Utils.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/Dialect/SCF/IR/SCF.h" +#include "mlir/Dialect/Vector/IR/VectorOps.h" +#include "mlir/IR/BuiltinOps.h" +#include "mlir/IR/Dominance.h" +#include "mlir/IR/OpDefinition.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/Interfaces/ViewLikeInterface.h" +#include "llvm/ADT/StringMap.h" +#include "llvm/TargetParser/Triple.h" + +namespace mlir{ +namespace buddy::gpu{ +static constexpr int32_t kNumGPUDims = 3; +static constexpr int32_t kWarpSize = 32; + +/// Pick an unrolling order that will allow tensorcore operation to reuse LHS +/// register. This is needed to get good performance on sm_80 target. +std::optional> +gpuMmaUnrollOrder(vector::ContractionOp contract); + +/// Helper function to return native size for MMA.SYNC-based operations. +std::optional> getMmaNativeVectorSize(Operation *op); + +/// Return true if the given memref has workgroup memory space. +bool hasSharedMemoryAddressSpace(MemRefType memrefType); + +/// Packs vector of lower precision into a single 32-bit width element. +/// (i.e <2xf16> -> i32 and <4xi8> -> i32) +Value packVectorToSupportedWidth(Location loc, OpBuilder &builder, Value input); + +/// Unpack single scalar element into a target vector type. +/// (i.e i32 -> vector<4xi8> or f32 -> vector<2xf16>) +Value unpackToVector(Location loc, OpBuilder &builder, Value packedInput, + VectorType targetVecType); + +/// Creates an allocation in the entry block of the function if the size is +/// statically bounded. For a static allocation, it returns an allocation +/// of the same size but in the entry basic block. For dynamic (still bounded) +/// allocations creates an allocation, and inserts a subview to match the +/// dynamic shape of the allocation. Returns std::nullopt if the method +/// couldnt creat an allocation in the entry block. +template +std::optional +hoistOneStaticallyBoundAllocation(func::FuncOp funcOp, OpBuilder &builder, + Location loc, MemRefType allocaType, + ValueRange dynamicSizes, + std::optional alignment); + +/// Hoists `allocaOp` to the entry block of the function if the size is +/// statically bounded. For a static allocation, it returns an allocation +/// of the same size but in the entry basic block. For dynamic (still bounded) +/// allocations creates an allocation, and inserts a subview to match the +/// dynamic shape of the allocation. The method returns a value, but +/// does not replace the uses of the `allocaOp`. +template +std::optional +hoistOneStaticallyBoundAllocation(func::FuncOp funcOp, OpBuilder &builder, + AllocLikeOpType allocaOp); + +/// Traverse funcOp and try to hoist every AllocaOp to the entry block of the +/// function if the size is statically bounded. +template +void hoistStaticallyBoundAllocationsInFunc(RewriterBase &rewriter, + func::FuncOp funcOp); + +} // namespace buddy::gpu +} // namespace mlir + +#endif // INCLUDE_UTILS_GPUUTILS_H diff --git a/midend/lib/Conversion/CMakeLists.txt b/midend/lib/Conversion/CMakeLists.txt index 99254e410..c372b0d80 100644 --- a/midend/lib/Conversion/CMakeLists.txt +++ b/midend/lib/Conversion/CMakeLists.txt @@ -14,3 +14,4 @@ add_subdirectory(LowerLinalgToGemmini) add_subdirectory(SchedulingOnDevices) add_subdirectory(LowerSche) add_subdirectory(FuncBufferize) +add_subdirectory(MLIRGPU) diff --git a/midend/lib/Conversion/MLIRGPU/CMakeLists.txt b/midend/lib/Conversion/MLIRGPU/CMakeLists.txt new file mode 100644 index 000000000..be7148357 --- /dev/null +++ b/midend/lib/Conversion/MLIRGPU/CMakeLists.txt @@ -0,0 +1,28 @@ +add_mlir_library(MLIRGPUPasses + ConvertMemcpyToGPU.cpp + LegalizeShmemOutlining.cpp + + ADDITIONAL_HEADER_DIRS + ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/Bufferization + + LINK_LIBS PUBLIC + MLIRArithDialect + MLIRBufferizationDialect + MLIRControlFlowInterfaces + MLIRFuncDialect + MLIRFunctionInterfaces + MLIRInferTypeOpInterface + MLIRIR + MLIRMemRefDialect + MLIRPass + MLIRTensorDialect + MLIRSCFDialect + MLIRSideEffectInterfaces + MLIRSubsetOpInterface + MLIRTransforms + MLIRViewLikeInterface + MLIRSupport + BuddyUtils + MLIRBufferizationTransforms + MLIRGPUDialect +) diff --git a/midend/lib/Conversion/MLIRGPU/ConvertMemcpyToGPU.cpp b/midend/lib/Conversion/MLIRGPU/ConvertMemcpyToGPU.cpp new file mode 100644 index 000000000..dd50feccf --- /dev/null +++ b/midend/lib/Conversion/MLIRGPU/ConvertMemcpyToGPU.cpp @@ -0,0 +1,263 @@ +//===- ConvertMemcpyToGPU.cpp ---------------------------------------------===// +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +//===----------------------------------------------------------------------===// +// +// This file implements the pass that converts memcpy to gpu operations. +// +//===---------------------------------------------------------------------===// + +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/GPU/IR/GPUDialect.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/BuiltinOps.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/OperationSupport.h" +#include "mlir/IR/TypeRange.h" +#include "mlir/IR/ValueRange.h" +#include "mlir/IR/Visitors.h" +#include "mlir/Support/LLVM.h" +#include "llvm/Support/Casting.h" +#include "llvm/Support/Debug.h" +#include "llvm/Support/ErrorHandling.h" +#include "llvm/Support/raw_ostream.h" +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include +using namespace mlir; +using namespace vector; + +//===----------------------------------------------------------------------===// +// ConvertMemcpyToGPUPass +//===----------------------------------------------------------------------===// + +namespace { + +class ConvertMemcpyToGPUPass + : public PassWrapper> { +public: + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(ConvertMemcpyToGPUPass) + StringRef getArgument() const final { return "convert-memcpy-to-gpu"; } + StringRef getDescription() const final { + return "Convert memref opertaions to gpu operations."; + } + ConvertMemcpyToGPUPass() = default; + ConvertMemcpyToGPUPass(const ConvertMemcpyToGPUPass &) {} + + Option processArgs{ + *this, "process-args", + llvm::cl::desc("Whether the pass processes the input args."), + llvm::cl::init(true)}; + + void runOnOperation() override; + + void getDependentDialects(DialectRegistry ®istry) const override { + registry.insert(); + } +}; + +void ConvertMemcpyToGPUPass::runOnOperation() { + auto funcOp = getOperation(); + + // Make sure the gpu function is already outlined. + funcOp->walk([&](Operation *nestedOp) { + if (auto gpuLaunchOp = dyn_cast(nestedOp)) { + nestedOp->emitOpError("The gpu function should be outlined."); + } + return WalkResult::advance(); + }); + + std::set unDeallocatedOperations; + OpBuilder builder(funcOp->getContext()); + // Copy all function arguments to gpu, needs deallocation + if (processArgs) { + builder.setInsertionPointToStart(&(funcOp.getBody().front())); + unsigned numArgs = funcOp.getNumArguments(); + for (unsigned i = 0; i < numArgs; ++i) { + BlockArgument arg = funcOp.getArgument(i); + // Create a gpu.alloc op, then copy memory to it + // TODO: Move this out of operation, make the copy process async + auto memrefType = dyn_cast(arg.getType()); + auto gpuAllocOp = builder.create( + builder.getUnknownLoc(), TypeRange({memrefType}), ValueRange({})); + unDeallocatedOperations.insert(&gpuAllocOp); + auto gpuMemcpyOp = builder.create( + gpuAllocOp.getLoc(), TypeRange(), ValueRange(), + gpuAllocOp.getResult(0), arg); + // Replace all users with GPU memory + auto users = arg.getUsers(); + std::vector usersVec(users.begin(), users.end()); + for (auto user : usersVec) { + // Don't replace memcpy's operand + if (isa(user)) + continue; + for (size_t j = 0; j < user->getNumOperands(); j++) { + if (user->getOperand(j) == arg) { + user->setOperand(j, gpuAllocOp.getResult(0)); + } + } + } + } + } + + funcOp->walk([&](Operation *nestedOp) { + // Replace all allocations with GPU.alloc + if (auto allocOp = dyn_cast(nestedOp)) { + // Rewrite this allocOp to gpu.alloc, change for all users + builder.setInsertionPointAfter(allocOp); + auto result = allocOp->getResult(0); + auto memrefType = dyn_cast(result.getType()); + auto memorySpace = memrefType.getMemorySpace(); + + // Filter operations. + if (memorySpace) { + if (auto intMemorySpace = llvm::dyn_cast(memorySpace)) { + if (intMemorySpace.getInt() != 0) { + return WalkResult::advance(); + } + } else if (auto gpuMemorySpace = + llvm::dyn_cast(memorySpace)) { + if (gpuMemorySpace.getValue() != gpu::AddressSpace::Global) { + return WalkResult::advance(); + } + } else + return WalkResult::advance(); + } + + auto gpuAllocOp = builder.create( + allocOp->getLoc(), TypeRange({memrefType}), ValueRange({})); + auto users = result.getUsers(); + std::vector usersVec(users.begin(), users.end()); + for (auto user : usersVec) { + for (size_t j = 0; j < user->getNumOperands(); j++) { + // Only the return value will not have dealloc op + if (auto deallocOp = dyn_cast(user)) { + builder.setInsertionPointAfter(deallocOp); + auto gpuDeallocOp = builder.create( + deallocOp->getLoc(), TypeRange(), ValueRange(), + gpuAllocOp.getResult(0)); + deallocOp->erase(); + } else if (user->getOperand(j) == result) { + user->setOperand(j, gpuAllocOp.getResult(0)); + } + } + } + allocOp->erase(); + } + // Replace all memory.copy operations with gpu.memcpy + else if (auto copyOp = dyn_cast(nestedOp)) { + auto src = copyOp.getOperand(0); + auto dst = copyOp.getOperand(1); + // Notice: GPU.memcpy has a different src dst order + builder.setInsertionPointAfter(copyOp); + auto gpuMemcpyOp = builder.create( + copyOp->getLoc(), TypeRange(), ValueRange(), dst, src); + { + auto users = src.getUsers(); + std::vector usersVec(users.begin(), users.end()); + for (auto user : usersVec) { + for (size_t j = 0; j < user->getNumOperands(); j++) { + if (user->getOperand(j) == src) { + user->setOperand(j, gpuMemcpyOp.getOperand(1)); + } + } + } + } + { + auto users = dst.getUsers(); + std::vector usersVec(users.begin(), users.end()); + for (auto user : usersVec) { + for (size_t j = 0; j < user->getNumOperands(); j++) { + if (user->getOperand(j) == src) { + user->setOperand(j, gpuMemcpyOp.getOperand(0)); + } + } + } + } + copyOp->erase(); + } + // Allocate space on GPU and copy global memrefs to GPU, needs deallocation + else if (auto getGlobalOp = dyn_cast(nestedOp)) { + builder.setInsertionPointAfter(getGlobalOp); + auto result = getGlobalOp->getResult(0); + auto memrefType = dyn_cast(result.getType()); + auto gpuAllocOp = builder.create( + getGlobalOp->getLoc(), TypeRange({memrefType}), ValueRange({})); + unDeallocatedOperations.insert(&gpuAllocOp); + auto src = result; + auto dst = gpuAllocOp->getResult(0); + auto gpuMemcpyOp = builder.create( + gpuAllocOp->getLoc(), TypeRange(), ValueRange(), dst, src); + { + auto users = src.getUsers(); + std::vector usersVec(users.begin(), users.end()); + for (auto user : usersVec) { + if (isa(user)) + continue; + // TODO: replace with src.replaceAllUsesExcept() + for (size_t j = 0; j < user->getNumOperands(); j++) { + if (user->getOperand(j) == src) { + user->setOperand(j, dst); + } + } + } + } + } + // Copy data back to CPU, deallocate GPU, then return + else if (auto returnOp = dyn_cast(nestedOp)) { + builder.setInsertionPoint(returnOp); + + for (auto *gpuAllocOp : unDeallocatedOperations) { + auto gpuDeallocOp = builder.create( + builder.getUnknownLoc(), TypeRange(), ValueRange(), + gpuAllocOp->getResult(0)); + } + builder.setInsertionPoint(returnOp); + for (unsigned i = 0; i < returnOp.getNumOperands(); ++i) { + auto val = returnOp->getOperand(i); + auto memRefType = dyn_cast(val.getType()); + auto allocOp = builder.create(builder.getUnknownLoc(), + memRefType); + auto gpuMemcpyOp = builder.create( + allocOp.getLoc(), TypeRange(), ValueRange(), allocOp->getResult(0), + val); + auto gpuDeallocOp = builder.create( + gpuMemcpyOp->getLoc(), TypeRange(), ValueRange(), val); + returnOp->setOperand(i, allocOp->getResult(0)); + } + } + return WalkResult::advance(); + }); +} +} // end anonymous namespace. + +namespace mlir { +namespace buddy { +void registerConvertMemcpyToGPUPass() { + PassRegistration(); +} +} // namespace buddy +} // namespace mlir diff --git a/midend/lib/Conversion/MLIRGPU/LegalizeShmemOutlining.cpp b/midend/lib/Conversion/MLIRGPU/LegalizeShmemOutlining.cpp new file mode 100644 index 000000000..79638d460 --- /dev/null +++ b/midend/lib/Conversion/MLIRGPU/LegalizeShmemOutlining.cpp @@ -0,0 +1,433 @@ +//===- LegalizeShmemOutlining.cpp -----------------------------------------===// +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +//===----------------------------------------------------------------------===// +// +// This file implements the pass that legalizes shared memory operations. +// +//===----------------------------------------------------------------------===// + +#include "mlir/AsmParser/AsmParser.h" +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h" +#include "mlir/Dialect/DLTI/DLTI.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/GPU/IR/GPUDialect.h" +#include "mlir/Dialect/GPU/Transforms/Passes.h" +#include "mlir/Dialect/GPU/Transforms/Utils.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/Dialect/SCF/IR/SCF.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/BuiltinAttributes.h" +#include "mlir/IR/BuiltinOps.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/IRMapping.h" +#include "mlir/IR/Matchers.h" +#include "mlir/IR/OperationSupport.h" +#include "mlir/IR/SymbolTable.h" +#include "mlir/IR/TypeRange.h" +#include "mlir/IR/ValueRange.h" +#include "mlir/IR/Visitors.h" +#include "mlir/Support/LLVM.h" +#include "mlir/Transforms/RegionUtils.h" +#include "llvm/Support/Casting.h" +#include "llvm/Support/Debug.h" +#include "llvm/Support/ErrorHandling.h" +#include "llvm/Support/raw_ostream.h" +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include +using namespace mlir; +using namespace vector; + +//===---------------------------------------------------------------------===// +// From mlir/lib/Dialect/GPU/Transforms/KernelOutlining.cpp +//===---------------------------------------------------------------------===// + +namespace mlir { +#define GEN_PASS_DEF_GPULAUNCHSINKINDEXCOMPUTATIONS +#define GEN_PASS_DEF_GPUKERNELOUTLINING +#include "mlir/Dialect/GPU/Transforms/Passes.h.inc" +} // namespace mlir + +using namespace mlir; + +template +static void createForAllDimensions(OpBuilder &builder, Location loc, + SmallVectorImpl &values) { + for (auto dim : {gpu::Dimension::x, gpu::Dimension::y, gpu::Dimension::z}) + values.push_back(builder.create(loc, builder.getIndexType(), dim)); +} + +/// Adds operations generating block/thread ids and grid/block dimensions at the +/// beginning of the `launchFuncOpBody` region. Add mapping from argument in +/// entry block of `launchOpBody`, to the corresponding result value of the +/// added operations. +static void injectGpuIndexOperations(Location loc, Region &launchFuncOpBody, + Region &launchOpBody, IRMapping &map) { + OpBuilder builder(loc->getContext()); + Block &firstBlock = launchOpBody.front(); + builder.setInsertionPointToStart(&launchFuncOpBody.front()); + SmallVector indexOps; + createForAllDimensions(builder, loc, indexOps); + createForAllDimensions(builder, loc, indexOps); + createForAllDimensions(builder, loc, indexOps); + createForAllDimensions(builder, loc, indexOps); + // Replace the leading 12 function args with the respective thread/block index + // operations. Iterate backwards since args are erased and indices change. + for (const auto &indexOp : enumerate(indexOps)) + map.map(firstBlock.getArgument(indexOp.index()), indexOp.value()); +} + +/// Return the provided KernelDim3 as an array of i32 constants if possible. +static DenseI32ArrayAttr maybeConstantDimsAttr(gpu::KernelDim3 dims) { + SmallVector constants; + MLIRContext *ctx = dims.x.getContext(); + for (Value v : {dims.x, dims.y, dims.z}) { + APInt constValue; + if (!matchPattern(v, m_ConstantInt(&constValue))) + return nullptr; + // In the event someone called for a too-large block or grid dimension, + // don't set bounds as it is likely to cause more confusing behavior. + if (constValue.ugt(std::numeric_limits::max())) + return nullptr; + constants.push_back( + constValue.getLimitedValue(std::numeric_limits::max())); + } + return DenseI32ArrayAttr::get(ctx, constants); +} + +/// Outline the `gpu.launch` operation body into a kernel function. Replace +/// `gpu.terminator` operations by `gpu.return` in the generated function. +/// Set block and grid size bounds if known. +static gpu::GPUFuncOp outlineKernelFuncImpl(gpu::LaunchOp launchOp, + StringRef kernelFnName, + SetVector &operands) { + Location loc = launchOp.getLoc(); + // Create a builder with no insertion point, insertion will happen separately + // due to symbol table manipulation. + OpBuilder builder(launchOp.getContext()); + Region &launchOpBody = launchOp.getBody(); + + // Identify uses from values defined outside of the scope of the launch + // operation. + getUsedValuesDefinedAbove(launchOpBody, operands); + + // Create the gpu.func operation. + SmallVector kernelOperandTypes; + kernelOperandTypes.reserve(operands.size()); + for (Value operand : operands) { + kernelOperandTypes.push_back(operand.getType()); + } + FunctionType type = + FunctionType::get(launchOp.getContext(), kernelOperandTypes, {}); + auto outlinedFunc = builder.create( + loc, kernelFnName, type, + TypeRange(ValueRange(launchOp.getWorkgroupAttributions())), + TypeRange(ValueRange(launchOp.getPrivateAttributions()))); + outlinedFunc->setAttr(gpu::GPUDialect::getKernelFuncAttrName(), + builder.getUnitAttr()); + + // If we can infer bounds on the grid and/or block sizes from the arguments + // to the launch op, propagate them to the generated kernel. This is safe + // because multiple launches with the same body are not deduplicated. + if (auto blockBounds = + maybeConstantDimsAttr(launchOp.getBlockSizeOperandValues())) + outlinedFunc->setAttr(gpu::GPUFuncOp::getKnownBlockSizeAttrName(), + blockBounds); + if (auto gridBounds = + maybeConstantDimsAttr(launchOp.getGridSizeOperandValues())) + outlinedFunc->setAttr(gpu::GPUFuncOp::getKnownGridSizeAttrName(), + gridBounds); + + IRMapping map; + + // Map the arguments corresponding to the launch parameters like blockIdx, + // threadIdx, etc. + Region &outlinedFuncBody = outlinedFunc.getBody(); + injectGpuIndexOperations(loc, outlinedFuncBody, launchOpBody, map); + + // Map memory attributions from the LaunOp op to the GPUFuncOp attributions. + for (const auto &[launchArg, funcArg] : + llvm::zip(launchOp.getWorkgroupAttributions(), + outlinedFunc.getWorkgroupAttributions())) + map.map(launchArg, funcArg); + for (const auto &[launchArg, funcArg] : + llvm::zip(launchOp.getPrivateAttributions(), + outlinedFunc.getPrivateAttributions())) + map.map(launchArg, funcArg); + + // Map arguments from gpu.launch region to the arguments of the gpu.func + // operation. + Block &entryBlock = outlinedFuncBody.front(); + for (const auto &operand : enumerate(operands)) + map.map(operand.value(), entryBlock.getArgument(operand.index())); + + // Clone the region of the gpu.launch operation into the gpu.func operation. + // TODO: If cloneInto can be modified such that if a mapping for + // a block exists, that block will be used to clone operations into (at the + // end of the block), instead of creating a new block, this would be much + // cleaner. + launchOpBody.cloneInto(&outlinedFuncBody, map); + + // Branch from entry of the gpu.func operation to the block that is cloned + // from the entry block of the gpu.launch operation. + Block &launchOpEntry = launchOpBody.front(); + Block *clonedLaunchOpEntry = map.lookup(&launchOpEntry); + builder.setInsertionPointToEnd(&entryBlock); + builder.create(loc, clonedLaunchOpEntry); + + outlinedFunc.walk([](gpu::TerminatorOp op) { + OpBuilder replacer(op); + replacer.create(op.getLoc()); + op.erase(); + }); + return outlinedFunc; +} + +/// Replace `gpu.launch` operations with an `gpu.launch_func` operation +/// launching `kernelFunc`. The kernel func contains the body of the +/// `gpu.launch` with constant region arguments inlined. +static void convertToLaunchFuncOp(gpu::LaunchOp launchOp, + gpu::GPUFuncOp kernelFunc, + ValueRange operands) { + OpBuilder builder(launchOp); + // The launch op has an optional dynamic shared memory size. If it doesn't + // exist, we use zero. + Value asyncToken = launchOp.getAsyncToken(); + auto launchFunc = builder.create( + launchOp.getLoc(), kernelFunc, launchOp.getGridSizeOperandValues(), + launchOp.getBlockSizeOperandValues(), + launchOp.getDynamicSharedMemorySize(), operands, + asyncToken ? asyncToken.getType() : nullptr, + launchOp.getAsyncDependencies()); + launchOp.replaceAllUsesWith(launchFunc); + launchOp.erase(); +} + +/// Pass that moves the kernel of each LaunchOp into its separate nested module. +/// +/// This pass moves the kernel code of each LaunchOp into a function created +/// inside a nested module. It also creates an external function of the same +/// name in the parent module. +/// +/// The gpu.modules are intended to be compiled to a cubin blob independently in +/// a separate pass. The external functions can then be annotated with the +/// symbol of the cubin accessor function. + +namespace { +class LegalizeShmemOutliningPass + : public PassWrapper> { +public: + std::vector shmemAllocations; + std::map shmemGlobalPairs; + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(LegalizeShmemOutliningPass) + StringRef getArgument() const final { return "legalize-shmem-outlining"; } + StringRef getDescription() const final { + return "Convert shared memory outlining to global memref declaration."; + } + + void runOnOperation() override { + SymbolTable symbolTable(getOperation()); + + bool modified = false; + for (auto func : getOperation().getOps()) { + // Insert just after the function. + Block::iterator insertPt(func->getNextNode()); + + // Collects all allocations for shared memory outside the kernel. + // The collection must happen before the kernel outlining. + // It moves back all shared allocations back into their GPU body + // Allowing the functions to create kernels without shared memory + // as parameters. + func.walk([&](memref::AllocOp allocOp) { + auto result = allocOp->getResult(0); + auto memrefType = dyn_cast(result.getType()); + auto memorySpace = memrefType.getMemorySpace(); + if (!memorySpace) + return WalkResult::advance(); + else { + if (auto intMemorySpace = llvm::dyn_cast(memorySpace)) { + if (intMemorySpace.getInt() != 3) { + return WalkResult::advance(); + } + } else if (auto gpuMemorySpace = + llvm::dyn_cast(memorySpace)) { + if (gpuMemorySpace.getValue() != gpu::AddressSpace::Workgroup) { + return WalkResult::advance(); + } + } else + return WalkResult::advance(); + } + auto users = allocOp->getUsers(); + for (auto user : users) { + if (isa(user)) { + user->erase(); + continue; + } + // Locates the gpu kernel wrapper + auto launchOp = user->getParentOfType(); + OpBuilder builder(launchOp); + builder.setInsertionPointToStart( + &launchOp.getBody().getBlocks().front()); + auto newAllocOp = + builder.create(launchOp.getLoc(), memrefType); + allocOp->replaceAllUsesWith(newAllocOp); + allocOp->erase(); + break; + } + return WalkResult::advance(); + }); + + auto funcWalkResult = func.walk([&](gpu::LaunchOp op) { + SetVector operands; + std::string kernelFnName = + Twine(op->getParentOfType().getName(), "_kernel") + .str(); + + gpu::GPUFuncOp outlinedFunc = + outlineKernelFuncImpl(op, kernelFnName, operands); + + // Create nested module and insert outlinedFunc. The module will + // originally get the same name as the function, but may be renamed on + // insertion into the parent module. + auto kernelModule = createKernelModule(outlinedFunc, symbolTable); + symbolTable.insert(kernelModule, insertPt); + + size_t counter = 0; + // Walk the funcop and replace all shmem allocations with global memref + outlinedFunc->walk([&](memref::AllocOp allocOp) { + auto result = allocOp->getResult(0); + auto memrefType = dyn_cast(result.getType()); + auto memorySpace = memrefType.getMemorySpace(); + if (!memorySpace) + allocOp->emitOpError() + << "Found non-shared memory inside a kernel function"; + else { + if (auto intMemorySpace = + llvm::dyn_cast(memorySpace)) { + if (intMemorySpace.getInt() != 3) { + return WalkResult::advance(); + } + } else if (auto gpuMemorySpace = + llvm::dyn_cast(memorySpace)) { + if (gpuMemorySpace.getValue() != gpu::AddressSpace::Workgroup) { + return WalkResult::advance(); + } + } else + return WalkResult::advance(); + } + + OpBuilder builder(outlinedFunc); + + auto name = Twine("shmem_", std::to_string(counter++)).str(); + + auto globalOp = builder.create( + kernelModule->getLoc(), + /*sym_name=*/name, + /*sym_visibility=*/builder.getStringAttr("private"), + /*type=*/memrefType, + /*initial_value=*/ElementsAttr(), + /*constant=*/false, + /*alignment=*/builder.getI64IntegerAttr(64)); + // symbolTable.insert(globalOp); + builder.setInsertionPointAfter(allocOp); + Value getGlobalOp = builder.create( + allocOp->getLoc(), globalOp.getType(), name); + allocOp.replaceAllUsesWith(getGlobalOp); + allocOp->erase(); + return WalkResult::advance(); + }); + + // Potentially changes signature, pulling in constants. + convertToLaunchFuncOp(op, outlinedFunc, operands.getArrayRef()); + modified = true; + return WalkResult::advance(); + }); + if (funcWalkResult.wasInterrupted()) + return signalPassFailure(); + } + + // If any new module was inserted in this module, annotate this module as + // a container module. + if (modified) + getOperation()->setAttr(gpu::GPUDialect::getContainerModuleAttrName(), + UnitAttr::get(&getContext())); + } + +private: + /// Returns a gpu.module containing kernelFunc and all callees (recursive). + gpu::GPUModuleOp createKernelModule(gpu::GPUFuncOp kernelFunc, + const SymbolTable &parentSymbolTable) { + // TODO: This code cannot use an OpBuilder because it must be inserted into + // a SymbolTable by the caller. SymbolTable needs to be refactored to + // prevent manual building of Ops with symbols in code using SymbolTables + // and then this needs to use the OpBuilder. + auto *context = getOperation().getContext(); + OpBuilder builder(context); + auto kernelModule = builder.create(kernelFunc.getLoc(), + kernelFunc.getName()); + + SymbolTable symbolTable(kernelModule); + symbolTable.insert(kernelFunc); + + SmallVector symbolDefWorklist = {kernelFunc}; + while (!symbolDefWorklist.empty()) { + if (std::optional symbolUses = + SymbolTable::getSymbolUses(symbolDefWorklist.pop_back_val())) { + for (SymbolTable::SymbolUse symbolUse : *symbolUses) { + StringRef symbolName = + cast(symbolUse.getSymbolRef()).getValue(); + if (symbolTable.lookup(symbolName)) + continue; + + Operation *symbolDefClone = + parentSymbolTable.lookup(symbolName)->clone(); + symbolDefWorklist.push_back(symbolDefClone); + symbolTable.insert(symbolDefClone); + } + } + } + + return kernelModule; + } +}; + +} // namespace + +//===----------------------------------------------------------------------===// +// LegalizeShmemOutliningPass +//===----------------------------------------------------------------------===// + +namespace mlir { +namespace buddy { +void registerLegalizeShmemOutliningPass() { + PassRegistration(); +} +} // namespace buddy +} // namespace mlir diff --git a/midend/lib/Dialect/CMakeLists.txt b/midend/lib/Dialect/CMakeLists.txt index 8ab8f29f5..afedee5d6 100644 --- a/midend/lib/Dialect/CMakeLists.txt +++ b/midend/lib/Dialect/CMakeLists.txt @@ -5,3 +5,4 @@ add_subdirectory(RVV) add_subdirectory(VectorExp) add_subdirectory(Gemmini) add_subdirectory(Sche) +add_subdirectory(GPU) diff --git a/midend/lib/Dialect/GPU/CMakeLists.txt b/midend/lib/Dialect/GPU/CMakeLists.txt new file mode 100644 index 000000000..b575a44e2 --- /dev/null +++ b/midend/lib/Dialect/GPU/CMakeLists.txt @@ -0,0 +1,42 @@ +add_mlir_library(BuddyGPUTransformOPs + TransformOps.cpp + + DEPENDS + TransformOpsIncGen + + LINK_LIBS PUBLIC + LLVMSupport + BuddyGPUUtils + MLIRAffineDialect + MLIRArithDialect + MLIRBufferizationDialect + MLIRBufferizationTransforms + MLIRBytecodeWriter + MLIRFuncDialect + MLIRFunctionInterfaces + MLIRGPUDialect + MLIRGPUTransformOps + MLIRNVGPUDialect + MLIRIndexDialect + MLIRIR + MLIRLinalgDialect + MLIRLinalgTransforms + MLIRLinalgUtils + MLIRMemRefDialect + MLIRNVGPUDialect + MLIRNVGPUTransforms + MLIRParser + MLIRPDLDialect + MLIRPass + MLIRSCFDialect + MLIRSideEffectInterfaces + MLIRTensorTransformOps + MLIRTransformDialect + MLIRTransformDialectUtils + MLIRTransforms + MLIRVectorDialect + MLIRVectorToGPU + MLIRVectorTransforms + MLIRViewLikeInterface + MLIRGPUPasses + ) diff --git a/midend/lib/Dialect/GPU/TransformOps.cpp b/midend/lib/Dialect/GPU/TransformOps.cpp new file mode 100644 index 000000000..3e689fc93 --- /dev/null +++ b/midend/lib/Dialect/GPU/TransformOps.cpp @@ -0,0 +1,211 @@ +//===- TransformOps.cpp ---------------------------------------------------===// +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +//===----------------------------------------------------------------------===// +// +// The process in this file references the IREE project, +// which is hereby acknowledged. +// For the license of the IREE project +// please see: https://github.com/iree-org/iree/blob/main/LICENSE +// +//===----------------------------------------------------------------------===// +// +// This file implements transform ops for GPU targets. +// +//===----------------------------------------------------------------------===// + +#include "GPU/TransformOps.h" + +#include "mlir/Conversion/VectorToGPU/VectorToGPU.h" +#include "mlir/Dialect/Affine/IR/AffineOps.h" +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/Bufferization/IR/Bufferization.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/GPU/IR/GPUDialect.h" +#include "mlir/Dialect/GPU/TransformOps/GPUTransformOps.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/Dialect/NVGPU/IR/NVGPUDialect.h" +#include "mlir/Dialect/NVGPU/Transforms/Transforms.h" +#include "mlir/Dialect/SCF/IR/DeviceMappingInterface.h" +#include "mlir/Dialect/SCF/IR/SCF.h" +#include "mlir/Dialect/Utils/IndexingUtils.h" +#include "mlir/Dialect/Vector/IR/VectorOps.h" +#include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h" +#include "mlir/Dialect/Vector/Transforms/VectorDistribution.h" +#include "mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h" +#include "mlir/IR/BuiltinAttributeInterfaces.h" +#include "mlir/IR/BuiltinAttributes.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/Value.h" +#include "mlir/Interfaces/SideEffectInterfaces.h" +#include "mlir/Interfaces/ViewLikeInterface.h" +#include "mlir/Support/LLVM.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/ScopeExit.h" +#include "llvm/ADT/TypeSwitch.h" +#include "llvm/Support/Debug.h" +#include +#include + +#include "Utils/GPUUtils.h" + +using namespace mlir; +using namespace mlir::buddy; + +using llvm::dbgs; + +#define DEBUG_TYPE "transform-llvmgpu-extensions" +#define DEBUG_TYPE_ALIAS "transform-llvmgpu-extensions-alias" +#define DEBUG_VECTOR_TO_MMA "transform-llvmgpu-extensions-vector-to-mma" + +#define DBGS() (dbgs() << '[' << DEBUG_TYPE << "] ") +#define LDBG(X) LLVM_DEBUG(dbgs() << '[' << DEBUG_TYPE << "] " << X) +#define DBGS_ALIAS() (dbgs() << '[' << DEBUG_TYPE_ALIAS << "] ") +#define DBGS_VECTOR_TO_MMA() (dbgs() << '[' << DEBUG_VECTOR_TO_MMA << "] ") + +buddy::gpu::TransformExtensions::TransformExtensions() { + // CreateAsyncGroupsOp depends on the following two dialects. + declareGeneratedDialect(); + declareGeneratedDialect(); + + registerTransformOps< +#define GET_OP_LIST +#include "GPU/TransformOps.cpp.inc" + >(); +} + +void buddy::registerBuddyGPUTransformOps(DialectRegistry ®istry) { + registry.addExtensions(); +} + +//===----------------------------------------------------------------------===// +// HoistStaticAllocOp +//===----------------------------------------------------------------------===// + +DiagnosedSilenceableFailure buddy::gpu::HoistStaticAllocOp::applyToOne( + mlir::transform::TransformRewriter &rewriter, func::FuncOp target, + mlir::transform::ApplyToEachResultList &results, + mlir::transform::TransformState &state) { + hoistStaticallyBoundAllocationsInFunc(rewriter, target); + return DiagnosedSilenceableFailure::success(); +} + +void buddy::gpu::HoistStaticAllocOp::getEffects( + SmallVectorImpl &effects) { + mlir::transform::onlyReadsHandle(getTarget(), effects); + mlir::transform::modifiesPayload(effects); +} + +//===---------------------------------------------------------------------===// +// ApplyUnrollVectorsGpuMmaSyncPatternsOp +//===---------------------------------------------------------------------===// + +static std::optional> +getGPUTensorCoreNativeMmaSyncVectorSize(Operation *op) { + return buddy::gpu::getMmaNativeVectorSize(op); +} + +void buddy::gpu::ApplyUnrollVectorsGpuMmaSyncPatternsOp::populatePatterns( + RewritePatternSet &patterns) { + auto unrollOrder = [](Operation *op) -> std::optional> { + auto contract = dyn_cast(op); + if (!contract) + return std::nullopt; + return gpuMmaUnrollOrder(contract); + }; + vector::populateVectorUnrollPatterns( + patterns, vector::UnrollVectorOptions() + .setNativeShapeFn(getGPUTensorCoreNativeMmaSyncVectorSize) + .setUnrollTraversalOrderFn(unrollOrder)); +} + +//===---------------------------------------------------------------------===// +// VectorToMMAConversionOp +//===---------------------------------------------------------------------===// + +void buddy::gpu::VectorToMMAConversionOp::getEffects( + SmallVectorImpl &effects) { + mlir::transform::onlyReadsHandle(getTarget(), effects); + mlir::transform::modifiesPayload(effects); +} + +DiagnosedSilenceableFailure +buddy::gpu::VectorToMMAConversionOp::applyToOne( + mlir::transform::TransformRewriter &rewriter, Operation *target, + mlir::transform::ApplyToEachResultList &results, + mlir::transform::TransformState &state) { + if (!target->hasTrait()) { + // target->emitOpError( + // "applies only to isolated-from-above targets because it " + // "needs to apply " + // "patterns greedily"); + // return emitDefaultDefiniteFailure(target); + } + + auto funcOp = dyn_cast(target); + if (!funcOp) { + target->emitOpError("Must apply to a func op"); + return emitDefaultDefiniteFailure(target); + } + + if (!(getUseMmaSync() ^ getUseWmma())) { + target->emitOpError( + "Exactly one of use_mma_sync or use_wmma must be specified"); + return emitDefaultDefiniteFailure(target); + } + + MLIRContext *ctx = target->getContext(); + mlir::transform::ErrorCheckingTrackingListener listener(state, *this); + GreedyRewriteConfig config; + config.listener = &listener; + + // Unrolling to native vector size must have previously occurred. + // TODO: Add pattern to propagate the extract through the scf.for + // ops. Convert slice of contract operations to mma_sync/wmma ops. + RewritePatternSet patterns(ctx); + mlir::vector::populateCastAwayVectorLeadingOneDimPatterns(patterns); + populatePrepareVectorToMMAPatterns(patterns, getUseMmaSync()); + if (failed( + applyPatternsAndFoldGreedily(target, std::move(patterns), config))) { + target->emitOpError("vector to mma preparation patterns failed to apply"); + return emitDefaultDefiniteFailure(target); + } + + auto diag = DiagnosedSilenceableFailure::success(); + if (getUseWmma()) { + if (failed(convertVectorToMMAOps(rewriter, target))) + return mlir::emitDefiniteFailure( + target, "vector to wmma patterns failed to apply"); + return listener.checkAndResetError(); + } + + if (failed(convertVectorToNVVMCompatibleMMASync(rewriter, funcOp))) + return mlir::emitDefiniteFailure(target, + "vector to mma patterns failed to apply"); + + // Using TF32 for Float. + RewritePatternSet f32ToTF32patterns(funcOp.getContext()); + nvgpu::populateMmaSyncF32ToTF32Patterns(f32ToTF32patterns, + nvgpu::MmaSyncF32Lowering::TF32); + if (failed(applyPatternsAndFoldGreedily(funcOp, std::move(f32ToTF32patterns), + config))) + return mlir::emitDefiniteFailure( + target, "vector to mma F32ToTF32 patterns failed to apply"); + + return listener.checkAndResetError(); +} + +#define GET_OP_CLASSES +#include "GPU/TransformOps.cpp.inc" diff --git a/midend/lib/Utils/CMakeLists.txt b/midend/lib/Utils/CMakeLists.txt index ff9aa6e38..9cedf4f6a 100644 --- a/midend/lib/Utils/CMakeLists.txt +++ b/midend/lib/Utils/CMakeLists.txt @@ -2,12 +2,13 @@ add_mlir_library(BuddyUtils Utils.cpp DIPUtils.cpp DAPUtils.cpp + GPUUtils.cpp AffineTransformUtils.cpp ) add_mlir_library(BuddyDIPUtils DIPUtils.cpp - + LINK_LIBS PUBLIC BuddyUtils ) @@ -18,3 +19,32 @@ add_mlir_library(BuddyDAPUtils LINK_LIBS PUBLIC BuddyUtils ) + +add_mlir_library(BuddyGPUUtils + GPUUtils.cpp + + LINK_LIBS PUBLIC + LLVMSupport + LLVMTargetParser + MLIRAffineDialect + MLIRAffineUtils + MLIRAnalysis + MLIRArithDialect + MLIRArithUtils + MLIRFuncDialect + MLIRGPUDialect + MLIRIR + MLIRLinalgDialect + MLIRLinalgTransforms + MLIRLinalgUtils + MLIRMemRefDialect + MLIRSCFDialect + MLIRSideEffectInterfaces + MLIRSupport + MLIRTensorDialect + MLIRTilingInterface + MLIRTransformUtils + MLIRVectorDialect + MLIRViewLikeInterface + MLIRGPUPasses +) diff --git a/midend/lib/Utils/GPUUtils.cpp b/midend/lib/Utils/GPUUtils.cpp new file mode 100644 index 000000000..82058c881 --- /dev/null +++ b/midend/lib/Utils/GPUUtils.cpp @@ -0,0 +1,536 @@ +//====- GPUUtils.cpp ------------------------------------------------------===// +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +//===----------------------------------------------------------------------===// +// +// The process in this file references the IREE project, +// which is hereby acknowledged. +// For the license of the IREE project +// please see: https://github.com/iree-org/iree/blob/main/LICENSE +// +//===----------------------------------------------------------------------===// +// +// This file implements GPU dialect specific utility functions for the buddy +// compiler ecosystem. +// +//===----------------------------------------------------------------------===// + +#ifndef UTILS_GPUUTILS_DEF +#define UTILS_GPUUTILS_DEF + +#include "mlir/Analysis/Liveness.h" +#include "mlir/Analysis/Presburger/IntegerRelation.h" +#include "mlir/Analysis/SliceAnalysis.h" +#include "mlir/Dialect/Affine/IR/AffineOps.h" +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/GPU/IR/GPUDialect.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/Dialect/NVGPU/IR/NVGPUDialect.h" +#include "mlir/Dialect/Vector/IR/VectorOps.h" +#include "mlir/IR/BuiltinAttributes.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/Matchers.h" +#include "mlir/IR/Visitors.h" +#include "mlir/Interfaces/SideEffectInterfaces.h" +#include "mlir/Interfaces/ValueBoundsOpInterface.h" +#include "mlir/Transforms/TopologicalSortUtils.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/SetVector.h" +#include "llvm/Support/Debug.h" +#include "llvm/Support/ErrorHandling.h" + +#include "Utils/GPUUtils.h" + +#include + +#define DEBUG_TYPE "buddy-codegen-gpu-utils" +#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ") +#define DBGSNL() (llvm::dbgs() << "\n") +#define LDBG(X) LLVM_DEBUG(DBGS() << X << "\n") + +using namespace mlir; + +static constexpr unsigned kShuffleBitWidth = 32; + +namespace mlir::buddy { +namespace gpu { + +/// Pick an unrolling order that will allow tensorcore operation to reuse LHS +/// register. This is needed to get good performance on sm_80 target. +std::optional> +gpuMmaUnrollOrder(vector::ContractionOp contract) { + SmallVector order; + // First make reduction the outer dimensions. + for (auto [index, iter] : llvm::enumerate(contract.getIteratorTypes())) { + if (vector::isReductionIterator(iter)) { + order.push_back(index); + } + } + + llvm::SmallDenseSet dims; + for (AffineExpr expr : contract.getIndexingMapsArray()[0].getResults()) { + dims.insert(expr.cast().getPosition()); + } + // Then parallel dimensions that are part of Lhs as we want to re-use Lhs. + for (auto [index, iter] : llvm::enumerate(contract.getIteratorTypes())) { + if (vector::isParallelIterator(iter) && dims.count(index)) { + order.push_back(index); + } + } + // Then the remaining parallel loops. + for (auto [index, iter] : llvm::enumerate(contract.getIteratorTypes())) { + if (vector::isParallelIterator(iter) && !dims.count(index)) { + order.push_back(index); + } + } + return order; +} + +//===----------------------------------------------------------------------===// +// Reduction utils +//===----------------------------------------------------------------------===// + +/// Packs scalar element to it's vector equivalent. +/// (i.e f16 -> vector<1xf16> and f32 -> vector<1xf32>) +static Value promoteElementToVector(Location loc, OpBuilder &builder, + Value input) { + VectorType vectorTypeBroadcast = VectorType::get({1}, input.getType()); + Value vectorInput = + builder.create(loc, vectorTypeBroadcast, input); + return vectorInput; +} + +Value packVectorToSupportedWidth(Location loc, OpBuilder &builder, + Value input) { + LLVM_DEBUG({ + auto vecType = input.getType().cast(); + Type elementType = vecType.getElementType(); + assert(vecType.getDimSize(0) * elementType.getIntOrFloatBitWidth() == + kShuffleBitWidth && + "vecSize * vecBitWidth needs to packable into 32-bitwidth."); + assert(elementType.isIntOrFloat() && + "Only int and float packing is supported."); + }); + VectorType packed32Type = VectorType::get({1}, builder.getI32Type()); + Value packedInputVec = + builder.create(loc, packed32Type, input); + Value packedInput = builder.create(loc, packedInputVec, 0); + return packedInput; +} + +Value unpackToVector(Location loc, OpBuilder &builder, Value packedInput, + VectorType targetVecType) { + LLVM_DEBUG({ + Type packedType = packedInput.getType(); + assert(packedType.isIntOrFloat() && "Only ints and floats are unpackable."); + Type elementType = targetVecType.getElementType(); + assert(targetVecType.getDimSize(0) * elementType.getIntOrFloatBitWidth() == + packedType.getIntOrFloatBitWidth() && + "packed width needs to be unpackable to vecSize * vecBitWidth."); + }); + Value packedVector = promoteElementToVector(loc, builder, packedInput); + Value unpackedVector = + builder.create(loc, targetVecType, packedVector); + return unpackedVector; +} + +//===----------------------------------------------------------------------===// +// getMmaNativeVectorSize +//===----------------------------------------------------------------------===// +/// Returns vector::ContractionOp operand's index where the result is used. +static std::optional +getVectorContractOpOperandId(vector::ContractionOp contractOp, + OpResult result) { + if (contractOp.getLhs() == result) + return 0; + if (contractOp.getRhs() == result) + return 1; + if (contractOp.getAcc() == result) + return 2; + return std::nullopt; +} + +/// Returns vector::ContractionOp operand's index where the +/// vector::TransferReadOp is consumed either consumed directly or via +/// vector::ExtractStridedSliceOp. +static std::optional +getVectorContractOpOperandIdForVectorReadOp(Operation *op) { + vector::ContractionOp contractOp; + + // Check if the vector::TransferReadOp is consumed directly by + // vector::ContractionOp. + if (op->use_empty()) + return std::nullopt; + Operation *firstLevelUser = *((op->getUsers()).begin()); + if (!firstLevelUser) + return std::nullopt; + if (auto contractOp = dyn_cast(firstLevelUser)) + return getVectorContractOpOperandId(contractOp, op->getResult(0)); + + // Check if the vector::TransferReadOp is consumed indirectly by + // vector::ContractionOp. Only check until the second level of use-def chain. + if (firstLevelUser->use_empty()) + return std::nullopt; + Operation *secondLevelUser = *((firstLevelUser->getUsers()).begin()); + if (!secondLevelUser) + return std::nullopt; + if (auto contractOp = dyn_cast(secondLevelUser)) + return getVectorContractOpOperandId(contractOp, + firstLevelUser->getResult(0)); + return std::nullopt; +} + +/// Helper function to return native size for MMA.SYNC-based operations. +std::optional> getMmaNativeVectorSize(Operation *op) { + // Shape of native Tensor Core GPU mma.sync operations. + int64_t mmaShapeM = 16; + int64_t mmaShapeN = 8; + int64_t mmaShapeK; + + // Shape the mma.sync warp-level operation. + if (auto contract = dyn_cast(op)) { + Type sourceType = contract.getLhsType().getElementType(); + + // Set mmaShapeK based on sourceType. + if (sourceType.isInteger(4)) + mmaShapeK = 64; + else if (sourceType.isInteger(8)) + mmaShapeK = 32; + else if (sourceType.isF16() || sourceType.isBF16()) + mmaShapeK = 16; + else if (sourceType.isF32()) + mmaShapeK = 8; + else { + LDBG("unsupported shape for vector.contract: "); + return std::nullopt; + } + + // Initialize/set the starting dims of the ranked shape, such as batch, + // to 1. + SmallVector mmaShape(contract.getIteratorTypes().size() - 3, 1); + mmaShape.append({mmaShapeM, mmaShapeN, mmaShapeK}); + LLVM_DEBUG({ + llvm::interleaveComma(mmaShape, DBGS() << "shape for vector.contract: "); + llvm::dbgs() << "\n"; + }); + return mmaShape; + } + + // Shape of warp-level vector write operation. + if (auto writeOp = dyn_cast(op)) { + if (writeOp.getVectorType().getRank() < 2) + return std::nullopt; + SmallVector outputShape(writeOp.getVectorType().getRank() - 2, 1); + outputShape.append({mmaShapeM, mmaShapeN}); + LLVM_DEBUG({ + llvm::interleaveComma(outputShape, + DBGS() << "shape for vector.xfer_write: "); + llvm::dbgs() << "\n"; + }); + return outputShape; + } + + // Shape of warp-level vector read (load) operation. + if (auto readOp = dyn_cast(op)) { + auto resultVectorType = + llvm::cast(readOp.getVector().getType()); + Type resultElementType = resultVectorType.getElementType(); + + std::optional operandId = + getVectorContractOpOperandIdForVectorReadOp(op); + if (!operandId) { + LLVM_DEBUG({ + DBGS() << "Failed to get operandId for vector::xfer_read: " << *op + << "\n"; + }); + return std::nullopt; + } + + // Loading F16 values from Shared Memory to Registers. + if (resultElementType.isF16() || resultElementType.isBF16()) { + // For matrixC. + if (*operandId == 2) { + SmallVector readShape; + readShape.append({mmaShapeM, mmaShapeN}); + LLVM_DEBUG({ + llvm::interleaveComma(readShape, + DBGS() << "shape for vector.xfer_read: "); + llvm::dbgs() << "\n"; + }); + return readShape; + } + + // For matrixA and matrixB. + if (*operandId == 0 || *operandId == 1) { + // MmaSyncOp input operands: matrixA and matrixB. + // LDSMx1, x2, x4: + // - LDSMx1 loads a 1 tile of 8x8. + // - LDSMx2 loads a 2 tiles of 8x8. + // - LDSMx4 loads a 4 tiles of 8x8. (in use) + // here uses the largest tiled load, i.e., LDSMx4. + + // MmaSyncOp source operand: matrixC. + // matrixC is also read/written in tiled block of 16x16. In the pass + // OptimizeVectorTransfer, matrixC reads are moved above the mainloop + // and writes are moved below the mainloop. Thus, mma.sync read/write + // accumulator inplace. + SmallVector readShape; + readShape.append({16, 16}); + LLVM_DEBUG({ + llvm::interleaveComma(readShape, + DBGS() << "shape for vector.xfer_read: "); + llvm::dbgs() << "\n"; + }); + return readShape; + } + } + + // Loading F32 values from Shared Memory to Registers. + if (resultElementType.isF32()) { + // Set mmaShapeK for F32 datatype mma.sync.f32.tf32.m16n8k8. + mmaShapeK = 8; + + // For matrixC. + if (*operandId == 2) { + SmallVector readShape; + readShape.append({mmaShapeM, mmaShapeN}); + LLVM_DEBUG({ + llvm::interleaveComma(readShape, + DBGS() << "shape for vector.xfer_read: "); + llvm::dbgs() << "\n"; + }); + return readShape; + } + // For matrixA. + if (*operandId == 0) { + SmallVector readShape; + readShape.append({mmaShapeM, mmaShapeK}); + LLVM_DEBUG({ + llvm::interleaveComma(readShape, + DBGS() << "shape for vector.xfer_read: "); + llvm::dbgs() << "\n"; + }); + return readShape; + } + // For matrixB. + if (*operandId == 1) { + // Do not use ldmatrix for matrixB. + // Transfer read ops may need different shapes based on how they are + // being used. For simplicity just match the shape used by the extract + // strided op. + VectorType sliceType; + for (Operation *users : op->getUsers()) { + auto extract = dyn_cast(users); + if (!extract) + return std::nullopt; + auto vecType = llvm::cast(extract.getResult().getType()); + if (sliceType && sliceType != vecType) + return std::nullopt; + sliceType = vecType; + } + LLVM_DEBUG({ + llvm::interleaveComma(sliceType.getShape(), + DBGS() << "shape for vector.xfer_read: "); + llvm::dbgs() << "\n"; + }); + return llvm::to_vector(sliceType.getShape()); + } + } + } + LDBG("unsupported shape for " << op->getName().getStringRef()); + return std::nullopt; +} + +bool hasSharedMemoryAddressSpace(MemRefType memrefType) { + auto addrSpace = llvm::dyn_cast_if_present( + memrefType.getMemorySpace()); + return addrSpace && addrSpace.getValue() == + mlir::gpu::GPUDialect::getWorkgroupAddressSpace(); +} + +template +std::optional +hoistOneStaticallyBoundAllocation(func::FuncOp funcOp, OpBuilder &builder, + Location loc, MemRefType allocLikeType, + ValueRange dynamicSizes, + std::optional alignment) { + IntegerAttr alignmentAttr = + alignment ? builder.getI64IntegerAttr(alignment.value()) : nullptr; + // For static case just create a new allocation in the entry block of the same + // size. No need to insert a subview. + if (dynamicSizes.empty()) { + OpBuilder::InsertionGuard g(builder); + builder.setInsertionPointToStart(&funcOp.getBody().front()); + Value allocation = + builder.create(loc, allocLikeType, alignmentAttr); + if (std::is_same::value) { + builder.setInsertionPoint(funcOp.getBody().front().getTerminator()); + builder.create(loc, allocation); + } + return allocation; + } + + /// For the dynamic but bounded case, insert an allocation of the shape of the + /// bounds, and a subview of the required size to be used as a replacement. + SmallVector staticShape; + SmallVector subviewSizes; + staticShape.reserve(allocLikeType.getRank()); + subviewSizes.reserve(allocLikeType.getRank()); + + int index = 0; + for (auto dimSize : allocLikeType.getShape()) { + if (!ShapedType::isDynamic(dimSize)) { + staticShape.push_back(dimSize); + subviewSizes.push_back(builder.getIndexAttr(dimSize)); + continue; + } + Value dynamicSize = dynamicSizes[index++]; + auto ub = ValueBoundsConstraintSet::computeConstantBound( + presburger::BoundType::UB, dynamicSize, /*dim=*/std::nullopt, + /*stopCondition=*/nullptr, /*closedUB=*/true); + if (failed(ub)) { + return std::nullopt; + } + staticShape.push_back(ub.value()); + subviewSizes.push_back(dynamicSize); + } + SmallVector offsets(allocLikeType.getRank(), + builder.getIndexAttr(0)); + SmallVector strides(allocLikeType.getRank(), + builder.getIndexAttr(1)); + + Value allocation; + { + OpBuilder::InsertionGuard g(builder); + builder.setInsertionPointToStart(&funcOp.getBody().front()); + auto allocationType = + MemRefType::get(staticShape, allocLikeType.getElementType()); + allocation = + builder.create(loc, allocationType, alignmentAttr); + } + + Value subviewOp = builder.create(loc, allocation, offsets, + subviewSizes, strides); + + if (std::is_same::value) { + builder.setInsertionPoint(funcOp.getBody().front().getTerminator()); + builder.create(loc, allocation); + } + return subviewOp; +} + +template +std::optional +hoistOneStaticallyBoundAllocation(func::FuncOp funcOp, OpBuilder &builder, + AllocLikeOpType allocLikeOp) { + OpBuilder::InsertionGuard guard(builder); + builder.setInsertionPoint(allocLikeOp); + return hoistOneStaticallyBoundAllocation( + funcOp, builder, allocLikeOp.getLoc(), allocLikeOp.getType(), + allocLikeOp.getDynamicSizes(), allocLikeOp.getAlignment()); +} + +/// Some uses of a AllocLike can be replaced with a `memref.subview` +/// easily. Other uses (like a use in a `scf.yield` or `func.return`) are +/// non-trivial because of compatibility between types of different SSA values. +static bool isUseReplaceableWithSubview(OpOperand &use) { + Operation *user = use.getOwner(); + return isa(user); +} + +template +void hoistStaticallyBoundAllocationsInFunc(RewriterBase &rewriter, + func::FuncOp funcOp) { + SmallVector allocLikeOps; + + // Collect all allocLikes that are hoistable. + funcOp.walk([&](AllocLikeOpType allocLikeOp) { + if (allocLikeOp->getBlock() == &funcOp.getBody().front()) + return; + if (allocLikeOp.getDynamicSizes().empty()) { + allocLikeOps.push_back(allocLikeOp); + return; + } + if (llvm::all_of(allocLikeOp->getUses(), [](OpOperand &use) { + return isUseReplaceableWithSubview(use); + })) { + allocLikeOps.push_back(allocLikeOp); + return; + } + }); + + // Hoist the allocLikes and replace all uses. + for (auto allocLikeOp : allocLikeOps) { + // Record potential memref::DeallocOps to clean up after hoisting occurs. + SmallVector deallocOps; + for (Operation *user : allocLikeOp->getUsers()) { + auto dealloc = dyn_cast(user); + if (dealloc) + deallocOps.push_back(dealloc); + } + + LLVM_DEBUG({ + llvm::dbgs() << "Alloca Op : "; + allocLikeOp->dump(); + int numUses = std::distance(allocLikeOp.getResult().use_begin(), + allocLikeOp.getResult().use_end()); + llvm::dbgs() << " num Uses : " << numUses; + }); + std::optional replacement = + hoistOneStaticallyBoundAllocation(funcOp, rewriter, allocLikeOp); + if (!replacement) + continue; + LLVM_DEBUG({ + llvm::dbgs() << "Replacement : "; + replacement->dump(); + }); + Value replacementVal = replacement.value(); + rewriter.replaceOp(allocLikeOp, replacementVal); + + for (memref::DeallocOp deallocOp : deallocOps) + rewriter.eraseOp(deallocOp); + } +} + +/// Explicit instantiations for `hoistStaticallyBoundAllocationsInFunc` and +/// dependent functions. +template std::optional +hoistOneStaticallyBoundAllocation( + func::FuncOp funcOp, OpBuilder &builder, Location loc, + MemRefType allocLikeType, ValueRange dynamicSizes, + std::optional alignment); +template std::optional +hoistOneStaticallyBoundAllocation( + func::FuncOp funcOp, OpBuilder &builder, Location loc, + MemRefType allocLikeType, ValueRange dynamicSizes, + std::optional alignment); +template std::optional +hoistOneStaticallyBoundAllocation(func::FuncOp funcOp, + OpBuilder &builder, + memref::AllocOp allocLikeOp); +template std::optional +hoistOneStaticallyBoundAllocation( + func::FuncOp funcOp, OpBuilder &builder, memref::AllocaOp allocLikeOp); +template void +hoistStaticallyBoundAllocationsInFunc(RewriterBase &rewriter, + func::FuncOp funcOp); +template void +hoistStaticallyBoundAllocationsInFunc(RewriterBase &rewriter, + func::FuncOp funcOp); + +} // namespace gpu +} // namespace mlir::buddy +#endif // UTILS_GPUUTILS_DEF diff --git a/tests/Conversion/convert-memcpy-to-gpu.mlir b/tests/Conversion/convert-memcpy-to-gpu.mlir new file mode 100644 index 000000000..63edfd8d0 --- /dev/null +++ b/tests/Conversion/convert-memcpy-to-gpu.mlir @@ -0,0 +1,23 @@ +// RUN: buddy-opt -convert-memcpy-to-gpu -canonicalize %s | FileCheck %s + +// CHECK: %memref = gpu.alloc () : memref<32x32xf32> +// CHECK: %memref_0 = gpu.alloc () : memref<32x32xf32> +// CHECK: gpu.dealloc %memref : memref<32x32xf32> +// CHECK: %alloc = memref.alloc() : memref<32x32xf32> +// CHECK: gpu.memcpy %alloc, %memref_0 : memref<32x32xf32>, memref<32x32xf32> +// CHECK: gpu.dealloc %memref_0 : memref<32x32xf32> +module attributes {gpu.container_module} { + func.func @matmul(%arg0: memref<32x32xf32>, %arg1: memref<32x32xf32>) -> memref<32x32xf32> { + %c2 = arith.constant 2 : index + %c64 = arith.constant 64 : index + %c1 = arith.constant 1 : index + %alloc = memref.alloc() {alignment = 64 : i64} : memref<32x32xf32> + gpu.launch_func @matmul_kernel::@matmul_kernel blocks in (%c1, %c1, %c1) threads in (%c64, %c2, %c1) + return %alloc : memref<32x32xf32> + } + gpu.module @matmul_kernel { + gpu.func @matmul_kernel() kernel attributes {gpu.known_block_size = array, gpu.known_grid_size = array} { + gpu.return + } + } +} diff --git a/tests/Conversion/legalize-shmem-outlining.mlir b/tests/Conversion/legalize-shmem-outlining.mlir new file mode 100644 index 000000000..f80c9b761 --- /dev/null +++ b/tests/Conversion/legalize-shmem-outlining.mlir @@ -0,0 +1,26 @@ +// RUN: buddy-opt -legalize-shmem-outlining -canonicalize %s | FileCheck %s + +// CHECK: module attributes {gpu.container_module} +// CHECK: gpu.launch_func @matmul_kernel::@matmul_kernel blocks in (%c1, %c1, %c1) threads in (%c64, %c2, %c1) +// CHECK: return %alloc : memref<32x32xf32> +// CHECK: gpu.module @matmul_kernel { +// CHECK-NEXT: gpu.func @matmul_kernel() kernel attributes {gpu.known_block_size = array, gpu.known_grid_size = array} { +// CHECK-NEXT: gpu.return +// CHECK-NEXT: } +// CHECK-NEXT: } +module { + func.func @matmul(%arg0: memref<32x32xf32>, %arg1: memref<32x32xf32>) -> memref<32x32xf32> { + %alloc = memref.alloc() : memref<16x32xf32, 3> + %alloc_2 = memref.alloc() : memref<32x16xf32, 3> + %alloc_3 = memref.alloc() {alignment = 64 : i64} : memref<32x32xf32> + %c1 = arith.constant 1 : index + %c64 = arith.constant 64 : index + %c2 = arith.constant 2 : index + gpu.launch blocks(%arg2, %arg3, %arg4) in (%arg8 = %c1, %arg9 = %c1, %arg10 = %c1) threads(%arg5, %arg6, %arg7) in (%arg11 = %c64, %arg12 = %c2, %arg13 = %c1) { + gpu.terminator + } + memref.dealloc %alloc_2 : memref<32x16xf32, 3> + memref.dealloc %alloc : memref<16x32xf32, 3> + return %alloc_3 : memref<32x32xf32> + } +} diff --git a/tests/Dialect/BuddyGPU/hoist-static-alloc.mlir b/tests/Dialect/BuddyGPU/hoist-static-alloc.mlir new file mode 100644 index 000000000..0578b2e00 --- /dev/null +++ b/tests/Dialect/BuddyGPU/hoist-static-alloc.mlir @@ -0,0 +1,92 @@ +// RUN: buddy-opt --split-input-file --transform-interpreter %s | FileCheck %s + +func.func @non_entry_bb_allocs() { + cf.br ^bb1 + ^bb1() : + %0 = memref.alloc() : memref<16xi32> + memref.dealloc %0 : memref<16xi32> + return +} +// CHECK-LABEL: func @non_entry_bb_allocs() +// CHECK-NEXT: %[[ALLOC:.+]] = memref.alloc() : memref<16xi32> +// CHECK-NEXT: memref.dealloc %[[ALLOC]] : memref<16xi32> +// CHECK-NEXT: cf.br ^bb1 +// CHECK-NEXT: ^bb1: +// CHECK-NEXT: return + +module attributes { transform.with_named_sequence } { + transform.named_sequence @__transform_main(%module: !transform.any_op {transform.readonly}) { + %func = transform.structured.match ops{["func.func"]} in %module + : (!transform.any_op) -> !transform.op<"func.func"> + transform.buddy.hoist_static_alloc %func : (!transform.op<"func.func">) -> () + transform.yield + } // @__transform_main +} // module + +// ----- + +#map = affine_map<(d0) -> (d0, 16)> +func.func @nested_op_alloc_subview_use_static(%arg0 : index, %o0 : index, %o1 : index) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c42 = arith.constant 42 : i32 + scf.for %iv = %c0 to %arg0 step %c1 { + %0 = affine.min #map(%iv) + %1 = memref.alloc() : memref<16x16xi32> + %2 = memref.subview %1[%o0, %o1][%c1, %0][1, 1] : memref<16x16xi32> to memref> + memref.dealloc %1 : memref<16x16xi32> + scf.yield + } + return +} +// CHECK-LABEL: func @nested_op_alloc_subview_use_static( +// CHECK-NEXT: %[[ALLOC:.+]] = memref.alloc() : memref<16x16xi32> +// CHECK: scf.for +// CHECK: %[[SIZE:.+]] = affine.min +// CHECK: memref.subview %[[ALLOC]] +// CHECK-NEXT: } +// CHECK-NEXT: memref.dealloc %[[ALLOC]] : memref<16x16xi32> + +module attributes { transform.with_named_sequence } { + transform.named_sequence @__transform_main(%module: !transform.any_op {transform.readonly}) { + %func = transform.structured.match ops{["func.func"]} in %module + : (!transform.any_op) -> !transform.op<"func.func"> + transform.buddy.hoist_static_alloc %func : (!transform.op<"func.func">) -> () + transform.yield + } // @__transform_main +} // module + +// ----- + +#map = affine_map<(d0) -> (d0, 16)> +func.func @nested_op_alloc_subview_use_dynamic(%arg0 : index, %o0 : index, %o1 : index) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c42 = arith.constant 42 : i32 + scf.for %iv = %c0 to %arg0 step %c1 { + %0 = affine.min #map(%iv) + %1 = memref.alloc(%0, %0) : memref + %2 = memref.subview %1[%o0, %o1][%c1, %0][1, 1] : memref to memref> + memref.dealloc %1 : memref + scf.yield + } + return +} + +// CHECK-LABEL: func @nested_op_alloc_subview_use_dynamic( +// CHECK-NEXT: %[[ALLOC:.+]] = memref.alloc() : memref<16x16xi32> +// CHECK: scf.for +// CHECK: %[[SIZE:.+]] = affine.min +// CHECK: %subview = memref.subview %[[ALLOC]][0, 0] [%[[SIZE]], %[[SIZE]]] [1, 1] +// CHECK: %subview_0 = memref.subview %subview[%arg1, %arg2] [%c1, %0] [1, 1] +// CHECK-NEXT: } +// CHECK-NEXT: memref.dealloc %[[ALLOC]] : memref<16x16xi32> + +module attributes { transform.with_named_sequence } { + transform.named_sequence @__transform_main(%module: !transform.any_op {transform.readonly}) { + %func = transform.structured.match ops{["func.func"]} in %module + : (!transform.any_op) -> !transform.op<"func.func"> + transform.buddy.hoist_static_alloc %func : (!transform.op<"func.func">) -> () + transform.yield + } // @__transform_main +} // module diff --git a/tests/Dialect/BuddyGPU/transform-dialect-vector-to-nvgpu-mma.mlir b/tests/Dialect/BuddyGPU/transform-dialect-vector-to-nvgpu-mma.mlir new file mode 100644 index 000000000..3aee301ce --- /dev/null +++ b/tests/Dialect/BuddyGPU/transform-dialect-vector-to-nvgpu-mma.mlir @@ -0,0 +1,97 @@ +// RUN: buddy-opt --split-input-file --transform-interpreter %s | FileCheck %s + + +#matmat_accesses = [ + affine_map<(m, n, k) -> (m, k)>, + affine_map<(m, n, k) -> (k, n)>, + affine_map<(m, n, k) -> (m, n)> +] + +#matmat_trait = { + indexing_maps = #matmat_accesses, + iterator_types = ["parallel", "parallel", "reduction"] +} + +func.func @wmma(%a: memref<16x16xf32>, %b: memref<16x16xf32>, %c: memref<16x16xf32>) { + %c0 = arith.constant 0: index + %cst = arith.constant 0.0: f32 + %va = vector.transfer_read %a[%c0, %c0], %cst: memref<16x16xf32>, vector<16x16xf32> + %vb = vector.transfer_read %b[%c0, %c0], %cst: memref<16x16xf32>, vector<16x16xf32> + %vc = vector.transfer_read %c[%c0, %c0], %cst: memref<16x16xf32>, vector<16x16xf32> + + // CHECK-NOT: vector.contract + // CHECK: gpu.subgroup_mma_compute + %vres = vector.contract #matmat_trait %va, %vb, %vc + : vector<16x16xf32>, vector<16x16xf32> into vector<16x16xf32> + vector.transfer_write %vres, %c[%c0, %c0]: vector<16x16xf32>, memref<16x16xf32> + return +} + +module attributes { transform.with_named_sequence } { + transform.named_sequence @__transform_main( + %module: !transform.any_op {transform.readonly}) { + %func = transform.structured.match ops{["func.func"]} in %module + : (!transform.any_op) -> !transform.any_op + transform.apply_patterns to %func { + transform.apply_patterns.buddy.unroll_vectors_gpu_mma_sync + } : !transform.any_op + transform.buddy.vector.vector_to_mma_conversion %func { use_wmma } : (!transform.any_op) -> () + + // Apply canonicalization post-hoc to trigger DCE and pass the test + // (i.e. all vector.contract are dead). + // TODO: consider having the vector_to_mma_conversion do the DCE automatically. + transform.apply_patterns to %func { + transform.apply_patterns.canonicalization + } : !transform.any_op + + transform.yield + } +} + +// ----- + +#matmat_accesses = [ + affine_map<(m, n, k) -> (m, k)>, + affine_map<(m, n, k) -> (k, n)>, + affine_map<(m, n, k) -> (m, n)> +] +#matmat_trait = { + indexing_maps = #matmat_accesses, + iterator_types = ["parallel", "parallel", "reduction"] +} + +func.func @mma_sync(%a: memref<16x16xf32>, %b: memref<16x16xf32>, %c: memref<16x16xf32>) { + %c0 = arith.constant 0: index + %cst = arith.constant 0.0: f32 + %va = vector.transfer_read %a[%c0, %c0], %cst: memref<16x16xf32>, vector<16x16xf32> + %vb = vector.transfer_read %b[%c0, %c0], %cst: memref<16x16xf32>, vector<16x16xf32> + %vc = vector.transfer_read %c[%c0, %c0], %cst: memref<16x16xf32>, vector<16x16xf32> + + // CHECK-NOT: vector.contract + // CHECK: nvgpu.mma.sync{{.*}} tf32Enabled} + %vres = vector.contract #matmat_trait %va, %vb, %vc + : vector<16x16xf32>, vector<16x16xf32> into vector<16x16xf32> + vector.transfer_write %vres, %c[%c0, %c0]: vector<16x16xf32>, memref<16x16xf32> + return +} + +module attributes { transform.with_named_sequence } { + transform.named_sequence @__transform_main( + %module: !transform.any_op {transform.readonly}) { + %func = transform.structured.match ops{["func.func"]} in %module + : (!transform.any_op) -> !transform.any_op + transform.apply_patterns to %func { + transform.apply_patterns.buddy.unroll_vectors_gpu_mma_sync + } : !transform.any_op + transform.buddy.vector.vector_to_mma_conversion %func { use_mma_sync } : (!transform.any_op) -> () + + // Apply canonicalization post-hoc to trigger DCE and pass the test + // (i.e. all vector.contract are dead). + // TODO: consider having the vector_to_mma_conversion do the DCE automatically. + transform.apply_patterns to %func { + transform.apply_patterns.canonicalization + } : !transform.any_op + + transform.yield + } +} diff --git a/tools/buddy-opt/CMakeLists.txt b/tools/buddy-opt/CMakeLists.txt index 24bcde935..4b976dda0 100644 --- a/tools/buddy-opt/CMakeLists.txt +++ b/tools/buddy-opt/CMakeLists.txt @@ -37,4 +37,10 @@ target_link_libraries(buddy-opt SchedulingOnDevices LowerSche FuncBufferizeDynamicOffset + MLIRGPUPasses + BuddyGPUTransformOPs + MLIRTestTransforms + MLIRTestTransformDialect + MLIRTransforms + MLIRTransformUtils ) diff --git a/tools/buddy-opt/buddy-opt.cpp b/tools/buddy-opt/buddy-opt.cpp index bea9513b5..8b0919c1e 100644 --- a/tools/buddy-opt/buddy-opt.cpp +++ b/tools/buddy-opt/buddy-opt.cpp @@ -47,6 +47,7 @@ #include "Gemmini/GemminiOps.h" #include "Sche/ScheDialect.h" #include "Sche/ScheOps.h" +#include "GPU/TransformOps.h" namespace mlir { namespace buddy { @@ -71,6 +72,8 @@ void registerLowerLinalgToGemminiPass(); void registerDeviceSchedulePass(); void registerLowerSchePass(); void registerFuncBufferizeDynamicOffsetPass(); +void registerConvertMemcpyToGPUPass(); +void registerLegalizeShmemOutliningPass(); } // namespace buddy } // namespace mlir @@ -104,6 +107,10 @@ int main(int argc, char **argv) { mlir::buddy::registerLowerSchePass(); mlir::buddy::registerFuncBufferizeDynamicOffsetPass(); + // Register gpu passes + mlir::buddy::registerConvertMemcpyToGPUPass(); + mlir::buddy::registerLegalizeShmemOutliningPass(); + mlir::DialectRegistry registry; // Register all MLIR core dialects. registerAllDialects(registry); @@ -119,6 +126,8 @@ int main(int argc, char **argv) { buddy::sche::ScheDialect>(); // clang-format on - return mlir::failed( - mlir::MlirOptMain(argc, argv, "buddy-mlir optimizer driver", registry)); + mlir::buddy::registerBuddyGPUTransformOps(registry); + + return mlir::failed(mlir::MlirOptMain( + argc, argv, "buddy-mlir optimizer driver", registry)); }