From a23e3d4f6eacb0b815193be04b9eef552d870d16 Mon Sep 17 00:00:00 2001 From: Chen Weiwei <100988241+FloatingcloudKnight@users.noreply.github.com> Date: Fri, 27 Dec 2024 17:17:24 +0800 Subject: [PATCH 1/8] [midend] Add poolingnhwcmax vectorization pass, examples, and tests. (#430) --- examples/BuddyNext/makefile | 13 + examples/BuddyNext/pooling-nhwc-max-vec.mlir | 132 +++++++ .../MLIRLinalg/linalg-pooling-nhwc-max.mlir | 97 +++++ examples/MLIRLinalg/makefile | 36 ++ .../ConvVectorization/CMakeLists.txt | 1 + .../PoolingNhwcMaxVectorization.cpp | 368 ++++++++++++++++++ .../pooling-nhwc-max-vectorization.mlir | 32 ++ tools/buddy-opt/buddy-opt.cpp | 3 + 8 files changed, 682 insertions(+) create mode 100644 examples/BuddyNext/pooling-nhwc-max-vec.mlir create mode 100644 examples/MLIRLinalg/linalg-pooling-nhwc-max.mlir create mode 100644 midend/lib/Conversion/ConvVectorization/PoolingNhwcMaxVectorization.cpp create mode 100644 tests/Conversion/pooling-nhwc-max-vectorization.mlir diff --git a/examples/BuddyNext/makefile b/examples/BuddyNext/makefile index 83b2364d03..3ee282499c 100644 --- a/examples/BuddyNext/makefile +++ b/examples/BuddyNext/makefile @@ -701,3 +701,16 @@ next-ffn-parallel-vec-run: -reconcile-unrealized-casts | \ ${MLIR_CPU_RUNNER} ${OPT_FLAG} -e main -entry-point-result=void \ -shared-libs=${MLIR_RUNNER_UTILS} -shared-libs=${MLIR_C_RUNNER_UTILS} + +pooling-nhwc-max-vec-run: + @${BUDDY_OPT} ./pooling-nhwc-max-vec.mlir \ + -convert-linalg-to-loops \ + -lower-affine \ + -convert-scf-to-cf \ + -convert-vector-to-llvm \ + -finalize-memref-to-llvm \ + -convert-arith-to-llvm \ + -convert-func-to-llvm \ + -reconcile-unrealized-casts | \ + ${MLIR_CPU_RUNNER} ${OPT_FLAG} -e main -entry-point-result=void \ + -shared-libs=${MLIR_RUNNER_UTILS} -shared-libs=${MLIR_C_RUNNER_UTILS} diff --git a/examples/BuddyNext/pooling-nhwc-max-vec.mlir b/examples/BuddyNext/pooling-nhwc-max-vec.mlir new file mode 100644 index 0000000000..d6d8a35d14 --- /dev/null +++ b/examples/BuddyNext/pooling-nhwc-max-vec.mlir @@ -0,0 +1,132 @@ +// RUN: buddy-opt %s \ +// RUN: -convert-linalg-to-loops \ +// RUN: -convert-vector-to-scf \ +// RUN: -lower-affine \ +// RUN: -convert-scf-to-cf \ +// RUN: -convert-vector-to-llvm \ +// RUN: -finalize-memref-to-llvm \ +// RUN: -convert-func-to-llvm \ +// RUN: -reconcile-unrealized-casts \ +// RUN: | mlir-cpu-runner -e main -entry-point-result=void \ +// RUN: -shared-libs=%mlir_runner_utils_dir/libmlir_runner_utils%shlibext \ +// RUN: -shared-libs=%mlir_runner_utils_dir/libmlir_c_runner_utils%shlibext \ +// RUN: | FileCheck %s + +#map = affine_map<(d0) -> (d0)> +#map1 = affine_map<(d0, d1) -> (d0 + d1)> + +module { + func.func private @rtclock() -> f64 + func.func private @printMemrefF32(memref<*xf32>) + func.func @pooling_nhwc_max(%arg0: memref, %arg1: memref, %arg2: memref) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c2 = arith.constant 2 : index + %c3 = arith.constant 3 : index + %vl_step = arith.constant 32 : index + %c0_f32 = arith.constant 0.000000e+00 : f32 + %0 = vector.splat %c0_f32 : vector<32xf32> + %dim = memref.dim %arg1, %c0 : memref + %dim_0 = memref.dim %arg1, %c1 : memref + %dim_1 = memref.dim %arg2, %c0 : memref + %dim_2 = memref.dim %arg2, %c1 : memref + %dim_3 = memref.dim %arg2, %c2 : memref + %dim_4 = memref.dim %arg2, %c3 : memref + + // Calculate the upper bound for vectorized processing + // - Subtract `vl_step` is to avoid overflow at the vectorization tail. + // - Add 1 to ensure the final loop runs when the workload length + // is divisible by the vector size. + %dim_4_upbound_tmp = arith.subi %dim_4, %vl_step : index + %dim_4_upbound = arith.addi %dim_4_upbound_tmp, %c1 : index + + %t_start = call @rtclock() : () -> f64 + affine.for %arg3 = #map(%c0) to #map(%dim_1) { + affine.for %arg4 = #map(%c0) to #map(%dim_2) { + affine.for %arg5 = #map(%c0) to #map(%dim_3) { + // Perform the vectorization body. + %iter_idx = scf.for %arg6 = %c0 to %dim_4_upbound + step %vl_step iter_args(%iter_init = %c0) -> (index) { // N + %4 = vector.load %arg2[%arg3, %arg4, %arg5, %arg6] : memref, vector<32xf32> + %5 = affine.for %arg7 = #map(%c0) to #map(%dim) iter_args(%arg8 = %4) -> (vector<32xf32>) { + %6 = affine.for %arg9 = #map(%c0) to #map(%dim_0) iter_args(%arg10 = %arg8) -> (vector<32xf32>) { + %in_iter_h = affine.apply #map1 (%arg7, %arg4) + %in_iter_w = affine.apply #map1 (%arg9, %arg5) + %7 = vector.load %arg0[%arg3, %in_iter_h, %in_iter_w, %arg6] : memref, vector<32xf32> + %8 = arith.maximumf %7, %arg10 : vector<32xf32> + affine.yield %8 : vector<32xf32> + } + affine.yield %6 : vector<32xf32> + } + vector.store %5, %arg2[%arg3, %arg4, %arg5, %arg6] : memref, vector<32xf32> + %dim_4_next = arith.addi %dim_4, %vl_step : index + scf.yield %dim_4_next : index + } + // Compute the tail size and Process the remaining elements + // using masked vector operations. + %tail_size = arith.subi %dim_4, %iter_idx : index + %3 = arith.cmpi sgt, %tail_size, %c0 : index + scf.if %3 { + %mask = vector.create_mask %tail_size : vector<32xi1> + %5 = vector.maskedload %arg2[%arg3, %arg4, %arg5, %iter_idx], %mask, %0 : memref, vector<32xi1>, vector<32xf32> into vector<32xf32> + %6 = affine.for %arg7 = #map(%c0) to #map(%dim) iter_args(%arg8 = %5) -> (vector<32xf32>) { + %8 = arith.addi %arg4, %arg7 : index + %7 = affine.for %arg9 = #map(%c0) to #map(%dim_0) iter_args(%arg10 = %arg8) -> (vector<32xf32>) { + %9 = arith.addi %arg9, %arg5 : index + %10 = vector.maskedload %arg0[%arg3, %8, %9, %iter_idx], %mask, %0 : memref, vector<32xi1>, vector<32xf32> into vector<32xf32> + %11 = arith.maximumf %10, %arg10 : vector<32xf32> + affine.yield %11 : vector<32xf32> + } + affine.yield %7 : vector<32xf32> + } + vector.maskedstore %arg2[%arg3, %arg4, %arg5, %iter_idx], %mask, %6 : memref, vector<32xi1>, vector<32xf32> + } + } + } + } + %t_end = call @rtclock() : () -> f64 + %time = arith.subf %t_end, %t_start : f64 + %printed_output = memref.cast %arg2 : memref to memref<*xf32> + call @printMemrefF32(%printed_output) : (memref<*xf32>) -> () + + // Print timings. + vector.print %time : f64 + + return + } + + func.func @main(){ + // Set up dims. + %c1 = arith.constant 1 : index + %cInput = arith.constant 24 : index + %cKernel = arith.constant 2 : index + %cOutput = arith.constant 12 : index + %c6 = arith.constant 6 : index + + // Set Init Value. + %cf1_32 = arith.constant 1.0 : f32 + + %a = memref.alloc(%c1, %cInput, %cInput, %c6) : memref + %b = memref.alloc(%cKernel, %cKernel) : memref + %c = memref.alloc(%c1, %cOutput, %cOutput, %c6) : memref + + linalg.fill ins(%cf1_32 : f32) outs(%a : memref) + linalg.fill ins(%cf1_32 : f32) outs(%b : memref) + linalg.fill ins(%cf1_32 : f32) outs(%c : memref) + + // All the elements of the MemRef are the same, + // only check the first line to verify the correctness. + // CHECK: Unranked Memref + // CHECK: [ + // CHECK: [ + // CHECK: [ + // CHECK: [1{{(, 1)*}}], + call @pooling_nhwc_max(%a, %b, %c) : (memref, memref, memref) -> () + + memref.dealloc %c : memref + memref.dealloc %b : memref + memref.dealloc %a : memref + + return + } +} diff --git a/examples/MLIRLinalg/linalg-pooling-nhwc-max.mlir b/examples/MLIRLinalg/linalg-pooling-nhwc-max.mlir new file mode 100644 index 0000000000..3577c09272 --- /dev/null +++ b/examples/MLIRLinalg/linalg-pooling-nhwc-max.mlir @@ -0,0 +1,97 @@ +// RUN: buddy-opt %s \ +// RUN: -pooling-nhwc-max-vectorization \ +// RUN: -convert-linalg-to-loops \ +// RUN: -lower-affine \ +// RUN: -convert-scf-to-cf \ +// RUN: -convert-vector-to-llvm \ +// RUN: -finalize-memref-to-llvm \ +// RUN: -convert-arith-to-llvm \ +// RUN: -convert-func-to-llvm \ +// RUN: -reconcile-unrealized-casts \ +// RUN: | mlir-cpu-runner -e main -entry-point-result=void \ +// RUN: -shared-libs=%mlir_runner_utils_dir/libmlir_runner_utils%shlibext \ +// RUN: -shared-libs=%mlir_runner_utils_dir/libmlir_c_runner_utils%shlibext \ +// RUN: | FileCheck %s + +module{ + func.func private @rtclock() -> f64 + func.func private @printMemrefF32(memref<*xf32>) + + func.func @pooling_nhwc_max(%a : memref, %b : memref, %c : memref) { + %t_start = call @rtclock() : () -> f64 + + linalg.pooling_nhwc_max {dilations = dense<1> : vector<2xi64>, strides = dense<2> : vector<2xi64>} + ins(%a, %b : memref, memref) + outs(%c : memref) + + %t_end = call @rtclock() : () -> f64 + %time = arith.subf %t_end, %t_start : f64 + %printed_output = memref.cast %c : memref to memref<*xf32> + call @printMemrefF32(%printed_output) : (memref<*xf32>) -> () + + // Print timings. + vector.print %time : f64 + + return + } + + func.func @alloc_f32(%arg0: index, %arg1: index, %arg2: index, %arg3: index, %arg4: f32) -> memref { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %0 = memref.alloc(%arg0, %arg1, %arg2, %arg3) : memref + scf.for %idx0 = %c0 to %arg0 step %c1 { + scf.for %idx1 = %c0 to %arg1 step %c1 { + scf.for %idx2 = %c0 to %arg2 step %c1 { + scf.for %idx3 = %c0 to %arg3 step %c1 { + memref.store %arg4, %0[%idx0, %idx1, %idx2, %idx3] : memref + } + } + } + } + return %0 : memref + } + + func.func @alloc2_f32(%arg0: index, %arg1: index, %arg4: f32) -> memref { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %0 = memref.alloc(%arg0, %arg1) : memref + scf.for %idx0 = %c0 to %arg0 step %c1 { + scf.for %idx1 = %c0 to %arg1 step %c1 { + memref.store %arg4, %0[%idx0, %idx1] : memref + } + } + return %0 : memref + } + + func.func @main(){ + // Set up dims. + %c1 = arith.constant 1 : index + %c24 = arith.constant 24 : index + %c2 = arith.constant 2 : index + %c12 = arith.constant 12 : index + %c6 = arith.constant 6 : index + + // Set Init Value. + %f0 = arith.constant 0.000000e+00 : f32 + %f1 = arith.constant 1.000000e+00 : f32 + + %v0 = call @alloc_f32(%c1, %c24, %c24, %c6, %f1) : (index, index, index, index, f32) -> memref + %v1 = call @alloc2_f32(%c2, %c2, %f0) : (index, index, f32) -> memref + %v2 = call @alloc_f32(%c1, %c12, %c12, %c6, %f0) : (index, index, index, index, f32) -> memref + + // All the elements of the MemRef are the same, + // only check the first line to verify the correctness. + // CHECK: Unranked Memref + // CHECK: [ + // CHECK: [ + // CHECK: [ + // CHECK: [1{{(, 1)*}}], + call @pooling_nhwc_max(%v0, %v1, %v2) : (memref, memref, memref) -> () + + memref.dealloc %v0 : memref + memref.dealloc %v1 : memref + memref.dealloc %v2 : memref + + return + } +} diff --git a/examples/MLIRLinalg/makefile b/examples/MLIRLinalg/makefile index d9a37926f4..865a0b162c 100644 --- a/examples/MLIRLinalg/makefile +++ b/examples/MLIRLinalg/makefile @@ -497,3 +497,39 @@ linalg-matmul-vectorization-lower: @${BUDDY_OPT} linalg-matmul.mlir \ -matmul-vectorization \ -o log.mlir + +linalg-pooling-nhwc-max-run: + @${BUDDY_OPT} linalg-pooling-nhwc-max.mlir \ + -convert-linalg-to-loops \ + -lower-affine \ + -convert-scf-to-cf \ + -convert-vector-to-llvm \ + -finalize-memref-to-llvm \ + -convert-arith-to-llvm \ + -convert-func-to-llvm \ + -reconcile-unrealized-casts | \ + ${MLIR_CPU_RUNNER} ${OPT_FLAG} -e main \ + -entry-point-result=void \ + -shared-libs=${MLIR_RUNNER_UTILS} \ + -shared-libs=${MLIR_C_RUNNER_UTILS} + +linalg-pooling-nhwc-max-vectorization-lower: + @${BUDDY_OPT} linalg-pooling-nhwc-max.mlir \ + -pooling-nhwc-max-vectorization \ + -o log.mlir + +linalg-pooling-nhwc-max-vectorization-run: + @${BUDDY_OPT} linalg-pooling-nhwc-max.mlir \ + -pooling-nhwc-max-vectorization \ + -convert-linalg-to-loops \ + -lower-affine \ + -convert-scf-to-cf \ + -convert-vector-to-llvm \ + -finalize-memref-to-llvm \ + -convert-arith-to-llvm \ + -convert-func-to-llvm \ + -reconcile-unrealized-casts | \ + ${MLIR_CPU_RUNNER} ${OPT_FLAG} -e main \ + -entry-point-result=void \ + -shared-libs=${MLIR_RUNNER_UTILS} \ + -shared-libs=${MLIR_C_RUNNER_UTILS} diff --git a/midend/lib/Conversion/ConvVectorization/CMakeLists.txt b/midend/lib/Conversion/ConvVectorization/CMakeLists.txt index fce89520b6..d4cc3ec987 100644 --- a/midend/lib/Conversion/ConvVectorization/CMakeLists.txt +++ b/midend/lib/Conversion/ConvVectorization/CMakeLists.txt @@ -2,6 +2,7 @@ add_mlir_library(CBConvVectorization CBConvVectorization.cpp GEMMPointwiseConv2DNhwcHwcf.cpp PoolingVectorization.cpp + PoolingNhwcMaxVectorization.cpp LINK_LIBS PUBLIC BuddyUtils diff --git a/midend/lib/Conversion/ConvVectorization/PoolingNhwcMaxVectorization.cpp b/midend/lib/Conversion/ConvVectorization/PoolingNhwcMaxVectorization.cpp new file mode 100644 index 0000000000..280d11b226 --- /dev/null +++ b/midend/lib/Conversion/ConvVectorization/PoolingNhwcMaxVectorization.cpp @@ -0,0 +1,368 @@ +//===--------PoolingNhwcMaxVectorization.cpp-------------------------------===// +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +//===----------------------------------------------------------------------===// +// +// This file implements the Pooling Nhwc Max Vectorization. +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/Linalg/IR/Linalg.h" +#include "mlir/Dialect/Math/IR/Math.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/Dialect/Vector/IR/VectorOps.h" +#include "mlir/IR/AffineExpr.h" +#include "mlir/IR/AffineMap.h" +#include "mlir/IR/Attributes.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/BuiltinAttributes.h" +#include "mlir/IR/TypeRange.h" +#include "mlir/IR/ValueRange.h" +#include "llvm/ADT/ArrayRef.h" +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "Utils/Utils.h" + +using namespace mlir; +using namespace vector; + +//===----------------------------------------------------------------------===// +// Rewrite Pattern +//===----------------------------------------------------------------------===// + +namespace { + +class PoolingNhwcMaxVectorizationPattern : public ConversionPattern { +public: + explicit PoolingNhwcMaxVectorizationPattern(MLIRContext *context, + int64_t stripParam) + : ConversionPattern(linalg::PoolingNhwcMaxOp::getOperationName(), 1, + context) { + strip = stripParam; + } + + LogicalResult + matchAndRewrite(Operation *op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const override { + auto loc = op->getLoc(); + auto ctx = op->getContext(); + + // Get i1 as the element type for mask vector. + IntegerType i1 = IntegerType::get(ctx, 1); + VectorType vectorMaskTy = mlir::VectorType::get({strip}, i1); + + // Get input, kernel and output. + Value input = op->getOperand(0); + Value kernel = op->getOperand(1); + Value output = op->getOperand(2); + // Get strides. + SmallVector strides = {1, 1}; + if (op->hasAttr("strides")) { + if (auto attr = + op->getAttrOfType("strides")) { + strides.clear(); + for (auto value : attr.getValues()) { + strides.push_back(value); + } + } + } + bool stride1 = strides[0] != 1; + bool stride2 = strides[1] != 1; + Value strHeight = rewriter.create(loc, strides[0]); + Value strWidth = rewriter.create(loc, strides[1]); + + // // Get dilations. + SmallVector dilations = {1, 1}; + if (op->hasAttr("dilations")) { + if (auto attr = + op->getAttrOfType("dilations")) { + dilations.clear(); + for (auto value : attr.getValues()) { + dilations.push_back(value); + } + } + } + bool dilated1 = dilations[0] != 1; + bool dilated2 = dilations[1] != 1; + Value dilHeight = + rewriter.create(loc, dilations[0]); + Value dilWidth = rewriter.create(loc, dilations[1]); + + // Get ElementType of input. + Type elementTy = input.getType().cast().getElementType(); + VectorType vectorTy = mlir::VectorType::get({strip}, elementTy); + + // Get Constants. + const Value c0 = rewriter.create(loc, 0); + const Value c1 = rewriter.create(loc, 1); + const Value c2 = rewriter.create(loc, 2); + const Value c3 = rewriter.create(loc, 3); + const Value vlStep = rewriter.create(loc, strip); + const Value zero = + buddy::insertZeroConstantOp(ctx, rewriter, loc, elementTy); + + // Create pass through vector. + Value passThroughVec = rewriter.create(loc, vectorTy, zero); + + // Get Dimensions of Kernel. + Value kernelHeight = rewriter.create(loc, kernel, c0); + Value kernelWidth = rewriter.create(loc, kernel, c1); + + // Get Dimensions of Outputs. + Value batch = rewriter.create(loc, output, c0); + Value height = rewriter.create(loc, output, c1); + Value width = rewriter.create(loc, output, c2); + Value channels = rewriter.create(loc, output, c3); + + // Calculate the upper bound for vectorized processing + // - Subtract `vlStep` is to avoid overflow at the vectorization tail. + // - Add 1 to ensure the final loop runs when the workload length + // is divisible by the vector size. + Value upperBoundTmp = rewriter.create(loc, channels, vlStep); + Value upperBound = rewriter.create(loc, upperBoundTmp, c1); + + SmallVector lowerBounds(3, c0); + SmallVector uperBounds{batch, height, width}; + SmallVector steps(3, /*Value=*/1); + affine::buildAffineLoopNest( + rewriter, loc, lowerBounds, uperBounds, steps, + [&](OpBuilder &builder, Location loc, ValueRange ivs) { + // Create strides variables. + Value tmpIvs1 = ivs[1]; + if (stride1) { + tmpIvs1 = builder.create(loc, ivs[1], strHeight); + } + Value tmpIvs2 = ivs[2]; + if (stride2) { + tmpIvs2 = builder.create(loc, ivs[2], strWidth); + } + // Create strip mining loop. + auto iterIdx = builder.create( + loc, c0, upperBound, /*Step=*/vlStep, ValueRange{c0}, + [&](OpBuilder &nestedBuilder, Location nestedLoc, Value iv, + ValueRange itrArgs) { + Value outputVector = nestedBuilder.create( + loc, vectorTy, output, + ValueRange{ivs[0], ivs[1], ivs[2], iv}); + + auto tmp0 = nestedBuilder.create( + loc, ValueRange{c0}, builder.getDimIdentityMap(), + ValueRange{kernelHeight}, builder.getDimIdentityMap(), + /*Step=*/1, ValueRange{outputVector}, + [&](OpBuilder &builder, Location loc, Value iv0, + ValueRange itrArgs0) { + // Create dilated[0] variables. + Value tmpIvs3 = iv0; + if (dilated1) { + tmpIvs3 = + builder.create(loc, iv0, dilHeight); + } + Value inputHeight = + builder.create(loc, tmpIvs1, tmpIvs3); + auto tmp1 = builder.create( + loc, ValueRange{c0}, builder.getDimIdentityMap(), + ValueRange{kernelWidth}, builder.getDimIdentityMap(), + /*Step=*/1, ValueRange{itrArgs0[0]}, + [&](OpBuilder &builder, Location loc, Value iv1, + ValueRange itrArgs1) { + // Create dilated[1] variables. + Value tmpIvs4 = iv1; + if (dilated2) { + tmpIvs4 = builder.create(loc, iv1, + dilWidth); + } + Value inputWidth = builder.create( + loc, tmpIvs2, tmpIvs4); + Value inputVector = builder.create( + loc, vectorTy, input, + ValueRange{ivs[0], inputHeight, inputWidth, + iv}); + // Max + Value resultVector; + if (auto ty = + llvm::dyn_cast(elementTy)) { + resultVector = builder.create( + loc, inputVector, itrArgs1[0]); + } else { + resultVector = builder.create( + loc, inputVector, itrArgs1[0]); + } + builder.create(loc, + resultVector); + }); + nestedBuilder.create( + loc, tmp1.getResult(0)); + }); + builder.create( + loc, tmp0.getResult(0), output, + ValueRange{ivs[0], ivs[1], ivs[2], iv}); + Value idx = + builder.create(loc, itrArgs[0], vlStep); + builder.create(loc, idx); + }); + // Compute the tail size and Process the remaining elements + // using masked vector operations. + Value idx = iterIdx.getResult(0); + Value tailSize = builder.create(loc, channels, idx); + Value tailCond = rewriter.create( + loc, arith::CmpIPredicate::sgt, tailSize, c0); + // If the current column does not reach the tail. + builder.create< + scf::IfOp>(loc, tailCond, [&](OpBuilder &builder, Location loc) { + // Create mask according to the tail. + Value tailMask = + builder.create(loc, vectorMaskTy, tailSize); + // Masked load output. + Value maskedOutputVec = builder.create( + loc, vectorTy, output, ValueRange{ivs[0], ivs[1], ivs[2], idx}, + tailMask, passThroughVec); + auto tmp0 = builder.create( + loc, ValueRange{c0}, builder.getDimIdentityMap(), + ValueRange{kernelHeight}, builder.getDimIdentityMap(), + /*Step=*/1, ValueRange{maskedOutputVec}, + [&](OpBuilder &builder, Location loc, Value iv0, + ValueRange itrArgs0) { + // Create dilated[0] variables. + Value tmpIvs3 = iv0; + if (dilated1) { + tmpIvs3 = + builder.create(loc, iv0, dilHeight); + } + Value inputHeight = + builder.create(loc, tmpIvs1, tmpIvs3); + auto tmp1 = builder.create( + loc, ValueRange{c0}, builder.getDimIdentityMap(), + ValueRange{kernelWidth}, builder.getDimIdentityMap(), + /*Step=*/1, ValueRange{itrArgs0[0]}, + [&](OpBuilder &builder, Location loc, Value iv1, + ValueRange itrArgs1) { + // Calculate the index of the input and + // output. + // Create dilated[1] variables. + Value tmpIvs4 = iv1; + if (dilated2) { + tmpIvs4 = + builder.create(loc, iv1, dilWidth); + } + Value inputWidth = + builder.create(loc, iv1, tmpIvs2); + // Masked load input and output. + Value maskedInputVec = builder.create( + loc, vectorTy, input, + ValueRange{ivs[0], inputHeight, inputWidth, idx}, + tailMask, passThroughVec); + // Max + Value resultVec; + if (auto ty = llvm::dyn_cast(elementTy)) { + resultVec = builder.create( + loc, maskedInputVec, itrArgs1[0]); + } else { + resultVec = builder.create( + loc, maskedInputVec, itrArgs1[0]); + } + builder.create(loc, resultVec); + }); + builder.create(loc, tmp1.getResult(0)); + }); + // Masked store the result to output. + builder.create( + loc, output, ValueRange{ivs[0], ivs[1], ivs[2], idx}, tailMask, + tmp0.getResult(0)); + builder.create(loc); + }); + }); + // Remove the origin convolution operation. + rewriter.eraseOp(op); + return success(); + } + +private: + int64_t strip; +}; +} // end anonymous namespace + +//===----------------------------------------------------------------------===// +// PoolingNhwcMaxVectorizationPass +//===----------------------------------------------------------------------===// + +/// This is a partial lowering linalg pooling max operations to mixture of +/// Arith + Vector operations. +namespace { +class PoolingNhwcMaxVectorizationPass + : public PassWrapper> { +public: + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(PoolingNhwcMaxVectorizationPass) + StringRef getArgument() const final { + return "pooling-nhwc-max-vectorization"; + } + StringRef getDescription() const final { + return "Pooling_Nhwc_Max vectorization."; + } + PoolingNhwcMaxVectorizationPass() = default; + PoolingNhwcMaxVectorizationPass(const PoolingNhwcMaxVectorizationPass &) {} + + void runOnOperation() override; + + void getDependentDialects(DialectRegistry ®istry) const override { + registry.insert(); + } + Option strip{*this, "vector-size", + llvm::cl::desc("Specify vector type size."), + llvm::cl::init(16)}; +}; +} // end anonymous namespace. + +void PoolingNhwcMaxVectorizationPass::runOnOperation() { + MLIRContext *context = &getContext(); + ModuleOp module = getOperation(); + + ConversionTarget target(*context); + target.addLegalDialect(); + target.addLegalOp(); + target.addLegalOp(); + + RewritePatternSet patterns(context); + patterns.add(context, strip); + + if (failed(applyPartialConversion(module, target, std::move(patterns)))) + signalPassFailure(); +} + +namespace mlir { +namespace buddy { +void registerPoolingNhwcMaxVectorizationPass() { + PassRegistration(); +} +} // namespace buddy +} // namespace mlir diff --git a/tests/Conversion/pooling-nhwc-max-vectorization.mlir b/tests/Conversion/pooling-nhwc-max-vectorization.mlir new file mode 100644 index 0000000000..0fa188eeb6 --- /dev/null +++ b/tests/Conversion/pooling-nhwc-max-vectorization.mlir @@ -0,0 +1,32 @@ +// RUN: buddy-opt -pooling-nhwc-max-vectorization %s | FileCheck %s + +// CHECK: #map = affine_map<(d0) -> (d0)> +// CHECK: module { +// CHECK: affine.for %arg3 = #map(%c0) to #map(%dim_5) { +// CHECK-NEXT: affine.for %arg4 = #map(%c0) to #map(%dim_6) { +// CHECK-NEXT: affine.for %arg5 = #map(%c0) to #map(%dim_7) { +// CHECK-NEXT: %3 = arith.muli %arg4, %c2 : index +// CHECK-NEXT: %4 = arith.muli %arg5, %c2_0 : index +// CHECK-NEXT: %5 = scf.for %arg6 = %c0 to %2 step %c16 iter_args(%arg7 = %c0) -> (index) { +// CHECK-NEXT: %8 = vector.load %arg2[%arg3, %arg4, %arg5, %arg6] : memref, vector<16xf32> +// CHECK-NEXT: %9 = affine.for %arg8 = #map(%c0) to #map(%dim) iter_args(%arg9 = %8) -> (vector<16xf32>) { +// CHECK-NEXT: %11 = arith.addi %3, %arg8 : index +// CHECK-NEXT: %12 = affine.for %arg10 = #map(%c0) to #map(%dim_4) iter_args(%arg11 = %arg9) -> (vector<16xf32>) { +// CHECK-NEXT: %13 = arith.addi %4, %arg10 : index +// CHECK-NEXT: %14 = vector.load %arg0[%arg3, %11, %13, %arg6] : memref, vector<16xf32> +// CHECK-NEXT: %15 = arith.maximumf %14, %arg11 : vector<16xf32> +// CHECK-NEXT: affine.yield %15 : vector<16xf32> +// CHECK-NEXT: } +// CHECK-NEXT: affine.yield %12 : vector<16xf32> +// CHECK-NEXT: } +// CHECK-NEXT: vector.store %9, %arg2[%arg3, %arg4, %arg5, %arg6] : memref, vector<16xf32> +// CHECK-NEXT: %10 = arith.addi %arg7, %c16 : index +// CHECK-NEXT: scf.yield %10 : index +// CHECK-NEXT: } + +func.func @pooling_nhwc_max(%a : memref, %b : memref, %c : memref) { + linalg.pooling_nhwc_max {dilations = dense<1> : vector<2xi64>, strides = dense<2> : vector<2xi64>} + ins(%a, %b : memref, memref) + outs(%c : memref) + return +} diff --git a/tools/buddy-opt/buddy-opt.cpp b/tools/buddy-opt/buddy-opt.cpp index 08e172f8bc..2c7de877c6 100644 --- a/tools/buddy-opt/buddy-opt.cpp +++ b/tools/buddy-opt/buddy-opt.cpp @@ -55,6 +55,7 @@ void registerConvVectorizationPass(); void registerPointwiseConvToGemmPass(); void registerPointwiseConvToGemmForNhwcFhwcPass(); void registerPoolingVectorizationPass(); +void registerPoolingNhwcMaxVectorizationPass(); void registerLowerBudPass(); void registerLowerDIPPass(); void registerBatchMatMulOptimizePass(); @@ -94,6 +95,8 @@ int main(int argc, char **argv) { mlir::buddy::registerConvVectorizationPass(); // Register Vectorization of Pooling. mlir::buddy::registerPoolingVectorizationPass(); + // Register Vectorization of Pooling Nhwc Max. + mlir::buddy::registerPoolingNhwcMaxVectorizationPass(); mlir::buddy::registerLowerBudPass(); mlir::buddy::registerLowerDIPPass(); mlir::buddy::registerLowerDAPPass(); From 8541a5ea49febb4a3d26a6a1689de53c75cb3608 Mon Sep 17 00:00:00 2001 From: zhanghb97 Date: Thu, 2 Jan 2025 15:17:08 +0000 Subject: [PATCH 2/8] [examples] Add transpose vectorization example. --- examples/BuddyNext/.gitignore | 4 +- examples/BuddyNext/makefile | 68 ++++++++++++++++++ .../BuddyNext/next-transpose-vec-manual.mlir | 68 ++++++++++++++++++ examples/BuddyNext/next-transpose.mlir | 70 +++++++++++++++++++ 4 files changed, 207 insertions(+), 3 deletions(-) create mode 100644 examples/BuddyNext/next-transpose-vec-manual.mlir create mode 100644 examples/BuddyNext/next-transpose.mlir diff --git a/examples/BuddyNext/.gitignore b/examples/BuddyNext/.gitignore index 0194ea7a68..80a243fa81 100644 --- a/examples/BuddyNext/.gitignore +++ b/examples/BuddyNext/.gitignore @@ -1,3 +1 @@ -log.mlir -log.ll -log.s +log.* diff --git a/examples/BuddyNext/makefile b/examples/BuddyNext/makefile index 3ee282499c..ce30c81b2b 100644 --- a/examples/BuddyNext/makefile +++ b/examples/BuddyNext/makefile @@ -9,6 +9,7 @@ OPT_FLAG := -O0 ifeq ($(shell uname),Linux) MLIR_RUNNER_UTILS := ../../llvm/build/lib/libmlir_runner_utils.so MLIR_C_RUNNER_UTILS := ../../llvm/build/lib/libmlir_c_runner_utils.so +LIB_OMP := ../../llvm/build/lib/libomp.so MTRIPLE := x86_64-unknown-linux-gnu else ifeq ($(shell uname),Darwin) MLIR_RUNNER_UTILS := ../../llvm/build/lib/libmlir_runner_utils.dylib @@ -313,6 +314,73 @@ next-sgemm-run: ${MLIR_CPU_RUNNER} ${OPT_FLAG} -e main -entry-point-result=void \ -shared-libs=${MLIR_RUNNER_UTILS} -shared-libs=${MLIR_C_RUNNER_UTILS} +next-transpose-lower: + @${MLIR_OPT} ./next-transpose.mlir \ + -pass-pipeline "builtin.module(func.func(tosa-to-linalg-named),func.func(tosa-to-linalg),func.func(tosa-to-tensor),func.func(tosa-to-arith))" | \ + ${MLIR_OPT} \ + -arith-expand \ + -eliminate-empty-tensors \ + -empty-tensor-to-alloc-tensor \ + -one-shot-bufferize \ + -func-bufferize \ + -arith-bufferize \ + -o log.mlir + +next-transpose-run: + @${MLIR_OPT} ./next-transpose.mlir \ + -pass-pipeline "builtin.module(func.func(tosa-to-linalg-named),func.func(tosa-to-linalg),func.func(tosa-to-tensor),func.func(tosa-to-arith))" | \ + ${MLIR_OPT} \ + -arith-expand \ + -eliminate-empty-tensors \ + -empty-tensor-to-alloc-tensor \ + -one-shot-bufferize \ + -func-bufferize \ + -arith-bufferize \ + -convert-linalg-to-affine-loops \ + -affine-loop-fusion \ + -lower-affine \ + -convert-vector-to-scf \ + -expand-strided-metadata \ + -convert-vector-to-llvm \ + -memref-expand \ + -arith-expand \ + -convert-arith-to-llvm \ + -finalize-memref-to-llvm \ + -convert-scf-to-cf \ + -convert-openmp-to-llvm \ + -convert-arith-to-llvm \ + -convert-math-to-llvm \ + -convert-math-to-libm \ + -convert-func-to-llvm \ + -reconcile-unrealized-casts | \ + ${MLIR_CPU_RUNNER} ${OPT_FLAG} -e main -entry-point-result=void \ + -shared-libs=${MLIR_RUNNER_UTILS} \ + -shared-libs=${MLIR_C_RUNNER_UTILS} + +next-transpose-vec-manual-run: + @${MLIR_OPT} ./next-transpose-vec-manual.mlir \ + -convert-linalg-to-affine-loops \ + -affine-loop-fusion \ + -lower-affine \ + -convert-scf-to-openmp \ + -convert-vector-to-scf \ + -expand-strided-metadata \ + -convert-vector-to-llvm \ + -memref-expand \ + -arith-expand \ + -convert-arith-to-llvm \ + -finalize-memref-to-llvm \ + -convert-scf-to-cf \ + -convert-openmp-to-llvm \ + -convert-arith-to-llvm \ + -convert-math-to-llvm \ + -convert-math-to-libm \ + -convert-func-to-llvm \ + -reconcile-unrealized-casts | \ + ${MLIR_CPU_RUNNER} ${OPT_FLAG} -e main -entry-point-result=void \ + -shared-libs=${MLIR_RUNNER_UTILS} \ + -shared-libs=${MLIR_C_RUNNER_UTILS} + next-embedding-lower: @${MLIR_OPT} ./next-embedding.mlir \ -pass-pipeline "builtin.module(func.func(tosa-to-linalg-named),func.func(tosa-to-linalg),func.func(tosa-to-tensor),func.func(tosa-to-arith))" | \ diff --git a/examples/BuddyNext/next-transpose-vec-manual.mlir b/examples/BuddyNext/next-transpose-vec-manual.mlir new file mode 100644 index 0000000000..ccf5c7b7e4 --- /dev/null +++ b/examples/BuddyNext/next-transpose-vec-manual.mlir @@ -0,0 +1,68 @@ +// RUN: buddy-opt %s \ +// RUN: -convert-linalg-to-affine-loops \ +// RUN: -affine-loop-fusion \ +// RUN: -lower-affine \ +// RUN: -func-bufferize \ +// RUN: -arith-bufferize \ +// RUN: -tensor-bufferize \ +// RUN: -buffer-deallocation \ +// RUN: -finalizing-bufferize \ +// RUN: -convert-vector-to-scf \ +// RUN: -expand-strided-metadata \ +// RUN: -convert-vector-to-llvm \ +// RUN: -memref-expand \ +// RUN: -arith-expand \ +// RUN: -convert-arith-to-llvm \ +// RUN: -finalize-memref-to-llvm \ +// RUN: -convert-scf-to-cf \ +// RUN: -convert-openmp-to-llvm \ +// RUN: -convert-arith-to-llvm \ +// RUN: -convert-math-to-llvm \ +// RUN: -convert-math-to-libm \ +// RUN: -convert-func-to-llvm \ +// RUN: -reconcile-unrealized-casts \ +// RUN: | mlir-cpu-runner -e main -entry-point-result=void \ +// RUN: -shared-libs=%mlir_runner_utils_dir/libmlir_runner_utils%shlibext \ +// RUN: -shared-libs=%mlir_runner_utils_dir/libmlir_c_runner_utils%shlibext \ +// RUN: | FileCheck %s + +module { + memref.global "private" constant @__constant_1x32x40x128xf32 : memref<1x32x40x128xf32> = dense<3.000000e+00> {alignment = 64 : i64} + func.func private @rtclock() -> f64 + func.func private @printMemrefF32(memref<*xf32>) + func.func @kernel(%arg0: memref<1x32x40x128xf32>) { + %0 = call @rtclock() : () -> f64 + %alloc = memref.alloc() {alignment = 64 : i64} : memref<1x40x32x128xf32> + affine.for %arg1 = 0 to 1 { + affine.for %arg2 = 0 to 40 { + affine.for %arg3 = 0 to 32 { + affine.for %arg4 = 0 to 128 step 64 { + %3 = vector.load %arg0[%arg1, %arg3, %arg2, %arg4] : memref<1x32x40x128xf32>, vector<64xf32> + vector.store %3, %alloc[%arg1, %arg2, %arg3, %arg4] : memref<1x40x32x128xf32>, vector<64xf32> + } + } + } + } + %1 = call @rtclock() : () -> f64 + %2 = arith.subf %1, %0 : f64 + %cast = memref.cast %alloc : memref<1x40x32x128xf32> to memref<*xf32> + + // All the elements of the MemRef are the same, + // only check the first line to verify the correctness. + // CHECK: Unranked Memref base@ = {{.*}} rank = 4 offset = 0 sizes = [1, 40, 32, 128] strides = [163840, 4096, 128, 1] data = + // CHECK-NEXT: [ + // CHECK-SAME: [ + // CHECK-SAME: [ + // CHECK-SAME: [3{{(, 3)*}}], + + call @printMemrefF32(%cast) : (memref<*xf32>) -> () + vector.print %2 : f64 + return + } + func.func @main() { + %0 = memref.get_global @__constant_1x32x40x128xf32 : memref<1x32x40x128xf32> + call @kernel(%0) : (memref<1x32x40x128xf32>) -> () + return + } +} + diff --git a/examples/BuddyNext/next-transpose.mlir b/examples/BuddyNext/next-transpose.mlir new file mode 100644 index 0000000000..1b2bd93d62 --- /dev/null +++ b/examples/BuddyNext/next-transpose.mlir @@ -0,0 +1,70 @@ +// RUN: buddy-opt %s \ +// RUN: -pass-pipeline "builtin.module(func.func(tosa-to-linalg-named),func.func(tosa-to-linalg),func.func(tosa-to-tensor),func.func(tosa-to-arith))" \ +// RUN: | buddy-opt \ +// RUN: -arith-expand \ +// RUN: -eliminate-empty-tensors \ +// RUN: -empty-tensor-to-alloc-tensor \ +// RUN: -one-shot-bufferize \ +// RUN: -convert-linalg-to-affine-loops \ +// RUN: -affine-loop-fusion \ +// RUN: -lower-affine \ +// RUN: -func-bufferize \ +// RUN: -arith-bufferize \ +// RUN: -tensor-bufferize \ +// RUN: -buffer-deallocation \ +// RUN: -finalizing-bufferize \ +// RUN: -convert-vector-to-scf \ +// RUN: -expand-strided-metadata \ +// RUN: -convert-vector-to-llvm \ +// RUN: -memref-expand \ +// RUN: -arith-expand \ +// RUN: -convert-arith-to-llvm \ +// RUN: -finalize-memref-to-llvm \ +// RUN: -convert-scf-to-cf \ +// RUN: -convert-openmp-to-llvm \ +// RUN: -convert-arith-to-llvm \ +// RUN: -convert-math-to-llvm \ +// RUN: -convert-math-to-libm \ +// RUN: -convert-func-to-llvm \ +// RUN: -reconcile-unrealized-casts \ +// RUN: | mlir-cpu-runner -e main -entry-point-result=void \ +// RUN: -shared-libs=%mlir_runner_utils_dir/libmlir_runner_utils%shlibext \ +// RUN: -shared-libs=%mlir_runner_utils_dir/libmlir_c_runner_utils%shlibext \ +// RUN: | FileCheck %s + +func.func private @rtclock() -> f64 +func.func private @printMemrefF32(%ptr : tensor<*xf32>) + +func.func @kernel(%t0 : tensor<1x32x40x128xf32>) { + %t_start = call @rtclock() : () -> f64 + + %idx = "tosa.const"() <{value = dense<[0, 2, 1, 3]> : tensor<4xi32>}> : () -> tensor<4xi32> + %t1 = tosa.transpose %t0, %idx : (tensor<1x32x40x128xf32>, tensor<4xi32>) -> tensor<1x40x32x128xf32> + + %t_end = call @rtclock() : () -> f64 + %time = arith.subf %t_end, %t_start : f64 + + %tensor_unranked = tensor.cast %t1 : tensor<1x40x32x128xf32> to tensor<*xf32> + + // All the elements of the MemRef are the same, + // only check the first line to verify the correctness. + // CHECK: Unranked Memref base@ = {{.*}} rank = 4 offset = 0 sizes = [1, 40, 32, 128] strides = [163840, 4096, 128, 1] data = + // CHECK-NEXT: [ + // CHECK-SAME: [ + // CHECK-SAME: [ + // CHECK-SAME: [3{{(, 3)*}}], + + // Print results. + call @printMemrefF32(%tensor_unranked) : (tensor<*xf32>) -> () + // Print timings. + vector.print %time : f64 + + return +} + +func.func @main() { + %c0 = arith.constant dense<3.0> : tensor<1x32x40x128xf32> + call @kernel(%c0) : (tensor<1x32x40x128xf32>) -> () + + return +} From 3e7ba5039d01af9e7abce5889126c7bc5abc2153 Mon Sep 17 00:00:00 2001 From: zhanghb97 Date: Thu, 2 Jan 2025 16:00:36 +0000 Subject: [PATCH 3/8] [examples] Reorder llama pass pipeline. --- examples/BuddyLlama/CMakeLists.txt | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/examples/BuddyLlama/CMakeLists.txt b/examples/BuddyLlama/CMakeLists.txt index 6953b7de7d..b7d720685c 100644 --- a/examples/BuddyLlama/CMakeLists.txt +++ b/examples/BuddyLlama/CMakeLists.txt @@ -58,6 +58,11 @@ add_custom_command( -eliminate-empty-tensors -empty-tensor-to-alloc-tensor -one-shot-bufferize + -func-bufferize-dynamic-offset + -tensor-bufferize + -arith-bufferize + -buffer-deallocation + -finalizing-bufferize -matmul-parallel-vectorization-optimize -batchmatmul-optimize -convert-linalg-to-affine-loops @@ -65,11 +70,6 @@ add_custom_command( -affine-parallelize -lower-affine -convert-scf-to-openmp - -func-bufferize-dynamic-offset - -tensor-bufferize - -arith-bufferize - -buffer-deallocation - -finalizing-bufferize -convert-vector-to-scf -expand-strided-metadata -cse From 2776c10484b79323e4eeb6390c7533d32edbda8d Mon Sep 17 00:00:00 2001 From: Wu Xintong <13683168028@163.com> Date: Fri, 3 Jan 2025 16:27:21 +0800 Subject: [PATCH 4/8] [frontend] Update graph for op fusion (#445) --- Co-authored-by: zhxzh-2001 <70198007+zhxzh-2001@users.noreply.github.com> --- examples/BuddyLeNet/buddy-lenet-import.py | 2 +- examples/BuddyLlama/import-llama2.py | 2 +- frontend/Python/frontend.py | 45 +++++---- frontend/Python/graph/graph.py | 104 +++++++++++++++++--- frontend/Python/graph/operation.py | 6 ++ frontend/Python/graph/transform/__init__.py | 2 +- frontend/Python/graph/transform/fuse_ops.py | 89 ++++++++++++++++- frontend/Python/ops/linalg.py | 21 ++++ tests/Python/test_permute_matmul_fusion.py | 40 ++++++++ 9 files changed, 270 insertions(+), 41 deletions(-) create mode 100644 tests/Python/test_permute_matmul_fusion.py diff --git a/examples/BuddyLeNet/buddy-lenet-import.py b/examples/BuddyLeNet/buddy-lenet-import.py index c787061a55..e4f85f905f 100644 --- a/examples/BuddyLeNet/buddy-lenet-import.py +++ b/examples/BuddyLeNet/buddy-lenet-import.py @@ -26,7 +26,7 @@ from buddy.compiler.frontend import DynamoCompiler from buddy.compiler.graph import GraphDriver -from buddy.compiler.graph.transform import simply_fuse +from buddy.compiler.graph.transform import simply_fuse, apply_classic_fusion from buddy.compiler.ops import tosa from model import LeNet diff --git a/examples/BuddyLlama/import-llama2.py b/examples/BuddyLlama/import-llama2.py index d893ee87f6..af89329e62 100644 --- a/examples/BuddyLlama/import-llama2.py +++ b/examples/BuddyLlama/import-llama2.py @@ -28,7 +28,7 @@ from buddy.compiler.frontend import DynamoCompiler from buddy.compiler.ops import tosa from buddy.compiler.graph import GraphDriver -from buddy.compiler.graph.transform import simply_fuse +from buddy.compiler.graph.transform import simply_fuse, apply_classic_fusion # Retrieve the LLaMA model path from environment variables. model_path = os.environ.get("LLAMA_MODEL_PATH") diff --git a/frontend/Python/frontend.py b/frontend/Python/frontend.py index f5a17a1c31..c11843eab7 100644 --- a/frontend/Python/frontend.py +++ b/frontend/Python/frontend.py @@ -165,9 +165,9 @@ def __init__( "cos.default": CosOp, "sin.default": SinOp, "argmax.default": ArgMaxOp, - "split.Tensor":SplitOp, - "max.default":MaxOp, - "gt.Scalar":GtOp, + "split.Tensor": SplitOp, + "max.default": MaxOp, + "gt.Scalar": GtOp, "_scaled_dot_product_flash_attention_for_cpu.default": ScaledDotProductFlashAttentionForCpuOp, "ge.Scalar": GeOp, "gt.Tensor": GreaterThanOp, @@ -237,7 +237,9 @@ def _create_node( buddy_node.add_argument(str(input_arg)) buddy_node.add_parent(str(input_arg)) elif isinstance(input_arg, torch.dtype): - buddy_node.add_argument(self._torch_dtype_translate(str(input_arg))) + buddy_node.add_argument( + self._torch_dtype_translate(str(input_arg)) + ) else: buddy_node.add_argument(input_arg) for user in node_users: @@ -294,7 +296,7 @@ def _compiler(_gm: torch.fx.GraphModule, _inputs: List[torch.Tensor]): nonlocal params_flat func_inputs = [] for i in inputs_pos: - # for inp in _inputs[len(params_flat) :]: + # for inp in _inputs[len(params_flat) :]: inp = _inputs[i] inp_shape = inp.shape inp_dtype = self._torch_dtype_translate(str(inp.dtype)) @@ -308,7 +310,7 @@ def _compiler(_gm: torch.fx.GraphModule, _inputs: List[torch.Tensor]): fake_params, self._ops_registry, self._func_name, - self._verbose + self._verbose, ) param_nodes = [] buffers_nodes = [] @@ -344,10 +346,7 @@ def _compiler(_gm: torch.fx.GraphModule, _inputs: List[torch.Tensor]): elif gm_node.op == "output": buddy_node = self._create_node( - gm_node.op, - gm_node.name, - gm_node.args, - node_users + gm_node.op, gm_node.name, gm_node.args, node_users ) elif gm_node.target is operator.getitem: @@ -367,7 +366,11 @@ def _compiler(_gm: torch.fx.GraphModule, _inputs: List[torch.Tensor]): tensor_meta = gm_node.meta.get("tensor_meta") val = gm_node.meta.get("val") # num_returns = len(gm_node.target._schema.returns) - num_returns = len(val) if isinstance(val, list) else len(gm_node.target._schema.returns) + num_returns = ( + len(val) + if isinstance(val, list) + else len(gm_node.target._schema.returns) + ) if num_returns == 1: node_dtype = self._torch_dtype_translate( str(tensor_meta.dtype) @@ -477,7 +480,7 @@ def get_lib_extension(): def cast_c_ptr(outdata_ptr, memref_ptr): """ - Casts a C pointer (`outdata_ptr`) to the type of another C pointer + Casts a C pointer (`outdata_ptr`) to the type of another C pointer (`memref_ptr`). Args: @@ -488,14 +491,14 @@ def cast_c_ptr(outdata_ptr, memref_ptr): Returns: ctypes.POINTER - A new C pointer with the type of `memref_ptr`, representing the + A new C pointer with the type of `memref_ptr`, representing the same memory location as `outdata_ptr`. Example: outdata = ctypes.pointer(ctypes.c_int()) memref = ctypes.pointer(ctypes.c_float()) casted_ptr = cast_c_ptr(outdata, memref) - # Now `casted_ptr` points to the same memory location as `outdata`, + # Now `casted_ptr` points to the same memory location as `outdata`, but with the type of `memref`. """ outdata_addr = ctypes.addressof(outdata_ptr.contents) @@ -504,15 +507,15 @@ def cast_c_ptr(outdata_ptr, memref_ptr): def move_c_ptr(outdata_ptr, memref_ptr): """ - Moves a C pointer (`outdata_ptr`) to the next element in memory, - based on the size of the referenced type in another C pointer + Moves a C pointer (`outdata_ptr`) to the next element in memory, + based on the size of the referenced type in another C pointer (`memref_ptr`). Args: outdata_ptr: ctypes.POINTER The C pointer whose position needs to be moved. memref_ptr: ctypes.POINTER - The reference C pointer whose type determines the size of each + The reference C pointer whose type determines the size of each element for the move. Returns: @@ -535,7 +538,7 @@ def exec_buddy_graph(*args): Returns: List[torch.Tensor] - The result of executing the graph, represented as a list of + The result of executing the graph, represented as a list of output tensors. """ # A list of ctypes pointers representing memory references for input @@ -548,13 +551,13 @@ def exec_buddy_graph(*args): ) for tensor in args ] - # A list of ctypes pointers representing memory references for + # A list of ctypes pointers representing memory references for # output tensors. output_memref = [ ctypes.pointer(ctypes.pointer(graph._output_descriptor())) ] args_memref = output_memref + input_memref - # Invoke the graph's function using the provided execution engine + # Invoke the graph's function using the provided execution engine # and memory references ee.invoke(graph._func_name, *args_memref) @@ -571,7 +574,7 @@ def exec_buddy_graph(*args): # Move to the next element in memory based on the size of the # current output type outdata_ptr = move_c_ptr(outdata_ptr, output_ptr[0]) - # Convert each NumPy array to a PyTorch tensor and return the list + # Convert each NumPy array to a PyTorch tensor and return the list # of tensors return [torch.from_numpy(tensor) for tensor in output_tensor] diff --git a/frontend/Python/graph/graph.py b/frontend/Python/graph/graph.py index 88c6a85df6..751ddb0066 100644 --- a/frontend/Python/graph/graph.py +++ b/frontend/Python/graph/graph.py @@ -105,7 +105,7 @@ def __init__( fake_params: List[TensorMeta], ops_registry: dict, func_name: str, - verbose=False + verbose=False, ) -> None: """ Initializes the Graph. @@ -164,6 +164,78 @@ def add_node(self, node: Op): self._body.append(node) self.node_table[node.name] = node + def check_delete_node(self, node: Op) -> bool: + """ + Determines if a node exists in the graph and has no child nodes. + + Args: + node (Op): The operation node to check for deletion eligibility. + + Returns: + bool: True if the node exists in the graph and has no children. + """ + if not (node.name in self.node_table): + raise KeyError("node{0} not in graph".format(node.name)) + + if len(node._children) == 0: + return True + return False + + def delete_node(self, node: Op, parents: List[Op]): + """ + Removes a node from the graph and updates its parent nodes accordingly. + + Args: + node (Op): The operation node to be deleted from the graph. + parents (List[Op]): A list of parent operation nodes that reference the node to be deleted. + + Returns: + None + """ + for i in parents: + i._children.remove(node.name) + node.args.clear() + node.kwargs.clear() + node._children.clear() + self._body.remove(node) + self.node_table.pop(node.name) + + def displace_node(self, node: Op, newnode: Op): + """ + Replaces an existing node with a new node in the graph. + + Args: + node (Op): The operation node to be replaced. + newnode (Op): The new operation node that will replace the existing node. + + Returns: + None + """ + newnode._arguments = node.args + newnode._keyword_arguments = node.kwargs + newnode._tensor_meta = node.tensor_meta + newnode._op_type = node._op_type + + for i in node._children: + newnode.add_children(i) + users = [self.node_table[i] for i in node._children] + for user in users: + if node.name in user._parents: + user._parents[user._parents.index(node.name)] = newnode.name + user.args[user.args.index(node.name)] = newnode.name + node._children.clear() + # deal with parents+args + for i in node._parents: + newnode.add_parent(i) + parents = [self.node_table[i] for i in node._parents] + for parent in parents: + parent._children[parent._children.index(node.name)] = newnode.name + node._parents.clear() + # update node table + self._body[self._body.index(node)] = newnode + self.node_table.pop(node.name) + self.node_table[newnode.name] = newnode + def init_op_group(self): """ Initializes operation groups within the graph. @@ -239,7 +311,7 @@ def lower_to_top_level_ir(self): self._inputs, self._func_name, self._ops_registry, - verbose=self._verbose + verbose=self._verbose, ) self._imported_module = fx_importer.import_graph() outputs = fx_importer.get_output_nodes() @@ -352,7 +424,7 @@ def __init__( func_name: str, ops_registry: dict, do_param_pack: bool = False, - verbose=False + verbose=False, ): """ Initializes the buddy Graph importer. @@ -475,27 +547,27 @@ def generated_func(*args): elif isinstance(node, PlaceholderOp): self._import_placeholder(node, args_list) elif isinstance(node, GetItemOp): - self._symbol_table[ - (str(node.name), 0) - ] = self._symbol_table[ - (str(node.args[0]), node.args[1]) - ] + self._symbol_table[(str(node.name), 0)] = ( + self._symbol_table[ + (str(node.args[0]), node.args[1]) + ] + ) else: self._import_op(node) new_ops = [op for op in func_op.body.blocks[0].operations] if self._verbose: - print('='*20 + "Graph Node" + "="*20) + print("=" * 20 + "Graph Node" + "=" * 20) print("Node: " + node.name) print("Type: " + str(node._op_type)) print("Arguments: " + str(node.args)) print("Parents: " + str(node._parents)) print("Children: " + str(node._children)) - print('-'*20 + "MLIR OPS" + '-'*20) + print("-" * 20 + "MLIR OPS" + "-" * 20) for op in new_ops: if op not in old_ops: print(op) print("") - + return self._symbol_table.get(("output", 0)) return self._module @@ -544,11 +616,11 @@ def generated_func(*args): elif isinstance(node, PlaceholderOp): self._import_placeholder(node, args_list) elif isinstance(node, GetItemOp): - self._symbol_table[ - (str(node.name), 0) - ] = self._symbol_table[ - (str(node.args[0]), node.args[1]) - ] + self._symbol_table[(str(node.name), 0)] = ( + self._symbol_table[ + (str(node.args[0]), node.args[1]) + ] + ) else: self._import_op(node) diff --git a/frontend/Python/graph/operation.py b/frontend/Python/graph/operation.py index c1a7b09746..218752abc0 100644 --- a/frontend/Python/graph/operation.py +++ b/frontend/Python/graph/operation.py @@ -154,6 +154,12 @@ def __init__(self) -> None: self._op_type = OpType.ReduceType +class TransposeMatmulFusedOp(Op): + def __init__(self) -> None: + super().__init__() + self._op_type = OpType.ReduceType + + class GetItemOp(Op): def __init__(self) -> None: super().__init__() diff --git a/frontend/Python/graph/transform/__init__.py b/frontend/Python/graph/transform/__init__.py index d91e0d06b2..a1e294f8cd 100644 --- a/frontend/Python/graph/transform/__init__.py +++ b/frontend/Python/graph/transform/__init__.py @@ -18,5 +18,5 @@ # # ===--------------------------------------------------------------------------- -from .fuse_ops import simply_fuse +from .fuse_ops import simply_fuse, apply_classic_fusion from .useless_op_eliminate import maxpool2d_simplify diff --git a/frontend/Python/graph/transform/fuse_ops.py b/frontend/Python/graph/transform/fuse_ops.py index ac7d34c99c..992168aecc 100644 --- a/frontend/Python/graph/transform/fuse_ops.py +++ b/frontend/Python/graph/transform/fuse_ops.py @@ -21,12 +21,99 @@ from .. import Graph from ..operation import * from .. import DeviceType +from torch.fx.immutable_collections import immutable_list + +classicfuse_register = {"transpose_matmul_fusion": TransposeMatmulFusedOp} # TODO: classify op type for op fusion # OP_TYPE_FUSABLE = [OpType.BroadcastType, OpType.ElementwiseType, OpType.ReshapeType] # OP_TYPE_UNFUSABLE = [OpType.Unfusable, OpType.ConcatType] # OP_TYPE_FUSABLE_BY_SPECIFIC_PASS = [] -# ANCHOR_OP_TYPE = [] +# ANCHOR_OP_TYPE = [] + + +def classic_fuse_check(graph: Graph): + """ + Function to identifies and fuses PermuteOp operations with preceding + MatmulOp operations in a computation graph to optimize performance. + + Args: + graph (Graph): The computation graph to analyze and optimize. + + Returns: + None + """ + for op in graph.body: + pattern = None + if isinstance(op, MatmulOp): + parentop = [graph.node_table[str(i)] for i in op._parents] + for target in parentop: + if isinstance(target, PermuteOp) and target.args[ + 1 + ] == immutable_list([1, 0]): + pattern = target, parentop, "transpose_matmul_fusion" + if pattern: + transpose_matmul_fusion( + graph, op, pattern[0], pattern[1], pattern[2] + ) + + +def transpose_matmul_fusion( + graph: Graph, node, target: Op, parents: List[Op], pattern: str +): + """ + Function to fuse some typical operations into one operation. + Such as transpose + matmul + Args: + - graph (Graph): The input graph to be simplified. + - node (Op): The operation to be fused. + - target (Op): The target operation to be fused. + - parents (List[Op]): The parents of the node to be fused. + - pattern (str): The pattern of the fusion. + Returns: + - None: Modifies the input graph in place. + """ + fused_op = classicfuse_register.get(pattern)() + # matmulop -> fusedmatmulopnode + fused_op.name = "fused" + node.name + graph.displace_node(node, fused_op) + fused_op.args.pop(fused_op.args.index(target.name)) + fused_op._parents.pop(fused_op._parents.index(target.name)) + fused_op.args.extend(target.args) + + fused_op._parents.extend(target._parents) + targets_parent = [graph.node_table[i] for i in target._parents] + for i in targets_parent: + i.add_children(fused_op.name) + target._children.pop(target._children.index(fused_op.name)) + + if graph.check_delete_node(target): + graph.delete_node(target, targets_parent) + + +def apply_classic_fusion(graph: Graph): + """ + Function to fuse some typical operations into one operation and fuse + all operations into one graph. + + Args: + - graph (Graph): The input graph to be simplified. + + Returns: + - None: Modifies the input graph in place. + """ + new_op_group = [] + device = DeviceType.UNKNOW + # Run the first round of op fusion + classic_fuse_check(graph) + for op in graph.body: + if isinstance(op, PlaceholderOp): + continue + new_op_group.append(op) + graph.op_groups = {} + graph.op_groups["subgraph0"] = new_op_group + graph.group_map_device = {"subgraph0": device} + def simply_fuse(graph: Graph): """ diff --git a/frontend/Python/ops/linalg.py b/frontend/Python/ops/linalg.py index ec6c827e6c..6bd3a2f318 100644 --- a/frontend/Python/ops/linalg.py +++ b/frontend/Python/ops/linalg.py @@ -1171,6 +1171,26 @@ def matmul_op( return op +def matmul_transpose_b_op( + node: TransposeMatmulFusedOp, + symbol_table: Dict[Tuple[str, int], ir.Operation], +): + input1 = symbol_table.get((str(node.args[0]), 0)) + input2 = symbol_table.get((str(node.args[1]), 0)) + + if input1 is None or input2 is None: + return + output_shape = list(node.tensor_meta["shape"]) + dtype = node.tensor_meta["dtype"] + mlir_dtype = mlir_element_type_get(dtype) + tensor_type = ir.RankedTensorType.get(output_shape, mlir_dtype) + element = mlir_element_attr_get(dtype, 0.0) + attr = ir.DenseElementsAttr.get_splat(tensor_type, element) + result_buffer = arith.ConstantOp(tensor_type, attr).result + op = linalg.matmul_transpose_b(input1, input2, outs=[result_buffer]) + return op + + def transpose_op( node: TransposeOp, symbol_table: Dict[Tuple[str, int], ir.Operation], @@ -2344,6 +2364,7 @@ def unsafe_index_op( ops_registry = { "MatmulOp": matmul_op, + "TransposeMatmulFusedOp": matmul_transpose_b_op, "ArangeOp": arange_op, "UnsqueezeOp": unsqueeze_op, "ViewOp": view_op, diff --git a/tests/Python/test_permute_matmul_fusion.py b/tests/Python/test_permute_matmul_fusion.py new file mode 100644 index 0000000000..70f120e5a4 --- /dev/null +++ b/tests/Python/test_permute_matmul_fusion.py @@ -0,0 +1,40 @@ +# RUN: %PYTHON %s 2>&1 | FileCheck %s + +import torch +import torch._dynamo as dynamo +from torch._inductor.decomposition import decompositions as inductor_decomp +from torch._functorch.aot_autograd import aot_autograd_decompositions + +from buddy.compiler.frontend import DynamoCompiler +from buddy.compiler.ops import linalg +from buddy.compiler.graph.transform import simply_fuse, apply_classic_fusion + +def foo(m1, m2,map): + tmp = torch.ops.aten.permute(m2,map) + return torch.matmul(m1,tmp) + +m1 = torch.ones([3, 4], dtype=torch.float32) +m2 = torch.ones([3, 4], dtype=torch.float32) +map = (1,0) +# Initialize the dynamo compiler. +dynamo_compiler = DynamoCompiler( + primary_registry=linalg.ops_registry, + aot_autograd_decomposition=aot_autograd_decompositions, +) + +graphs = dynamo_compiler.importer(foo, m1,m2,map) +assert len(graphs) == 1 +graph = graphs[0] +pattern_list = [apply_classic_fusion] +graphs[0].fuse_ops(pattern_list) + +graph.lower_to_top_level_ir() +print(graph._imported_module) + +# CHECK: module { +# CHECK-LABEL: func.func @forward +# CHECK: %{{.*}} = arith.constant +# CHECK: %{{.*}} = linalg.matmul_transpose_b +# CHECK: return %{{.*}} +# CHECK: } +# CHECK: } From 75021013b408099a2ad643ad40098e061c406815 Mon Sep 17 00:00:00 2001 From: Junyi Mei Date: Fri, 3 Jan 2025 17:22:00 +0800 Subject: [PATCH 5/8] [Python] Fix python bindings for out-of-tree dialects and passes (#440) * [Python] support gemmini * [Python] support python APIs for buddy-mlir * [Python] fix a grammarly bug * [Python] Format & use Lit and FileCheck to test the code * [Pybind] register everything once * [Pybind] update submodule llvm * [Python] Fix bug in python binding test Current implementation of python binding will place the buddy_mlir package under 'BUDDY_BUILD_DIR/midend/python_packages', which is inconsistent with the frontend an leads to test failures. This commit changes the output directory of buddy_mlir module. The 'buddy_mlir' will be placed under 'BUDDY_BUILD_DIR/python_packages' together with 'buddy' module. Signed-off-by: Junyi Mei * [NFC] Format python bindings Signed-off-by: Junyi Mei * [NFC] Format cpp files * [NFC] Format & add missing newlines Signed-off-by: Junyi Mei --------- Signed-off-by: Junyi Mei Co-authored-by: qzylalala <304228244@qq.com> --- midend/CMakeLists.txt | 12 +- midend/include/Dialect/Bud/BudOps.td | 2 +- midend/include/Dialect/DAP/DAPOps.td | 2 +- midend/include/Dialect/DIP/DIPOps.td | 2 +- .../include/Dialect/Gemmini/GemminiDialect.h | 3 +- midend/include/Dialect/Sche/ScheOps.td | 2 +- .../include/Dialect/VectorExp/VectorExpOps.td | 2 +- midend/include/buddy-mlir-c/Dialects.h | 38 +++++ midend/include/buddy-mlir-c/InitAll.h | 31 ++++ .../include/buddy-mlir-c/RegisterEverything.h | 40 +++++ .../Bindings/Python/RegisterEverything.cpp | 31 ++++ midend/lib/CAPI/CMakeLists.txt | 33 +++++ midend/lib/CAPI/Dialects.cpp | 43 ++++++ midend/lib/CAPI/RegisterEverything.cpp | 70 +++++++++ midend/lib/CMakeLists.txt | 38 +++++ midend/lib/InitAll.cpp | 83 +++++++++++ midend/python/CMakeLists.txt | 138 ++++++++++++++++++ .../python/buddy_mlir/dialects/BudBinding.td | 23 +++ .../python/buddy_mlir/dialects/DAPBinding.td | 23 +++ .../python/buddy_mlir/dialects/DIPBinding.td | 23 +++ .../buddy_mlir/dialects/GemminiBinding.td | 23 +++ .../python/buddy_mlir/dialects/RVVBinding.td | 23 +++ .../python/buddy_mlir/dialects/ScheBinding.td | 23 +++ .../buddy_mlir/dialects/VectorExpBinding.td | 23 +++ midend/python/buddy_mlir/dialects/bud.py | 17 +++ midend/python/buddy_mlir/dialects/dap.py | 17 +++ midend/python/buddy_mlir/dialects/dip.py | 17 +++ midend/python/buddy_mlir/dialects/gemmini.py | 17 +++ midend/python/buddy_mlir/dialects/rvv.py | 17 +++ midend/python/buddy_mlir/dialects/sche.py | 17 +++ .../python/buddy_mlir/dialects/vector_exp.py | 17 +++ tests/Python/test_python.py | 49 +++++++ 32 files changed, 892 insertions(+), 7 deletions(-) create mode 100644 midend/include/buddy-mlir-c/Dialects.h create mode 100644 midend/include/buddy-mlir-c/InitAll.h create mode 100644 midend/include/buddy-mlir-c/RegisterEverything.h create mode 100644 midend/lib/Bindings/Python/RegisterEverything.cpp create mode 100644 midend/lib/CAPI/CMakeLists.txt create mode 100644 midend/lib/CAPI/Dialects.cpp create mode 100644 midend/lib/CAPI/RegisterEverything.cpp create mode 100644 midend/lib/InitAll.cpp create mode 100644 midend/python/CMakeLists.txt create mode 100644 midend/python/buddy_mlir/dialects/BudBinding.td create mode 100644 midend/python/buddy_mlir/dialects/DAPBinding.td create mode 100644 midend/python/buddy_mlir/dialects/DIPBinding.td create mode 100644 midend/python/buddy_mlir/dialects/GemminiBinding.td create mode 100644 midend/python/buddy_mlir/dialects/RVVBinding.td create mode 100644 midend/python/buddy_mlir/dialects/ScheBinding.td create mode 100644 midend/python/buddy_mlir/dialects/VectorExpBinding.td create mode 100644 midend/python/buddy_mlir/dialects/bud.py create mode 100644 midend/python/buddy_mlir/dialects/dap.py create mode 100644 midend/python/buddy_mlir/dialects/dip.py create mode 100644 midend/python/buddy_mlir/dialects/gemmini.py create mode 100644 midend/python/buddy_mlir/dialects/rvv.py create mode 100644 midend/python/buddy_mlir/dialects/sche.py create mode 100644 midend/python/buddy_mlir/dialects/vector_exp.py create mode 100644 tests/Python/test_python.py diff --git a/midend/CMakeLists.txt b/midend/CMakeLists.txt index c847a9fa37..37eb612e29 100644 --- a/midend/CMakeLists.txt +++ b/midend/CMakeLists.txt @@ -1,2 +1,12 @@ add_subdirectory(include) -add_subdirectory(lib) \ No newline at end of file +add_subdirectory(lib) + +if(MLIR_ENABLE_BINDINGS_PYTHON) + include(MLIRDetectPythonEnv) + mlir_detect_pybind11_install() + find_package(Python3 ${LLVM_MINIMUM_PYTHON_VERSION} + COMPONENTS Interpreter Development NumPy REQUIRED) + find_package(pybind11 2.6 CONFIG REQUIRED) + message(STATUS "Enabling Python API") + add_subdirectory(python) +endif() diff --git a/midend/include/Dialect/Bud/BudOps.td b/midend/include/Dialect/Bud/BudOps.td index abcf953975..1a18455dee 100644 --- a/midend/include/Dialect/Bud/BudOps.td +++ b/midend/include/Dialect/Bud/BudOps.td @@ -21,7 +21,7 @@ #ifndef BUD_BUDOPS_TD #define BUD_BUDOPS_TD -include "BudDialect.td" +include "Bud/BudDialect.td" include "mlir/Interfaces/InferTypeOpInterface.td" include "mlir/Interfaces/SideEffectInterfaces.td" include "mlir/IR/EnumAttr.td" diff --git a/midend/include/Dialect/DAP/DAPOps.td b/midend/include/Dialect/DAP/DAPOps.td index d14ca5cfcd..54dc44632f 100644 --- a/midend/include/Dialect/DAP/DAPOps.td +++ b/midend/include/Dialect/DAP/DAPOps.td @@ -21,7 +21,7 @@ #ifndef DAP_DAPOPS_TD #define DAP_DAPOPS_TD -include "DAPDialect.td" +include "DAP/DAPDialect.td" include "mlir/Interfaces/InferTypeOpInterface.td" include "mlir/Interfaces/SideEffectInterfaces.td" include "mlir/IR/EnumAttr.td" diff --git a/midend/include/Dialect/DIP/DIPOps.td b/midend/include/Dialect/DIP/DIPOps.td index 179e66359d..b5f928b888 100644 --- a/midend/include/Dialect/DIP/DIPOps.td +++ b/midend/include/Dialect/DIP/DIPOps.td @@ -21,7 +21,7 @@ #ifndef DIP_DIPOPS_TD #define DIP_DIPOPS_TD -include "DIPDialect.td" +include "DIP/DIPDialect.td" include "mlir/Interfaces/InferTypeOpInterface.td" include "mlir/Interfaces/SideEffectInterfaces.td" include "mlir/IR/EnumAttr.td" diff --git a/midend/include/Dialect/Gemmini/GemminiDialect.h b/midend/include/Dialect/Gemmini/GemminiDialect.h index 04e2a9dec7..9f992aab6c 100644 --- a/midend/include/Dialect/Gemmini/GemminiDialect.h +++ b/midend/include/Dialect/Gemmini/GemminiDialect.h @@ -17,7 +17,8 @@ #ifndef GEMMINI_GEMMINIOPS_H #define GEMMINI_GEMMINIOPS_H -#include "Gemmini/GemminiDialect.h.inc" #include "mlir/IR/Dialect.h" +#include "Gemmini/GemminiDialect.h.inc" + #endif diff --git a/midend/include/Dialect/Sche/ScheOps.td b/midend/include/Dialect/Sche/ScheOps.td index 4db3c3815b..c728c12fdb 100644 --- a/midend/include/Dialect/Sche/ScheOps.td +++ b/midend/include/Dialect/Sche/ScheOps.td @@ -21,7 +21,7 @@ #ifndef SCHE_SCHEOPS_TD #define SCHE_SCHEOPS_TD -include "ScheDialect.td" +include "Sche/ScheDialect.td" include "mlir/Interfaces/InferTypeOpInterface.td" include "mlir/Interfaces/SideEffectInterfaces.td" include "mlir/Interfaces/ControlFlowInterfaces.td" diff --git a/midend/include/Dialect/VectorExp/VectorExpOps.td b/midend/include/Dialect/VectorExp/VectorExpOps.td index aeacba34d0..2b3d177790 100644 --- a/midend/include/Dialect/VectorExp/VectorExpOps.td +++ b/midend/include/Dialect/VectorExp/VectorExpOps.td @@ -21,7 +21,7 @@ #ifndef VECTOREXP_VECTOREXPOPS_TD #define VECTOREXP_VECTOREXPOPS_TD -include "VectorExpDialect.td" +include "VectorExp/VectorExpDialect.td" include "mlir/Interfaces/InferTypeOpInterface.td" include "mlir/Interfaces/SideEffectInterfaces.td" diff --git a/midend/include/buddy-mlir-c/Dialects.h b/midend/include/buddy-mlir-c/Dialects.h new file mode 100644 index 0000000000..8129154be2 --- /dev/null +++ b/midend/include/buddy-mlir-c/Dialects.h @@ -0,0 +1,38 @@ +//===----------------- Dialects.h - CAPI for dialects ---------------------===// +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +//===----------------------------------------------------------------------===// + +#ifndef BUDDYMLIR_C_DIALECTS_H +#define BUDDYMLIR_C_DIALECTS_H + +#include "mlir-c/IR.h" + +#ifdef __cplusplus +extern "C" { +#endif + +MLIR_DECLARE_CAPI_DIALECT_REGISTRATION(Bud, bud); +MLIR_DECLARE_CAPI_DIALECT_REGISTRATION(DAP, dap); +MLIR_DECLARE_CAPI_DIALECT_REGISTRATION(DIP, dip); +MLIR_DECLARE_CAPI_DIALECT_REGISTRATION(Gemmini, gemmini); +MLIR_DECLARE_CAPI_DIALECT_REGISTRATION(RVV, rvv); +MLIR_DECLARE_CAPI_DIALECT_REGISTRATION(Sche, sche); +MLIR_DECLARE_CAPI_DIALECT_REGISTRATION(VectorExp, vector_exp); + +#ifdef __cplusplus +} +#endif + +#endif // BUDDYMLIR_C_DIALECTS_H diff --git a/midend/include/buddy-mlir-c/InitAll.h b/midend/include/buddy-mlir-c/InitAll.h new file mode 100644 index 0000000000..087d24a9f7 --- /dev/null +++ b/midend/include/buddy-mlir-c/InitAll.h @@ -0,0 +1,31 @@ +//===------------- InitAll.h - Register all dialects and passes -----------===// +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +//===----------------------------------------------------------------------===// + +#ifndef BUDDY_MLIR_INITALL_H +#define BUDDY_MLIR_INITALL_H + +#include "mlir/IR/Dialect.h" + +namespace mlir { +namespace buddy { + +void registerAllDialects(mlir::DialectRegistry ®istry); +void registerAllPasses(); + +} // namespace buddy +} // namespace mlir + +#endif // BUDDY_MLIR_INITALL_H diff --git a/midend/include/buddy-mlir-c/RegisterEverything.h b/midend/include/buddy-mlir-c/RegisterEverything.h new file mode 100644 index 0000000000..982b55c7e9 --- /dev/null +++ b/midend/include/buddy-mlir-c/RegisterEverything.h @@ -0,0 +1,40 @@ +//===------- RegisterEverything.h - Register all dialects and passes ------===// +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +//===----------------------------------------------------------------------===// + +#ifndef BUDDYMLIR_C_REGISTEREVERYTHING_H +#define BUDDYMLIR_C_REGISTEREVERYTHING_H + +#include "mlir-c/IR.h" +#include "mlir-c/Support.h" + +#ifdef __cplusplus +extern "C" { +#endif + +// Registers all dialects with a context. +MLIR_CAPI_EXPORTED void buddyRegisterAllDialects(MlirDialectRegistry registry); + +// Register all translations to LLVM IR for dialects that can support it. +MLIR_CAPI_EXPORTED void buddyRegisterAllTranslations(MlirContext context); + +// Registers all passes for symbolic access with the global registry. +MLIR_CAPI_EXPORTED void buddyRegisterAllPasses(); + +#ifdef __cplusplus +} +#endif + +#endif // BUDDYMLIR_C_REGISTEREVERYTHING_H diff --git a/midend/lib/Bindings/Python/RegisterEverything.cpp b/midend/lib/Bindings/Python/RegisterEverything.cpp new file mode 100644 index 0000000000..95ee65a34b --- /dev/null +++ b/midend/lib/Bindings/Python/RegisterEverything.cpp @@ -0,0 +1,31 @@ +//===- RegisterEverything.cpp - API to register all dialects/passes -------===// +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +//===----------------------------------------------------------------------===// + +#include "buddy-mlir-c/RegisterEverything.h" +#include "mlir/Bindings/Python/PybindAdaptors.h" + +PYBIND11_MODULE(_mlirRegisterEverything, m) { + m.doc() = "Buddy MLIR All Dialects, Translations and Passes Registration"; + + m.def("register_dialects", [](MlirDialectRegistry registry) { + buddyRegisterAllDialects(registry); + }); + m.def("register_llvm_translations", + [](MlirContext context) { buddyRegisterAllTranslations(context); }); + + // Register all passes on load. + buddyRegisterAllPasses(); +} diff --git a/midend/lib/CAPI/CMakeLists.txt b/midend/lib/CAPI/CMakeLists.txt new file mode 100644 index 0000000000..fc4cb6cc0e --- /dev/null +++ b/midend/lib/CAPI/CMakeLists.txt @@ -0,0 +1,33 @@ +get_property(dialect_libs GLOBAL PROPERTY MLIR_DIALECT_LIBS) +get_property(translation_libs GLOBAL PROPERTY MLIR_TRANSLATION_LIBS) +get_property(conversion_libs GLOBAL PROPERTY MLIR_CONVERSION_LIBS) +get_property(extension_libs GLOBAL PROPERTY MLIR_EXTENSION_LIBS) + +add_mlir_public_c_api_library(BuddyMLIRCAPI + Dialects.cpp + RegisterEverything.cpp + + ADDITIONAL_HEADER_DIRS + ${PROJECT_SOURCE_DIR}/midend/include/buddy-mlir-c + + LINK_LIBS PUBLIC + ${dialect_libs} + ${translation_libs} + ${conversion_libs} + ${extension_libs} + MLIRBuiltinToLLVMIRTranslation + MLIRCAPIIR + MLIRLLVMToLLVMIRTranslation + MLIRCAPITransforms + BuddyBud + BuddyDAP + BuddyDIP + BuddyGemmini + BuddyGemminiTransforms + BuddyRVV + BuddyRVVTransforms + BuddySche + VectorExp + BuddyMLIRInitAll + BuddyToLLVMIRTranslationRegistration +) diff --git a/midend/lib/CAPI/Dialects.cpp b/midend/lib/CAPI/Dialects.cpp new file mode 100644 index 0000000000..745e1d5bbd --- /dev/null +++ b/midend/lib/CAPI/Dialects.cpp @@ -0,0 +1,43 @@ +//===------------ Dialects.cpp - C Interface for Dialects -----------------===// +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +//===----------------------------------------------------------------------===// + +#include "buddy-mlir-c/Dialects.h" + +#include "Dialect/Bud/BudDialect.h" +#include "Dialect/Bud/BudOps.h" +#include "Dialect/DAP/DAPDialect.h" +#include "Dialect/DAP/DAPOps.h" +#include "Dialect/DIP/DIPDialect.h" +#include "Dialect/DIP/DIPOps.h" +#include "Dialect/Gemmini/GemminiDialect.h" +#include "Dialect/Gemmini/GemminiOps.h" +#include "Dialect/RVV/RVVDialect.h" +#include "Dialect/Sche/ScheDialect.h" +#include "Dialect/Sche/ScheOps.h" +#include "Dialect/VectorExp/VectorExpDialect.h" +#include "Dialect/VectorExp/VectorExpOps.h" + +#include "mlir/CAPI/Registration.h" + +MLIR_DEFINE_CAPI_DIALECT_REGISTRATION(Bud, bud, buddy::bud::BudDialect) +MLIR_DEFINE_CAPI_DIALECT_REGISTRATION(DAP, dap, buddy::dap::DAPDialect) +MLIR_DEFINE_CAPI_DIALECT_REGISTRATION(DIP, dip, buddy::dip::DIPDialect) +MLIR_DEFINE_CAPI_DIALECT_REGISTRATION(Gemmini, gemmini, + buddy::gemmini::GemminiDialect) +MLIR_DEFINE_CAPI_DIALECT_REGISTRATION(RVV, rvv, buddy::rvv::RVVDialect) +MLIR_DEFINE_CAPI_DIALECT_REGISTRATION(Sche, sche, buddy::sche::ScheDialect) +MLIR_DEFINE_CAPI_DIALECT_REGISTRATION(VectorExp, vector_exp, + buddy::vector_exp::VectorExpDialect) diff --git a/midend/lib/CAPI/RegisterEverything.cpp b/midend/lib/CAPI/RegisterEverything.cpp new file mode 100644 index 0000000000..0001f63747 --- /dev/null +++ b/midend/lib/CAPI/RegisterEverything.cpp @@ -0,0 +1,70 @@ +//===------- RegisterEverything.cpp - Register all MLIR entities ----------===// +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +//===----------------------------------------------------------------------===// + +#include "buddy-mlir-c/RegisterEverything.h" +#include "buddy-mlir-c/InitAll.h" +#include "mlir-c/RegisterEverything.h" + +#include "Target/LLVMIR/Dialect/Gemmini/GemminiToLLVMIRTranslation.h" +#include "Target/LLVMIR/Dialect/RVV/RVVToLLVMIRTranslation.h" +#include "mlir/CAPI/IR.h" +#include "mlir/Dialect/DLTI/DLTI.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/IR/MLIRContext.h" +#include "mlir/InitAllDialects.h" +#include "mlir/InitAllExtensions.h" +#include "mlir/InitAllPasses.h" +#include "mlir/Target/LLVMIR/Dialect/All.h" +#include "mlir/Target/LLVMIR/Dialect/Builtin/BuiltinToLLVMIRTranslation.h" +#include "mlir/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.h" + +using namespace buddy; +using namespace mlir; + +namespace buddy { +void registerBuddyToLLVMIRTranslation(); +} + +void buddyRegisterAllDialects(MlirDialectRegistry registry) { + // Register all Dialects from UPSTREAM MLIR + mlir::registerAllDialects(*unwrap(registry)); + mlir::registerAllExtensions(*unwrap(registry)); + + // Register all Dialects from BUDDY MLIR + mlir::buddy::registerAllDialects(*unwrap(registry)); +} + +void buddyRegisterAllTranslations(MlirContext context) { + auto &ctx = *unwrap(context); + mlir::DialectRegistry registry; + // Register all Translations from UPSTREAM MLIR + registry.insert(); + mlir::registerAllToLLVMIRTranslations(registry); + + // Register all Translations from BUDDY MLIR + registerRVVDialectTranslation(registry); + registerGemminiDialectTranslation(registry); + + ctx.appendDialectRegistry(registry); +} + +void buddyRegisterAllPasses() { + // Register all Passes from UPSTREAM MLIR + mlir::registerAllPasses(); + + // Register all Passes from BUDDY MLIR + mlir::buddy::registerAllPasses(); +} diff --git a/midend/lib/CMakeLists.txt b/midend/lib/CMakeLists.txt index 19b254cf38..cae54478c3 100644 --- a/midend/lib/CMakeLists.txt +++ b/midend/lib/CMakeLists.txt @@ -1,8 +1,46 @@ +add_subdirectory(CAPI) add_subdirectory(Dialect) add_subdirectory(Conversion) add_subdirectory(Target) add_subdirectory(Utils) + +get_property(extension_libs GLOBAL PROPERTY MLIR_EXTENSION_LIBS) +set(LinkedLibs + MLIRFuncDialect + MLIRIR + MLIRSupport + ${extension_libs} + + ConvOptimization + CBConvVectorization + LowerBudPass + LowerDAPPass + LowerDIPPass + LowerGemminiPass + LowerLinalgToGemminiPass + LowerRVVPass + LowerSche + LowerVectorExpPass + MatMulOptimization + BatchMatMulOptimization + MatMulParallelVectorization + SchedulingOnDevices + TransposeOptimization +) + + +add_mlir_library(BuddyMLIRInitAll + InitAll.cpp + + LINK_COMPONENTS + Core + + LINK_LIBS PUBLIC + ${LinkedLibs} +) + + # Build static library for async runtime. add_mlir_library(static_mlir_async_runtime STATIC diff --git a/midend/lib/InitAll.cpp b/midend/lib/InitAll.cpp new file mode 100644 index 0000000000..d6cad2bc1e --- /dev/null +++ b/midend/lib/InitAll.cpp @@ -0,0 +1,83 @@ +//===----------- InitAll.cpp - Register all dialects and passes -----------===// +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +//===----------------------------------------------------------------------===// + +#include "buddy-mlir-c/InitAll.h" + +#include "mlir/Dialect/Func/Extensions/InlinerExtension.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/IR/Dialect.h" + +#include "Dialect/Bud/BudDialect.h" +#include "Dialect/DAP/DAPDialect.h" +#include "Dialect/DIP/DIPDialect.h" +#include "Dialect/Gemmini/GemminiDialect.h" +#include "Dialect/RVV/RVVDialect.h" +#include "Dialect/Sche/ScheDialect.h" +#include "Dialect/VectorExp/VectorExpDialect.h" + +namespace mlir { +namespace buddy { +void registerConvOptimizePass(); +void registerConvVectorizationPass(); +void registerPointwiseConvToGemmPass(); +void registerPoolingVectorizationPass(); +void registerLowerBudPass(); +void registerLowerDAPPass(); +void registerLowerDIPPass(); +void registerLowerGemminiPass(); +void registerLowerLinalgToGemminiPass(); +void registerLowerRVVPass(); +void registerLowerSchePass(); +void registerLowerVectorExpPass(); +void registerBatchMatMulOptimizePass(); +void registerMatMulOptimizePass(); +void registerMatMulParallelVectorizationPass(); +void registerMatMulVectorizationPass(); +void registerDeviceSchedulePass(); +void registerTransposeOptimizationPass(); +} // namespace buddy +} // namespace mlir + +void mlir::buddy::registerAllDialects(mlir::DialectRegistry ®istry) { + registry.insert<::buddy::bud::BudDialect>(); + registry.insert<::buddy::dap::DAPDialect>(); + registry.insert<::buddy::dip::DIPDialect>(); + registry.insert<::buddy::gemmini::GemminiDialect>(); + registry.insert<::buddy::rvv::RVVDialect>(); + registry.insert<::buddy::sche::ScheDialect>(); + registry.insert<::buddy::vector_exp::VectorExpDialect>(); +} + +void mlir::buddy::registerAllPasses() { + mlir::buddy::registerConvOptimizePass(); + mlir::buddy::registerConvVectorizationPass(); + mlir::buddy::registerPointwiseConvToGemmPass(); + mlir::buddy::registerPoolingVectorizationPass(); + mlir::buddy::registerLowerBudPass(); + mlir::buddy::registerLowerDAPPass(); + mlir::buddy::registerLowerDIPPass(); + mlir::buddy::registerLowerGemminiPass(); + mlir::buddy::registerLowerLinalgToGemminiPass(); + mlir::buddy::registerLowerRVVPass(); + mlir::buddy::registerLowerSchePass(); + mlir::buddy::registerLowerVectorExpPass(); + mlir::buddy::registerBatchMatMulOptimizePass(); + mlir::buddy::registerMatMulOptimizePass(); + mlir::buddy::registerMatMulParallelVectorizationPass(); + mlir::buddy::registerMatMulVectorizationPass(); + mlir::buddy::registerDeviceSchedulePass(); + mlir::buddy::registerTransposeOptimizationPass(); +} diff --git a/midend/python/CMakeLists.txt b/midend/python/CMakeLists.txt new file mode 100644 index 0000000000..0afc9694a7 --- /dev/null +++ b/midend/python/CMakeLists.txt @@ -0,0 +1,138 @@ +include(AddMLIRPython) + +# Specifies that all MLIR packages are co-located under the `mlir_standalone` +# top level package (the API has been embedded in a relocatable way). +# TODO: Add an upstream cmake param for this vs having a global here. +add_compile_definitions("MLIR_PYTHON_PACKAGE_PREFIX=buddy_mlir.") + +################################################################################ +# Structural groupings. +################################################################################ + +declare_mlir_python_sources(BuddyMLIRPythonSources) +declare_mlir_python_sources(BuddyMLIRPythonSources.Dialects + ADD_TO_PARENT BuddyMLIRPythonSources) + +################################################################################ +# Dialect bindings +################################################################################ + +declare_mlir_dialect_python_bindings( + ADD_TO_PARENT BuddyMLIRPythonSources.Dialects + ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/buddy_mlir" + TD_FILE dialects/BudBinding.td + SOURCES + dialects/bud.py + DIALECT_NAME bud) + +declare_mlir_dialect_python_bindings( + ADD_TO_PARENT BuddyMLIRPythonSources.Dialects + ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/buddy_mlir" + TD_FILE dialects/DAPBinding.td + SOURCES + dialects/dap.py + DIALECT_NAME dap) + +declare_mlir_dialect_python_bindings( + ADD_TO_PARENT BuddyMLIRPythonSources.Dialects + ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/buddy_mlir" + TD_FILE dialects/DIPBinding.td + SOURCES + dialects/dip.py + DIALECT_NAME dip) + +declare_mlir_dialect_python_bindings( + ADD_TO_PARENT BuddyMLIRPythonSources.Dialects + ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/buddy_mlir" + TD_FILE dialects/GemminiBinding.td + SOURCES + dialects/gemmini.py + DIALECT_NAME gemmini) + +declare_mlir_dialect_python_bindings( + ADD_TO_PARENT BuddyMLIRPythonSources.Dialects + ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/buddy_mlir" + TD_FILE dialects/RVVBinding.td + SOURCES + dialects/rvv.py + DIALECT_NAME rvv) + +declare_mlir_dialect_python_bindings( + ADD_TO_PARENT BuddyMLIRPythonSources.Dialects + ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/buddy_mlir" + TD_FILE dialects/ScheBinding.td + SOURCES + dialects/sche.py + DIALECT_NAME sche) + +declare_mlir_dialect_python_bindings( + ADD_TO_PARENT BuddyMLIRPythonSources.Dialects + ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/buddy_mlir" + TD_FILE dialects/VectorExpBinding.td + SOURCES + dialects/vector_exp.py + DIALECT_NAME vector_exp) + +################################################################################ +# Python extensions. +# The sources for these are all in lib/python/Bindings, but since they have to +# be rebuilt for each package and integrate with the source setup here, we +# just reference them here instead of having ordered, cross package target +# dependencies. +################################################################################ + +set(PYTHON_SOURCE_DIR "${CMAKE_CURRENT_SOURCE_DIR}/../lib/Bindings/Python") + +declare_mlir_python_extension(BuddyMLIRPythonSources.Extension + MODULE_NAME _mlirRegisterEverything + ROOT_DIR "${PYTHON_SOURCE_DIR}" + ADD_TO_PARENT BuddyMLIRPythonSources + SOURCES + RegisterEverything.cpp + PRIVATE_LINK_LIBS + LLVMSupport + EMBED_CAPI_LINK_LIBS + BuddyMLIRCAPI + MLIRCAPIConversion + MLIRCAPITransforms +) + +################################################################################ +# Common CAPI dependency DSO. +# All python extensions must link through one DSO which exports the CAPI, and +# this must have a globally unique name amongst all embeddors of the python +# library since it will effectively have global scope. +# +# The presence of this aggregate library is part of the long term plan, but its +# use needs to be made more flexible. +# +# TODO: Upgrade to the aggregate utility in https://reviews.llvm.org/D106419 +# once ready. +################################################################################ + +add_mlir_python_common_capi_library(BuddyMLIRPythonCAPI + INSTALL_COMPONENT BuddyMLIRPythonModules + INSTALL_DESTINATION python_packages/buddy_mlir/_mlir_libs + OUTPUT_DIRECTORY "${BUDDY_BUILD_DIR}/python_packages/buddy_mlir/_mlir_libs" + RELATIVE_INSTALL_ROOT "../../.." + DECLARED_SOURCES + BuddyMLIRPythonSources + MLIRPythonSources + MLIRPythonExtension.Core +) + +################################################################################ +# Instantiation of all Python modules +################################################################################ + +add_mlir_python_modules(BuddyMLIRPythonModules + ROOT_PREFIX "${BUDDY_BUILD_DIR}/python_packages/buddy_mlir" + INSTALL_PREFIX "python_packages/buddy_mlir" + DECLARED_SOURCES + BuddyMLIRPythonSources + MLIRPythonSources + MLIRPythonExtension.Core + COMMON_CAPI_LINK_LIBS + BuddyMLIRPythonCAPI + MLIRPythonCAPI + ) diff --git a/midend/python/buddy_mlir/dialects/BudBinding.td b/midend/python/buddy_mlir/dialects/BudBinding.td new file mode 100644 index 0000000000..77fb86b121 --- /dev/null +++ b/midend/python/buddy_mlir/dialects/BudBinding.td @@ -0,0 +1,23 @@ +//===-------- BudOps.td - Python bindings for Bud --*- tablegen -*--------===// +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +//===----------------------------------------------------------------------===// + +#ifndef PYTHON_BINDINGS_BUD_OPS +#define PYTHON_BINDINGS_BUD_OPS + +// include "mlir/Bindings/Python/Attributes.td" +include "Bud/BudOps.td" + +#endif diff --git a/midend/python/buddy_mlir/dialects/DAPBinding.td b/midend/python/buddy_mlir/dialects/DAPBinding.td new file mode 100644 index 0000000000..2e0cdaca44 --- /dev/null +++ b/midend/python/buddy_mlir/dialects/DAPBinding.td @@ -0,0 +1,23 @@ +//===-------- DAPOps.td - Python bindings for DAP --*- tablegen -*--------===// +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +//===----------------------------------------------------------------------===// + +#ifndef PYTHON_BINDINGS_DAP_OPS +#define PYTHON_BINDINGS_DAP_OPS + +// include "mlir/Bindings/Python/Attributes.td" +include "DAP/DAPOps.td" + +#endif diff --git a/midend/python/buddy_mlir/dialects/DIPBinding.td b/midend/python/buddy_mlir/dialects/DIPBinding.td new file mode 100644 index 0000000000..8bd6e035bf --- /dev/null +++ b/midend/python/buddy_mlir/dialects/DIPBinding.td @@ -0,0 +1,23 @@ +//===-------- DIPOps.td - Python bindings for DIP --*- tablegen -*--------===// +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +//===----------------------------------------------------------------------===// + +#ifndef PYTHON_BINDINGS_DIP_OPS +#define PYTHON_BINDINGS_DIP_OPS + +// include "mlir/Bindings/Python/Attributes.td" +include "DIP/DIPOps.td" + +#endif diff --git a/midend/python/buddy_mlir/dialects/GemminiBinding.td b/midend/python/buddy_mlir/dialects/GemminiBinding.td new file mode 100644 index 0000000000..346cdfc1c2 --- /dev/null +++ b/midend/python/buddy_mlir/dialects/GemminiBinding.td @@ -0,0 +1,23 @@ +//===---- GemminiOps.td - Python bindings for Gemmini --*- tablegen -*----===// +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +//===----------------------------------------------------------------------===// + +#ifndef PYTHON_BINDINGS_Gemmini_OPS +#define PYTHON_BINDINGS_Gemmini_OPS + +// include "mlir/Bindings/Python/Attributes.td" +include "Gemmini/Gemmini.td" + +#endif diff --git a/midend/python/buddy_mlir/dialects/RVVBinding.td b/midend/python/buddy_mlir/dialects/RVVBinding.td new file mode 100644 index 0000000000..f82bcb7710 --- /dev/null +++ b/midend/python/buddy_mlir/dialects/RVVBinding.td @@ -0,0 +1,23 @@ +//===-------- RVVOps.td - Python bindings for RVV --*- tablegen -*--------===// +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +//===----------------------------------------------------------------------===// + +#ifndef PYTHON_BINDINGS_RVV_OPS +#define PYTHON_BINDINGS_RVV_OPS + +// include "mlir/Bindings/Python/Attributes.td" +include "RVV/RVV.td" + +#endif diff --git a/midend/python/buddy_mlir/dialects/ScheBinding.td b/midend/python/buddy_mlir/dialects/ScheBinding.td new file mode 100644 index 0000000000..8db80b9075 --- /dev/null +++ b/midend/python/buddy_mlir/dialects/ScheBinding.td @@ -0,0 +1,23 @@ +//===------- ScheOps.td - Python bindings for Sche --*- tablegen -*-------===// +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +//===----------------------------------------------------------------------===// + +#ifndef PYTHON_BINDINGS_SCHE_OPS +#define PYTHON_BINDINGS_SCHE_OPS + +// include "mlir/Bindings/Python/Attributes.td" +include "Sche/ScheOps.td" + +#endif diff --git a/midend/python/buddy_mlir/dialects/VectorExpBinding.td b/midend/python/buddy_mlir/dialects/VectorExpBinding.td new file mode 100644 index 0000000000..204ecb2991 --- /dev/null +++ b/midend/python/buddy_mlir/dialects/VectorExpBinding.td @@ -0,0 +1,23 @@ +//===-- VectorExpOps.td - Python bindings for VectorExp --*- tablegen -*--===// +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +//===----------------------------------------------------------------------===// + +#ifndef PYTHON_BINDINGS_VECTOREXP_OPS +#define PYTHON_BINDINGS_VECTOREXP_OPS + +// include "mlir/Bindings/Python/Attributes.td" +include "VectorExp/VectorExpOps.td" + +#endif diff --git a/midend/python/buddy_mlir/dialects/bud.py b/midend/python/buddy_mlir/dialects/bud.py new file mode 100644 index 0000000000..7cc060fdde --- /dev/null +++ b/midend/python/buddy_mlir/dialects/bud.py @@ -0,0 +1,17 @@ +# ===------------------------ bud.py ------------------------------------------- +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# ===--------------------------------------------------------------------------- + +from ._bud_ops_gen import * diff --git a/midend/python/buddy_mlir/dialects/dap.py b/midend/python/buddy_mlir/dialects/dap.py new file mode 100644 index 0000000000..969d967ab4 --- /dev/null +++ b/midend/python/buddy_mlir/dialects/dap.py @@ -0,0 +1,17 @@ +# ===------------------------ dap.py ------------------------------------------- +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# ===--------------------------------------------------------------------------- + +from ._dap_ops_gen import * diff --git a/midend/python/buddy_mlir/dialects/dip.py b/midend/python/buddy_mlir/dialects/dip.py new file mode 100644 index 0000000000..60ee9262d9 --- /dev/null +++ b/midend/python/buddy_mlir/dialects/dip.py @@ -0,0 +1,17 @@ +# ===------------------------ dip.py ------------------------------------------- +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# ===--------------------------------------------------------------------------- + +from ._dip_ops_gen import * diff --git a/midend/python/buddy_mlir/dialects/gemmini.py b/midend/python/buddy_mlir/dialects/gemmini.py new file mode 100644 index 0000000000..a91c517714 --- /dev/null +++ b/midend/python/buddy_mlir/dialects/gemmini.py @@ -0,0 +1,17 @@ +# ===------------------------ gemmini.py --------------------------------------- +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# ===--------------------------------------------------------------------------- + +from ._gemmini_ops_gen import * diff --git a/midend/python/buddy_mlir/dialects/rvv.py b/midend/python/buddy_mlir/dialects/rvv.py new file mode 100644 index 0000000000..8eb86f6e60 --- /dev/null +++ b/midend/python/buddy_mlir/dialects/rvv.py @@ -0,0 +1,17 @@ +# ===------------------------ rvv.py ------------------------------------------- +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# ===--------------------------------------------------------------------------- + +from ._rvv_ops_gen import * diff --git a/midend/python/buddy_mlir/dialects/sche.py b/midend/python/buddy_mlir/dialects/sche.py new file mode 100644 index 0000000000..a050685674 --- /dev/null +++ b/midend/python/buddy_mlir/dialects/sche.py @@ -0,0 +1,17 @@ +# ===------------------------ sche.py ------------------------------------------ +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# ===--------------------------------------------------------------------------- + +from ._sche_ops_gen import * diff --git a/midend/python/buddy_mlir/dialects/vector_exp.py b/midend/python/buddy_mlir/dialects/vector_exp.py new file mode 100644 index 0000000000..d448526a58 --- /dev/null +++ b/midend/python/buddy_mlir/dialects/vector_exp.py @@ -0,0 +1,17 @@ +# ===------------------------ vector_exp.py ------------------------------------ +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# ===--------------------------------------------------------------------------- + +from ._vector_exp_ops_gen import * diff --git a/tests/Python/test_python.py b/tests/Python/test_python.py new file mode 100644 index 0000000000..00a7eebbb4 --- /dev/null +++ b/tests/Python/test_python.py @@ -0,0 +1,49 @@ +# RUN: %PYTHON %s 2>&1 | FileCheck %s + +from buddy_mlir.ir import Context, Module +from buddy_mlir.passmanager import PassManager + + +with Context(): + mod = Module.parse( + """ + %0 = arith.constant 0 : i8 + %1 = arith.constant 1 : i8 + %2 = arith.constant 2 : i8 + %mem0 = memref.alloc() : memref<8x8xi8> + %mem1 = memref.alloc() : memref<8x8xi8> + %mem2 = memref.alloc() : memref<8x8xi8> + linalg.fill + ins(%2 : i8) + outs(%mem0 : memref<8x8xi8>) + linalg.fill + ins(%1 : i8) + outs(%mem1 : memref<8x8xi8>) + linalg.matmul + ins(%mem0, %mem1 : memref<8x8xi8>, memref<8x8xi8>) + outs(%mem2 : memref<8x8xi8>) + gemmini.print %mem2 : memref<8x8xi8> + """ + ) + + pm = PassManager("builtin.module") + pm.add("convert-linalg-to-gemmini") + pm.run(mod.operation) + + # CHECK: module { + # CHECK: %c0_i8 = arith.constant 0 : i8 + # CHECK: %c1_i8 = arith.constant 1 : i8 + # CHECK: %c2_i8 = arith.constant 2 : i8 + # CHECK: %alloc = memref.alloc() : memref<8x8xi8> + # CHECK: %alloc_0 = memref.alloc() : memref<8x8xi8> + # CHECK: %alloc_1 = memref.alloc() : memref<8x8xi8> + # CHECK: linalg.fill ins(%c2_i8 : i8) outs(%alloc : memref<8x8xi8>) + # CHECK: linalg.fill ins(%c1_i8 : i8) outs(%alloc_0 : memref<8x8xi8>) + # CHECK: %alloc_2 = memref.alloc() : memref<8x8xi32> + # CHECK: %c0_i32 = arith.constant 0 : i32 + # CHECK: linalg.fill ins(%c0_i32 : i32) outs(%alloc_2 : memref<8x8xi32>) + # CHECK: gemmini.tile_matmul %alloc %alloc_0 %alloc_1 %alloc_2 : memref<8x8xi8> memref<8x8xi8> memref<8x8xi8> memref<8x8xi32> + # CHECK: memref.dealloc %alloc_2 : memref<8x8xi32> + # CHECK: gemmini.print %alloc_1 : memref<8x8xi8> + # CHECK: } + print(str(mod)) From 61a4c5bd76e81b1bd1ce61940d6bbda6a7720f37 Mon Sep 17 00:00:00 2001 From: Yuliang Li <40186387+xTayEx@users.noreply.github.com> Date: Mon, 6 Jan 2025 10:58:50 +0800 Subject: [PATCH 6/8] [examples] Add several examples of `vector` dialect on GPU (#442) * feat: Add store, load, and bitcast * feat: Add constant_mask * feat: Add constant-mask, contract, create-mask, extract, fma, gather, splat * feat: Add insert, outerproduct and transpose * feat: Add reduction * feat: Add shape-cast and type-cast * feat: refine makefile. * feat: exclude MLIRVectorGPU from lit configuration * fix: 1. add missing newlines at the end of multiple MLIR files. 2. refine several binarys' path in the makefile. --- examples/MLIRVectorGPU/makefile | 1024 +++++++++++++++++ examples/MLIRVectorGPU/vector-bitcast.mlir | 45 + .../MLIRVectorGPU/vector-compressstore.mlir | 50 + .../MLIRVectorGPU/vector-constant-mask.mlir | 31 + examples/MLIRVectorGPU/vector-contract.mlir | 34 + .../MLIRVectorGPU/vector-create-mask.mlir | 23 + examples/MLIRVectorGPU/vector-extract.mlir | 26 + examples/MLIRVectorGPU/vector-fma.mlir | 32 + examples/MLIRVectorGPU/vector-gather.mlir | 99 ++ examples/MLIRVectorGPU/vector-insert.mlir | 72 ++ examples/MLIRVectorGPU/vector-load-store.mlir | 36 + .../MLIRVectorGPU/vector-outerproduct.mlir | 41 + examples/MLIRVectorGPU/vector-reduction.mlir | 35 + examples/MLIRVectorGPU/vector-shape-cast.mlir | 26 + examples/MLIRVectorGPU/vector-splat.mlir | 23 + examples/MLIRVectorGPU/vector-transpose.mlir | 61 + examples/MLIRVectorGPU/vector-type-cast.mlir | 32 + examples/lit.cfg.py | 1 + 18 files changed, 1691 insertions(+) create mode 100644 examples/MLIRVectorGPU/makefile create mode 100644 examples/MLIRVectorGPU/vector-bitcast.mlir create mode 100644 examples/MLIRVectorGPU/vector-compressstore.mlir create mode 100644 examples/MLIRVectorGPU/vector-constant-mask.mlir create mode 100644 examples/MLIRVectorGPU/vector-contract.mlir create mode 100644 examples/MLIRVectorGPU/vector-create-mask.mlir create mode 100644 examples/MLIRVectorGPU/vector-extract.mlir create mode 100644 examples/MLIRVectorGPU/vector-fma.mlir create mode 100644 examples/MLIRVectorGPU/vector-gather.mlir create mode 100644 examples/MLIRVectorGPU/vector-insert.mlir create mode 100644 examples/MLIRVectorGPU/vector-load-store.mlir create mode 100644 examples/MLIRVectorGPU/vector-outerproduct.mlir create mode 100644 examples/MLIRVectorGPU/vector-reduction.mlir create mode 100644 examples/MLIRVectorGPU/vector-shape-cast.mlir create mode 100644 examples/MLIRVectorGPU/vector-splat.mlir create mode 100644 examples/MLIRVectorGPU/vector-transpose.mlir create mode 100644 examples/MLIRVectorGPU/vector-type-cast.mlir diff --git a/examples/MLIRVectorGPU/makefile b/examples/MLIRVectorGPU/makefile new file mode 100644 index 0000000000..7b8a0b2879 --- /dev/null +++ b/examples/MLIRVectorGPU/makefile @@ -0,0 +1,1024 @@ +#!/bin/bash +BUDDY_BUILD_DIR := ../../build/ +LLVM_BUILD_DIR := ../../llvm/build/ +BUDDY_OPT := ${BUDDY_BUILD_DIR}/bin/buddy-opt +MLIR_OPT := ${LLVM_BUILD_DIR}/bin/mlir-opt +MLIR_TRANSLATE := ${LLVM_BUILD_DIR}/bin/mlir-translate +MLIR_CPU_RUNNER := ${LLVM_BUILD_DIR}/bin/mlir-cpu-runner +LLC := ${LLVM_BUILD_DIR}/bin/llc +OPT_FLAG := -O0 +CUDA_COMPUTE_CAPACITY ?= $(shell nvidia-smi --query-gpu=compute_cap --format=csv,noheader | awk 'NR==1{printf "sm_%.0f", $$0*10}') + +ifeq ($(shell uname),Linux) +MLIR_RUNNER_UTILS := ../../llvm/build/lib/libmlir_runner_utils.so +MLIR_C_RUNNER_UTILS := ../../llvm/build/lib/libmlir_c_runner_utils.so +MLIR_CUDA_RUNTIME := ../../llvm/build/lib/libmlir_cuda_runtime.so +MTRIPLE := x86_64-unknown-linux-gnu +else ifeq ($(shell uname),Darwin) +MLIR_RUNNER_UTILS := ../../llvm/build/lib/libmlir_runner_utils.dylib +MLIR_C_RUNNER_UTILS := ../../llvm/build/lib/libmlir_c_runner_utils.dylib +MTRIPLE := x86_64-apple-darwin +endif + +.SECONDEXPANSION: +all-run: $$(run-targets) + +# vector-load-lower: +vector-load-store-lower: + @${MLIR_OPT} ./vector-load-store.mlir \ + --pass-pipeline="builtin.module( \ + convert-linalg-to-loops, \ + arith-bufferize, \ + convert-vector-to-scf, \ + lower-affine, \ + convert-scf-to-cf, \ + convert-vector-to-llvm, \ + convert-arith-to-llvm, \ + convert-func-to-llvm, \ + gpu-kernel-outlining, \ + nvvm-attach-target{chip=${CUDA_COMPUTE_CAPACITY} O=3}, \ + strip-debuginfo, \ + gpu.module(convert-gpu-to-nvvm), \ + gpu-to-llvm, \ + reconcile-unrealized-casts, \ + gpu-module-to-binary \ + )" \ + -o log.mlir + +vector-load-store-translate: + @${MLIR_OPT} ./vector-load-store.mlir \ + --pass-pipeline="builtin.module( \ + convert-linalg-to-loops, \ + arith-bufferize, \ + convert-vector-to-scf, \ + lower-affine, \ + convert-scf-to-cf, \ + convert-vector-to-llvm, \ + convert-arith-to-llvm, \ + convert-func-to-llvm, \ + gpu-kernel-outlining, \ + nvvm-attach-target{chip=${CUDA_COMPUTE_CAPACITY} O=3}, \ + strip-debuginfo, \ + gpu.module(convert-gpu-to-nvvm), \ + gpu-to-llvm, \ + reconcile-unrealized-casts, \ + gpu-module-to-binary \ + )" | \ + ${MLIR_TRANSLATE} --mlir-to-llvmir -o log.ll + +run-targets += vector-load-store-run +vector-load-store-run: + @${MLIR_OPT} ./vector-load-store.mlir \ + --pass-pipeline="builtin.module( \ + convert-linalg-to-loops, \ + arith-bufferize, \ + convert-vector-to-scf, \ + lower-affine, \ + convert-scf-to-cf, \ + convert-vector-to-llvm, \ + convert-arith-to-llvm, \ + convert-func-to-llvm, \ + gpu-kernel-outlining, \ + nvvm-attach-target{chip=${CUDA_COMPUTE_CAPACITY} O=3}, \ + strip-debuginfo, \ + gpu.module(convert-gpu-to-nvvm), \ + gpu-to-llvm, \ + reconcile-unrealized-casts, \ + gpu-module-to-binary \ + )" | \ + ${MLIR_CPU_RUNNER} -entry-point-result=void \ + -shared-libs=${MLIR_RUNNER_UTILS} -shared-libs=${MLIR_CUDA_RUNTIME} + +run-targets += vector-load-store-run + +vector-bitcast-lower: + @${MLIR_OPT} ./vector-bitcast.mlir \ + --pass-pipeline="builtin.module( \ + convert-linalg-to-loops, \ + convert-vector-to-scf, \ + lower-affine, \ + convert-scf-to-cf, \ + convert-vector-to-llvm, \ + convert-arith-to-llvm, \ + convert-func-to-llvm, \ + gpu-kernel-outlining, \ + nvvm-attach-target{chip=sm_86 O=3}, \ + strip-debuginfo, \ + gpu.module(convert-gpu-to-nvvm), \ + gpu-to-llvm, \ + reconcile-unrealized-casts, \ + gpu-module-to-binary \ + )" \ + -o log.mlir + +vector-bitcast-translate: + @${MLIR_OPT} ./vector-bitcast.mlir \ + --pass-pipeline="builtin.module( \ + convert-linalg-to-loops, \ + convert-vector-to-scf, \ + lower-affine, \ + convert-scf-to-cf, \ + convert-vector-to-llvm, \ + convert-arith-to-llvm, \ + convert-func-to-llvm, \ + gpu-kernel-outlining, \ + nvvm-attach-target{chip=sm_86 O=3}, \ + strip-debuginfo, \ + gpu.module(convert-gpu-to-nvvm), \ + gpu-to-llvm, \ + reconcile-unrealized-casts, \ + gpu-module-to-binary \ + )" | \ + ${MLIR_TRANSLATE} --mlir-to-llvmir -o log.ll + +run-targets += vector-bitcast-run +vector-bitcast-run: + @${MLIR_OPT} ./vector-bitcast.mlir \ + --pass-pipeline="builtin.module( \ + convert-linalg-to-loops, \ + convert-vector-to-scf, \ + lower-affine, \ + convert-scf-to-cf, \ + convert-vector-to-llvm, \ + convert-arith-to-llvm, \ + convert-func-to-llvm, \ + gpu-kernel-outlining, \ + nvvm-attach-target{chip=sm_86 O=3}, \ + strip-debuginfo, \ + gpu.module(convert-gpu-to-nvvm), \ + gpu-to-llvm, \ + reconcile-unrealized-casts, \ + gpu-module-to-binary \ + )" | \ + ${MLIR_CPU_RUNNER} -entry-point-result=void \ + -shared-libs=${MLIR_RUNNER_UTILS} -shared-libs=${MLIR_CUDA_RUNTIME} + + +vector-compressstore-lower: + @${MLIR_OPT} ./vector-compressstore.mlir \ + --pass-pipeline="builtin.module( \ + convert-linalg-to-loops, \ + convert-vector-to-scf, \ + lower-affine, \ + convert-scf-to-cf, \ + convert-vector-to-llvm, \ + convert-arith-to-llvm, \ + convert-func-to-llvm, \ + gpu-kernel-outlining, \ + nvvm-attach-target{chip=${CUDA_COMPUTE_CAPACITY} O=3}, \ + strip-debuginfo, \ + gpu.module(convert-gpu-to-nvvm), \ + gpu-to-llvm, \ + reconcile-unrealized-casts, \ + gpu-module-to-binary \ + )" \ + -o log.mlir + +vector-compressstore-translate: + @${MLIR_OPT} ./vector-compressstore.mlir \ + --pass-pipeline="builtin.module( \ + convert-linalg-to-loops, \ + convert-vector-to-scf, \ + lower-affine, \ + convert-scf-to-cf, \ + convert-vector-to-llvm, \ + convert-arith-to-llvm, \ + convert-func-to-llvm, \ + gpu-kernel-outlining, \ + nvvm-attach-target{chip=${CUDA_COMPUTE_CAPACITY} O=3}, \ + strip-debuginfo, \ + gpu.module(convert-gpu-to-nvvm), \ + gpu-to-llvm, \ + reconcile-unrealized-casts, \ + gpu-module-to-binary \ + )" | \ + ${MLIR_TRANSLATE} --mlir-to-llvmir -o log.ll + +run-targets += vector-compressstore-run +vector-compressstore-run: + @${MLIR_OPT} ./vector-compressstore.mlir \ + --pass-pipeline="builtin.module( \ + convert-linalg-to-loops, \ + convert-vector-to-scf, \ + lower-affine, \ + convert-scf-to-cf, \ + convert-vector-to-llvm, \ + convert-arith-to-llvm, \ + convert-func-to-llvm, \ + gpu-kernel-outlining, \ + nvvm-attach-target{chip=${CUDA_COMPUTE_CAPACITY} O=3}, \ + strip-debuginfo, \ + gpu.module(convert-gpu-to-nvvm), \ + gpu-to-llvm, \ + reconcile-unrealized-casts, \ + gpu-module-to-binary \ + )" | \ + ${MLIR_CPU_RUNNER} -entry-point-result=void \ + -shared-libs=${MLIR_RUNNER_UTILS} -shared-libs=${MLIR_CUDA_RUNTIME} + +# The order of convert-arith-to-llvm and convert-vector-to-llvm matters. +vector-constant-mask-lower: + @${MLIR_OPT} ./vector-constant-mask.mlir \ + --pass-pipeline="builtin.module( \ + convert-linalg-to-loops, \ + convert-vector-to-scf, \ + lower-affine, \ + convert-scf-to-cf, \ + convert-arith-to-llvm, \ + convert-vector-to-llvm, \ + convert-func-to-llvm, \ + gpu-kernel-outlining, \ + nvvm-attach-target{chip=${CUDA_COMPUTE_CAPACITY} O=3}, \ + strip-debuginfo, \ + gpu.module(convert-gpu-to-nvvm), \ + gpu-to-llvm, \ + reconcile-unrealized-casts, \ + gpu-module-to-binary \ + )" \ + -o log.mlir + +vector-constant-mask-translate: + @${MLIR_OPT} ./vector-constant-mask.mlir \ + --pass-pipeline="builtin.module( \ + convert-linalg-to-loops, \ + convert-vector-to-scf, \ + lower-affine, \ + convert-scf-to-cf, \ + convert-arith-to-llvm, \ + convert-vector-to-llvm, \ + convert-func-to-llvm, \ + gpu-kernel-outlining, \ + nvvm-attach-target{chip=${CUDA_COMPUTE_CAPACITY} O=3}, \ + strip-debuginfo, \ + gpu.module(convert-gpu-to-nvvm), \ + gpu-to-llvm, \ + reconcile-unrealized-casts, \ + gpu-module-to-binary \ + )" | \ + ${MLIR_TRANSLATE} --mlir-to-llvmir -o log.ll + +run-targets += vector-constant-mask-run +vector-constant-mask-run: + @${MLIR_OPT} ./vector-constant-mask.mlir \ + --pass-pipeline="builtin.module( \ + convert-linalg-to-loops, \ + convert-vector-to-scf, \ + lower-affine, \ + convert-scf-to-cf, \ + convert-arith-to-llvm, \ + convert-vector-to-llvm, \ + convert-func-to-llvm, \ + gpu-kernel-outlining, \ + nvvm-attach-target{chip=${CUDA_COMPUTE_CAPACITY} O=3}, \ + strip-debuginfo, \ + gpu.module(convert-gpu-to-nvvm), \ + gpu-to-llvm, \ + reconcile-unrealized-casts, \ + gpu-module-to-binary \ + )" | \ + ${MLIR_CPU_RUNNER} -entry-point-result=void \ + -shared-libs=${MLIR_RUNNER_UTILS} -shared-libs=${MLIR_C_RUNNER_UTILS} -shared-libs=${MLIR_CUDA_RUNTIME} + +vector-contract-lower: + @${MLIR_OPT} ./vector-contract.mlir \ + --pass-pipeline="builtin.module( \ + convert-linalg-to-loops, \ + convert-vector-to-scf, \ + lower-affine, \ + convert-scf-to-cf, \ + convert-vector-to-llvm, \ + convert-arith-to-llvm, \ + convert-func-to-llvm, \ + gpu-kernel-outlining, \ + nvvm-attach-target{chip=${CUDA_COMPUTE_CAPACITY} O=3}, \ + strip-debuginfo, \ + gpu.module(convert-gpu-to-nvvm), \ + gpu-to-llvm, \ + reconcile-unrealized-casts, \ + gpu-module-to-binary \ + )" \ + -o log.mlir + +vector-contract-translate: + @${MLIR_OPT} ./vector-contract.mlir \ + --pass-pipeline="builtin.module( \ + convert-linalg-to-loops, \ + convert-vector-to-scf, \ + lower-affine, \ + convert-scf-to-cf, \ + convert-vector-to-llvm, \ + convert-arith-to-llvm, \ + convert-func-to-llvm, \ + gpu-kernel-outlining, \ + nvvm-attach-target{chip=${CUDA_COMPUTE_CAPACITY} O=3}, \ + strip-debuginfo, \ + gpu.module(convert-gpu-to-nvvm), \ + gpu-to-llvm, \ + reconcile-unrealized-casts, \ + gpu-module-to-binary \ + )" | \ + ${MLIR_TRANSLATE} --mlir-to-llvmir -o log.ll + +run-targets += vector-contract-run +vector-contract-run: + @${MLIR_OPT} ./vector-contract.mlir \ + --pass-pipeline="builtin.module( \ + convert-linalg-to-loops, \ + convert-vector-to-scf, \ + lower-affine, \ + convert-scf-to-cf, \ + convert-vector-to-llvm, \ + convert-arith-to-llvm, \ + convert-func-to-llvm, \ + gpu-kernel-outlining, \ + nvvm-attach-target{chip=${CUDA_COMPUTE_CAPACITY} O=3}, \ + strip-debuginfo, \ + gpu.module(convert-gpu-to-nvvm), \ + gpu-to-llvm, \ + reconcile-unrealized-casts, \ + gpu-module-to-binary \ + )" | \ + ${MLIR_CPU_RUNNER} -entry-point-result=void \ + -shared-libs=${MLIR_RUNNER_UTILS} -shared-libs=${MLIR_CUDA_RUNTIME} + +vector-create-mask-lower: + @${MLIR_OPT} ./vector-create-mask.mlir \ + --pass-pipeline="builtin.module( \ + convert-linalg-to-loops, \ + convert-vector-to-scf, \ + lower-affine, \ + convert-scf-to-cf, \ + convert-vector-to-llvm, \ + convert-arith-to-llvm, \ + convert-func-to-llvm, \ + gpu-kernel-outlining, \ + nvvm-attach-target{chip=${CUDA_COMPUTE_CAPACITY} O=3}, \ + strip-debuginfo, \ + gpu.module(convert-gpu-to-nvvm), \ + gpu-to-llvm, \ + reconcile-unrealized-casts, \ + gpu-module-to-binary \ + )" \ + -o log.mlir + +vector-create-mask-translate: + @${MLIR_OPT} ./vector-create-mask.mlir \ + --pass-pipeline="builtin.module( \ + convert-linalg-to-loops, \ + convert-vector-to-scf, \ + lower-affine, \ + convert-scf-to-cf, \ + convert-vector-to-llvm, \ + convert-arith-to-llvm, \ + convert-func-to-llvm, \ + gpu-kernel-outlining, \ + nvvm-attach-target{chip=${CUDA_COMPUTE_CAPACITY} O=3}, \ + strip-debuginfo, \ + gpu.module(convert-gpu-to-nvvm), \ + gpu-to-llvm, \ + reconcile-unrealized-casts, \ + gpu-module-to-binary \ + )" | \ + ${MLIR_TRANSLATE} --mlir-to-llvmir -o log.ll + +run-targets += vector-create-mask-run +vector-create-mask-run: + @${MLIR_OPT} ./vector-create-mask.mlir \ + --pass-pipeline="builtin.module( \ + convert-linalg-to-loops, \ + convert-vector-to-scf, \ + lower-affine, \ + convert-scf-to-cf, \ + convert-vector-to-llvm, \ + convert-arith-to-llvm, \ + convert-func-to-llvm, \ + gpu-kernel-outlining, \ + nvvm-attach-target{chip=${CUDA_COMPUTE_CAPACITY} O=3}, \ + strip-debuginfo, \ + gpu.module(convert-gpu-to-nvvm), \ + gpu-to-llvm, \ + reconcile-unrealized-casts, \ + gpu-module-to-binary \ + )" | \ + ${MLIR_CPU_RUNNER} -entry-point-result=void \ + -shared-libs=${MLIR_RUNNER_UTILS} -shared-libs=${MLIR_C_RUNNER_UTILS} -shared-libs=${MLIR_CUDA_RUNTIME} + +vector-extract-lower: + @${MLIR_OPT} ./vector-extract.mlir \ + --pass-pipeline="builtin.module( \ + convert-linalg-to-loops, \ + convert-vector-to-scf, \ + lower-affine, \ + convert-scf-to-cf, \ + convert-vector-to-llvm, \ + convert-arith-to-llvm, \ + convert-func-to-llvm, \ + gpu-kernel-outlining, \ + nvvm-attach-target{chip=${CUDA_COMPUTE_CAPACITY} O=3}, \ + strip-debuginfo, \ + gpu.module(convert-gpu-to-nvvm), \ + gpu-to-llvm, \ + reconcile-unrealized-casts, \ + gpu-module-to-binary \ + )" \ + -o log.mlir + +vector-extract-translate: + @${MLIR_OPT} ./vector-extract.mlir \ + --pass-pipeline="builtin.module( \ + convert-linalg-to-loops, \ + convert-vector-to-scf, \ + lower-affine, \ + convert-scf-to-cf, \ + convert-vector-to-llvm, \ + convert-arith-to-llvm, \ + convert-func-to-llvm, \ + gpu-kernel-outlining, \ + nvvm-attach-target{chip=${CUDA_COMPUTE_CAPACITY} O=3}, \ + strip-debuginfo, \ + gpu.module(convert-gpu-to-nvvm), \ + gpu-to-llvm, \ + reconcile-unrealized-casts, \ + gpu-module-to-binary \ + )" | \ + ${MLIR_TRANSLATE} --mlir-to-llvmir -o log.ll + +run-targets += vector-extract-run +vector-extract-run: + @${MLIR_OPT} ./vector-extract.mlir \ + --pass-pipeline="builtin.module( \ + convert-linalg-to-loops, \ + convert-vector-to-scf, \ + lower-affine, \ + convert-scf-to-cf, \ + convert-vector-to-llvm, \ + convert-arith-to-llvm, \ + convert-func-to-llvm, \ + gpu-kernel-outlining, \ + nvvm-attach-target{chip=${CUDA_COMPUTE_CAPACITY} O=3}, \ + strip-debuginfo, \ + gpu.module(convert-gpu-to-nvvm), \ + gpu-to-llvm, \ + reconcile-unrealized-casts, \ + gpu-module-to-binary \ + )" | \ + ${MLIR_CPU_RUNNER} -entry-point-result=void \ + -shared-libs=${MLIR_RUNNER_UTILS} -shared-libs=${MLIR_CUDA_RUNTIME} + +vector-fma-lower: + @${MLIR_OPT} ./vector-fma.mlir \ + --pass-pipeline="builtin.module( \ + convert-linalg-to-loops, \ + convert-vector-to-scf, \ + lower-affine, \ + convert-scf-to-cf, \ + convert-vector-to-llvm, \ + convert-arith-to-llvm, \ + convert-func-to-llvm, \ + gpu-kernel-outlining, \ + nvvm-attach-target{chip=${CUDA_COMPUTE_CAPACITY} O=3}, \ + strip-debuginfo, \ + gpu.module(convert-gpu-to-nvvm), \ + gpu-to-llvm, \ + reconcile-unrealized-casts, \ + gpu-module-to-binary \ + )" \ + -o log.mlir + +vector-fma-translate: + @${MLIR_OPT} ./vector-fma.mlir \ + --pass-pipeline="builtin.module( \ + convert-linalg-to-loops, \ + convert-vector-to-scf, \ + lower-affine, \ + convert-scf-to-cf, \ + convert-vector-to-llvm, \ + convert-arith-to-llvm, \ + convert-func-to-llvm, \ + gpu-kernel-outlining, \ + nvvm-attach-target{chip=${CUDA_COMPUTE_CAPACITY} O=3}, \ + strip-debuginfo, \ + gpu.module(convert-gpu-to-nvvm), \ + gpu-to-llvm, \ + reconcile-unrealized-casts, \ + gpu-module-to-binary \ + )" | \ + ${MLIR_TRANSLATE} --mlir-to-llvmir -o log.ll + +run-targets += vector-fma-run +vector-fma-run: + @${MLIR_OPT} ./vector-fma.mlir \ + --pass-pipeline="builtin.module( \ + convert-linalg-to-loops, \ + convert-vector-to-scf, \ + lower-affine, \ + convert-scf-to-cf, \ + convert-vector-to-llvm, \ + convert-arith-to-llvm, \ + convert-func-to-llvm, \ + gpu-kernel-outlining, \ + nvvm-attach-target{chip=${CUDA_COMPUTE_CAPACITY} O=3}, \ + strip-debuginfo, \ + gpu.module(convert-gpu-to-nvvm), \ + gpu-to-llvm, \ + reconcile-unrealized-casts, \ + gpu-module-to-binary \ + )" | \ + ${MLIR_CPU_RUNNER} -entry-point-result=void \ + -shared-libs=${MLIR_RUNNER_UTILS} -shared-libs=${MLIR_C_RUNNER_UTILS} -shared-libs=${MLIR_CUDA_RUNTIME} + +vector-splat-lower: + @${MLIR_OPT} ./vector-splat.mlir \ + --pass-pipeline="builtin.module( \ + convert-linalg-to-loops, \ + convert-vector-to-scf, \ + lower-affine, \ + convert-scf-to-cf, \ + convert-vector-to-llvm, \ + convert-arith-to-llvm, \ + convert-func-to-llvm, \ + gpu-kernel-outlining, \ + nvvm-attach-target{chip=${CUDA_COMPUTE_CAPACITY} O=3}, \ + strip-debuginfo, \ + gpu.module(convert-gpu-to-nvvm), \ + gpu-to-llvm, \ + reconcile-unrealized-casts, \ + gpu-module-to-binary \ + )" \ + -o log.mlir + +vector-splat-translate: + @${MLIR_OPT} ./vector-splat.mlir \ + --pass-pipeline="builtin.module( \ + convert-linalg-to-loops, \ + convert-vector-to-scf, \ + lower-affine, \ + convert-scf-to-cf, \ + convert-vector-to-llvm, \ + convert-arith-to-llvm, \ + convert-func-to-llvm, \ + gpu-kernel-outlining, \ + nvvm-attach-target{chip=${CUDA_COMPUTE_CAPACITY} O=3}, \ + strip-debuginfo, \ + gpu.module(convert-gpu-to-nvvm), \ + gpu-to-llvm, \ + reconcile-unrealized-casts, \ + gpu-module-to-binary \ + )" | \ + ${MLIR_TRANSLATE} --mlir-to-llvmir -o log.ll + +run-targets += vector-splat-run +vector-splat-run: + @${MLIR_OPT} ./vector-splat.mlir \ + --pass-pipeline="builtin.module( \ + convert-linalg-to-loops, \ + convert-vector-to-scf, \ + lower-affine, \ + convert-scf-to-cf, \ + convert-vector-to-llvm, \ + convert-arith-to-llvm, \ + convert-func-to-llvm, \ + gpu-kernel-outlining, \ + nvvm-attach-target{chip=${CUDA_COMPUTE_CAPACITY} O=3}, \ + strip-debuginfo, \ + gpu.module(convert-gpu-to-nvvm), \ + gpu-to-llvm, \ + reconcile-unrealized-casts, \ + gpu-module-to-binary \ + )" | \ + ${MLIR_CPU_RUNNER} -entry-point-result=void \ + -shared-libs=${MLIR_RUNNER_UTILS} -shared-libs=${MLIR_C_RUNNER_UTILS} -shared-libs=${MLIR_CUDA_RUNTIME} + +vector-gather-lower: + @${MLIR_OPT} ./vector-gather.mlir \ + --pass-pipeline="builtin.module( \ + convert-linalg-to-loops, \ + convert-vector-to-scf, \ + lower-affine, \ + convert-scf-to-cf, \ + convert-vector-to-llvm, \ + convert-arith-to-llvm, \ + convert-func-to-llvm, \ + gpu-kernel-outlining, \ + nvvm-attach-target{chip=${CUDA_COMPUTE_CAPACITY} O=3}, \ + strip-debuginfo, \ + gpu.module(convert-gpu-to-nvvm), \ + gpu-to-llvm, \ + reconcile-unrealized-casts, \ + gpu-module-to-binary \ + )" \ + -o log.mlir + +vector-gather-translate: + @${MLIR_OPT} ./vector-gather.mlir \ + --pass-pipeline="builtin.module( \ + convert-linalg-to-loops, \ + convert-vector-to-scf, \ + lower-affine, \ + convert-scf-to-cf, \ + convert-vector-to-llvm, \ + convert-arith-to-llvm, \ + convert-func-to-llvm, \ + gpu-kernel-outlining, \ + nvvm-attach-target{chip=${CUDA_COMPUTE_CAPACITY} O=3}, \ + strip-debuginfo, \ + gpu.module(convert-gpu-to-nvvm), \ + gpu-to-llvm, \ + reconcile-unrealized-casts, \ + gpu-module-to-binary \ + )" | \ + ${MLIR_TRANSLATE} --mlir-to-llvmir -o log.ll + +run-targets += vector-gather-run +vector-gather-run: + @${MLIR_OPT} ./vector-gather.mlir \ + --pass-pipeline="builtin.module( \ + convert-linalg-to-loops, \ + convert-vector-to-scf, \ + lower-affine, \ + convert-scf-to-cf, \ + convert-vector-to-llvm, \ + convert-arith-to-llvm, \ + convert-func-to-llvm, \ + gpu-kernel-outlining, \ + nvvm-attach-target{chip=${CUDA_COMPUTE_CAPACITY} O=3}, \ + strip-debuginfo, \ + gpu.module(convert-gpu-to-nvvm), \ + gpu-to-llvm, \ + reconcile-unrealized-casts, \ + gpu-module-to-binary \ + )" | \ + ${MLIR_CPU_RUNNER} -entry-point-result=void \ + -shared-libs=${MLIR_RUNNER_UTILS} -shared-libs=${MLIR_C_RUNNER_UTILS} -shared-libs=${MLIR_CUDA_RUNTIME} + +vector-insert-lower: + @${MLIR_OPT} ./vector-insert.mlir \ + --pass-pipeline="builtin.module( \ + convert-linalg-to-loops, \ + convert-vector-to-scf, \ + lower-affine, \ + convert-scf-to-cf, \ + convert-vector-to-llvm, \ + convert-arith-to-llvm, \ + convert-func-to-llvm, \ + gpu-kernel-outlining, \ + nvvm-attach-target{chip=${CUDA_COMPUTE_CAPACITY} O=3}, \ + strip-debuginfo, \ + gpu.module(convert-gpu-to-nvvm), \ + gpu-to-llvm, \ + reconcile-unrealized-casts, \ + gpu-module-to-binary \ + )" \ + -o log.mlir + +vector-insert-translate: + @${MLIR_OPT} ./vector-insert.mlir \ + --pass-pipeline="builtin.module( \ + convert-linalg-to-loops, \ + convert-vector-to-scf, \ + lower-affine, \ + convert-scf-to-cf, \ + convert-vector-to-llvm, \ + convert-arith-to-llvm, \ + convert-func-to-llvm, \ + gpu-kernel-outlining, \ + nvvm-attach-target{chip=${CUDA_COMPUTE_CAPACITY} O=3}, \ + strip-debuginfo, \ + gpu.module(convert-gpu-to-nvvm), \ + gpu-to-llvm, \ + reconcile-unrealized-casts, \ + gpu-module-to-binary \ + )" | \ + ${MLIR_TRANSLATE} --mlir-to-llvmir -o log.ll + +run-targets += vector-insert-run +vector-insert-run: + @${MLIR_OPT} ./vector-insert.mlir \ + --pass-pipeline="builtin.module( \ + convert-linalg-to-loops, \ + convert-vector-to-scf, \ + lower-affine, \ + convert-scf-to-cf, \ + convert-vector-to-llvm, \ + convert-arith-to-llvm, \ + convert-func-to-llvm, \ + gpu-kernel-outlining, \ + nvvm-attach-target{chip=${CUDA_COMPUTE_CAPACITY} O=3}, \ + strip-debuginfo, \ + gpu.module(convert-gpu-to-nvvm), \ + gpu-to-llvm, \ + reconcile-unrealized-casts, \ + gpu-module-to-binary \ + )" | \ + ${MLIR_CPU_RUNNER} -entry-point-result=void \ + -shared-libs=${MLIR_RUNNER_UTILS} -shared-libs=${MLIR_C_RUNNER_UTILS} -shared-libs=${MLIR_CUDA_RUNTIME} + +vector-transpose-lower: + @${MLIR_OPT} ./vector-transpose.mlir \ + --pass-pipeline="builtin.module( \ + convert-linalg-to-loops, \ + convert-vector-to-scf, \ + lower-affine, \ + convert-scf-to-cf, \ + convert-vector-to-llvm, \ + convert-arith-to-llvm, \ + convert-func-to-llvm, \ + gpu-kernel-outlining, \ + nvvm-attach-target{chip=${CUDA_COMPUTE_CAPACITY} O=3}, \ + strip-debuginfo, \ + gpu.module(convert-gpu-to-nvvm), \ + gpu-to-llvm, \ + reconcile-unrealized-casts, \ + gpu-module-to-binary \ + )" \ + -o log.mlir + +vector-transpose-translate: + @${MLIR_OPT} ./vector-transpose.mlir \ + --pass-pipeline="builtin.module( \ + convert-linalg-to-loops, \ + convert-vector-to-scf, \ + lower-affine, \ + convert-scf-to-cf, \ + convert-vector-to-llvm, \ + convert-arith-to-llvm, \ + convert-func-to-llvm, \ + gpu-kernel-outlining, \ + nvvm-attach-target{chip=${CUDA_COMPUTE_CAPACITY} O=3}, \ + strip-debuginfo, \ + gpu.module(convert-gpu-to-nvvm), \ + gpu-to-llvm, \ + reconcile-unrealized-casts, \ + gpu-module-to-binary \ + )" | \ + ${MLIR_TRANSLATE} --mlir-to-llvmir -o log.ll + +run-targets += vector-transpose-run +vector-transpose-run: + @${MLIR_OPT} ./vector-transpose.mlir \ + --pass-pipeline="builtin.module( \ + convert-linalg-to-loops, \ + convert-vector-to-scf, \ + lower-affine, \ + convert-scf-to-cf, \ + convert-vector-to-llvm, \ + convert-arith-to-llvm, \ + convert-func-to-llvm, \ + gpu-kernel-outlining, \ + nvvm-attach-target{chip=${CUDA_COMPUTE_CAPACITY} O=3}, \ + strip-debuginfo, \ + gpu.module(convert-gpu-to-nvvm), \ + gpu-to-llvm, \ + reconcile-unrealized-casts, \ + gpu-module-to-binary \ + )" | \ + ${MLIR_CPU_RUNNER} -entry-point-result=void \ + -shared-libs=${MLIR_RUNNER_UTILS} -shared-libs=${MLIR_C_RUNNER_UTILS} -shared-libs=${MLIR_CUDA_RUNTIME} + +vector-outerproduct-lower: + @${MLIR_OPT} ./vector-outerproduct.mlir \ + --pass-pipeline="builtin.module( \ + convert-linalg-to-loops, \ + convert-vector-to-scf, \ + lower-affine, \ + convert-scf-to-cf, \ + convert-vector-to-llvm, \ + convert-arith-to-llvm, \ + convert-func-to-llvm, \ + gpu-kernel-outlining, \ + nvvm-attach-target{chip=${CUDA_COMPUTE_CAPACITY} O=3}, \ + strip-debuginfo, \ + gpu.module(convert-gpu-to-nvvm), \ + gpu-to-llvm, \ + reconcile-unrealized-casts, \ + gpu-module-to-binary \ + )" \ + -o log.mlir + +vector-outerproduct-translate: + @${MLIR_OPT} ./vector-outerproduct.mlir \ + --pass-pipeline="builtin.module( \ + convert-linalg-to-loops, \ + convert-vector-to-scf, \ + lower-affine, \ + convert-scf-to-cf, \ + convert-vector-to-llvm, \ + convert-arith-to-llvm, \ + convert-func-to-llvm, \ + gpu-kernel-outlining, \ + nvvm-attach-target{chip=${CUDA_COMPUTE_CAPACITY} O=3}, \ + strip-debuginfo, \ + gpu.module(convert-gpu-to-nvvm), \ + gpu-to-llvm, \ + reconcile-unrealized-casts, \ + gpu-module-to-binary \ + )" | \ + ${MLIR_TRANSLATE} --mlir-to-llvmir -o log.ll + +run-targets += vector-outerproduct-run +vector-outerproduct-run: + @${MLIR_OPT} ./vector-outerproduct.mlir \ + --pass-pipeline="builtin.module( \ + convert-linalg-to-loops, \ + convert-vector-to-scf, \ + lower-affine, \ + convert-scf-to-cf, \ + convert-vector-to-llvm, \ + convert-arith-to-llvm, \ + convert-func-to-llvm, \ + gpu-kernel-outlining, \ + nvvm-attach-target{chip=${CUDA_COMPUTE_CAPACITY} O=3}, \ + strip-debuginfo, \ + gpu.module(convert-gpu-to-nvvm), \ + gpu-to-llvm, \ + reconcile-unrealized-casts, \ + gpu-module-to-binary \ + )" | \ + ${MLIR_CPU_RUNNER} -entry-point-result=void \ + -shared-libs=${MLIR_RUNNER_UTILS} -shared-libs=${MLIR_C_RUNNER_UTILS} -shared-libs=${MLIR_CUDA_RUNTIME} + +vector-reduction-lower: + @${MLIR_OPT} ./vector-reduction.mlir \ + --pass-pipeline="builtin.module( \ + convert-linalg-to-loops, \ + convert-vector-to-scf, \ + lower-affine, \ + convert-scf-to-cf, \ + convert-vector-to-llvm, \ + convert-arith-to-llvm, \ + convert-func-to-llvm, \ + gpu-kernel-outlining, \ + nvvm-attach-target{chip=${CUDA_COMPUTE_CAPACITY} O=3}, \ + strip-debuginfo, \ + gpu.module(convert-gpu-to-nvvm), \ + gpu-to-llvm, \ + reconcile-unrealized-casts, \ + gpu-module-to-binary \ + )" \ + -o log.mlir + +vector-reduction-translate: + @${MLIR_OPT} ./vector-reduction.mlir \ + --pass-pipeline="builtin.module( \ + convert-linalg-to-loops, \ + convert-vector-to-scf, \ + lower-affine, \ + convert-scf-to-cf, \ + convert-vector-to-llvm, \ + convert-arith-to-llvm, \ + convert-func-to-llvm, \ + gpu-kernel-outlining, \ + nvvm-attach-target{chip=${CUDA_COMPUTE_CAPACITY} O=3}, \ + strip-debuginfo, \ + gpu.module(convert-gpu-to-nvvm), \ + gpu-to-llvm, \ + reconcile-unrealized-casts, \ + gpu-module-to-binary \ + )" | \ + ${MLIR_TRANSLATE} --mlir-to-llvmir -o log.ll + +run-targets += vector-reduction-run +vector-reduction-run: + @${MLIR_OPT} ./vector-reduction.mlir \ + --pass-pipeline="builtin.module( \ + convert-linalg-to-loops, \ + convert-vector-to-scf, \ + lower-affine, \ + convert-scf-to-cf, \ + convert-vector-to-llvm, \ + convert-arith-to-llvm, \ + convert-func-to-llvm, \ + gpu-kernel-outlining, \ + nvvm-attach-target{chip=${CUDA_COMPUTE_CAPACITY} O=3}, \ + strip-debuginfo, \ + gpu.module(convert-gpu-to-nvvm), \ + gpu-to-llvm, \ + reconcile-unrealized-casts, \ + gpu-module-to-binary \ + )" | \ + ${MLIR_CPU_RUNNER} -entry-point-result=void \ + -shared-libs=${MLIR_RUNNER_UTILS} -shared-libs=${MLIR_CUDA_RUNTIME} + +vector-type-cast-lower: + @${MLIR_OPT} ./vector-type-cast.mlir \ + --pass-pipeline="builtin.module( \ + convert-linalg-to-loops, \ + convert-vector-to-scf, \ + lower-affine, \ + convert-scf-to-cf, \ + convert-vector-to-llvm, \ + convert-arith-to-llvm, \ + convert-func-to-llvm, \ + gpu-kernel-outlining, \ + nvvm-attach-target{chip=${CUDA_COMPUTE_CAPACITY} O=3}, \ + strip-debuginfo, \ + gpu.module(convert-gpu-to-nvvm), \ + gpu-to-llvm, \ + reconcile-unrealized-casts, \ + gpu-module-to-binary \ + )" \ + -o log.mlir + +vector-type-cast-translate: + @${MLIR_OPT} ./vector-type-cast.mlir \ + --pass-pipeline="builtin.module( \ + convert-linalg-to-loops, \ + convert-vector-to-scf, \ + lower-affine, \ + convert-scf-to-cf, \ + convert-vector-to-llvm, \ + convert-arith-to-llvm, \ + convert-func-to-llvm, \ + gpu-kernel-outlining, \ + nvvm-attach-target{chip=${CUDA_COMPUTE_CAPACITY} O=3}, \ + strip-debuginfo, \ + gpu.module(convert-gpu-to-nvvm), \ + gpu-to-llvm, \ + reconcile-unrealized-casts, \ + gpu-module-to-binary \ + )" | \ + ${MLIR_TRANSLATE} --mlir-to-llvmir -o log.ll + +run-targets += vector-type-cast-run +vector-type-cast-run: + @${MLIR_OPT} ./vector-type-cast.mlir \ + --pass-pipeline="builtin.module( \ + convert-linalg-to-loops, \ + convert-vector-to-scf, \ + lower-affine, \ + convert-scf-to-cf, \ + convert-vector-to-llvm, \ + convert-arith-to-llvm, \ + convert-func-to-llvm, \ + gpu-kernel-outlining, \ + nvvm-attach-target{chip=${CUDA_COMPUTE_CAPACITY} O=3}, \ + strip-debuginfo, \ + gpu.module(convert-gpu-to-nvvm), \ + gpu-to-llvm, \ + reconcile-unrealized-casts, \ + gpu-module-to-binary \ + )" | \ + ${MLIR_CPU_RUNNER} -entry-point-result=void \ + -shared-libs=${MLIR_RUNNER_UTILS} -shared-libs=${MLIR_C_RUNNER_UTILS} -shared-libs=${MLIR_CUDA_RUNTIME} + +vector-shape-cast-lower: + @${MLIR_OPT} ./vector-shape-cast.mlir \ + --pass-pipeline="builtin.module( \ + convert-linalg-to-loops, \ + convert-vector-to-scf, \ + lower-affine, \ + convert-scf-to-cf, \ + convert-vector-to-llvm, \ + convert-arith-to-llvm, \ + convert-func-to-llvm, \ + gpu-kernel-outlining, \ + nvvm-attach-target{chip=${CUDA_COMPUTE_CAPACITY} O=3}, \ + strip-debuginfo, \ + gpu.module(convert-gpu-to-nvvm), \ + gpu-to-llvm, \ + reconcile-unrealized-casts, \ + gpu-module-to-binary \ + )" \ + -o log.mlir + +vector-shape-cast-translate: + @${MLIR_OPT} ./vector-shape-cast.mlir \ + --pass-pipeline="builtin.module( \ + convert-linalg-to-loops, \ + convert-vector-to-scf, \ + lower-affine, \ + convert-scf-to-cf, \ + convert-vector-to-llvm, \ + convert-arith-to-llvm, \ + convert-func-to-llvm, \ + gpu-kernel-outlining, \ + nvvm-attach-target{chip=${CUDA_COMPUTE_CAPACITY} O=3}, \ + strip-debuginfo, \ + gpu.module(convert-gpu-to-nvvm), \ + gpu-to-llvm, \ + reconcile-unrealized-casts, \ + gpu-module-to-binary \ + )" | \ + ${MLIR_TRANSLATE} --mlir-to-llvmir -o log.ll + +run-targets += vector-shape-cast-run +vector-shape-cast-run: + @${MLIR_OPT} ./vector-shape-cast.mlir \ + --pass-pipeline="builtin.module( \ + convert-linalg-to-loops, \ + convert-vector-to-scf, \ + lower-affine, \ + convert-scf-to-cf, \ + convert-vector-to-llvm, \ + convert-arith-to-llvm, \ + convert-func-to-llvm, \ + gpu-kernel-outlining, \ + nvvm-attach-target{chip=${CUDA_COMPUTE_CAPACITY} O=3}, \ + strip-debuginfo, \ + gpu.module(convert-gpu-to-nvvm), \ + gpu-to-llvm, \ + reconcile-unrealized-casts, \ + gpu-module-to-binary \ + )" | \ + ${MLIR_CPU_RUNNER} -entry-point-result=void \ + -shared-libs=${MLIR_RUNNER_UTILS} -shared-libs=${MLIR_C_RUNNER_UTILS} -shared-libs=${MLIR_CUDA_RUNTIME} diff --git a/examples/MLIRVectorGPU/vector-bitcast.mlir b/examples/MLIRVectorGPU/vector-bitcast.mlir new file mode 100644 index 0000000000..f7f6c72c7b --- /dev/null +++ b/examples/MLIRVectorGPU/vector-bitcast.mlir @@ -0,0 +1,45 @@ +module attributes {gpu.container_module} { + gpu.module @kernels { + gpu.func @vector_bitcast(%ret0: memref<3xi64>, %ret1: memref<6xf32>, %ret2: memref<6xi32>) kernel { + %c0 = arith.constant 0 : index + %v0 = arith.constant dense<[10, 20, 56, 90, 12, 90]> : vector<6xi32> + %v1 = vector.bitcast %v0 : vector<6xi32> to vector<3xi64> + vector.store %v1, %ret0[%c0] : memref<3xi64>, vector<3xi64> + + %v2 = vector.bitcast %v0 : vector<6xi32> to vector<6xf32> + vector.store %v2, %ret1[%c0] : memref<6xf32>, vector<6xf32> + + %v3 = vector.bitcast %v2 : vector<6xf32> to vector<6xi32> + vector.store %v3, %ret2[%c0] : memref<6xi32>, vector<6xi32> + + gpu.return + } + } + + func.func @main() { + %c1 = arith.constant 1 : index + %kernel_ret0 = memref.alloc() : memref<3xi64> + %kernel_ret0_cast = memref.cast %kernel_ret0 : memref<3xi64> to memref<*xi64> + + %kernel_ret1 = memref.alloc() : memref<6xf32> + %kernel_ret1_cast = memref.cast %kernel_ret1 : memref<6xf32> to memref<*xf32> + + %kernel_ret2 = memref.alloc() : memref<6xi32> + %kernel_ret2_cast = memref.cast %kernel_ret2 : memref<6xi32> to memref<*xi32> + + gpu.host_register %kernel_ret0_cast : memref<*xi64> + gpu.host_register %kernel_ret1_cast : memref<*xf32> + gpu.host_register %kernel_ret2_cast : memref<*xi32> + gpu.launch_func @kernels::@vector_bitcast blocks in (%c1, %c1, %c1) threads in (%c1, %c1, %c1) args(%kernel_ret0 : memref<3xi64>, %kernel_ret1 : memref<6xf32>, %kernel_ret2 : memref<6xi32>) + + call @printMemrefI64(%kernel_ret0_cast) : (memref<*xi64>) -> () + call @printMemrefF32(%kernel_ret1_cast) : (memref<*xf32>) -> () + call @printMemrefI32(%kernel_ret2_cast) : (memref<*xi32>) -> () + + func.return + } + func.func private @printMemrefI64(%tpr : memref<*xi64>) + func.func private @printMemrefF32(%ptr : memref<*xf32>) + func.func private @printMemrefI32(%ptr : memref<*xi32>) + +} diff --git a/examples/MLIRVectorGPU/vector-compressstore.mlir b/examples/MLIRVectorGPU/vector-compressstore.mlir new file mode 100644 index 0000000000..598938a8b2 --- /dev/null +++ b/examples/MLIRVectorGPU/vector-compressstore.mlir @@ -0,0 +1,50 @@ +module attributes {gpu.container_module} { + gpu.module @kernels { + gpu.func @vector_compressstore(%base0 : memref<8xi32>, %base1 : memref<4x4xi32>) kernel { + + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c3 = arith.constant 3 : index + + // case 0 + %mask0 = arith.constant dense<[1, 0, 1]> : vector<3xi1> + %value0 = arith.constant dense<[100, 101, 102]> : vector<3xi32> + + vector.compressstore %base0[%c0], %mask0, %value0 : memref<8xi32>, vector<3xi1>, vector<3xi32> + + // case 1 + %base1_casted = memref.cast %base1 : memref<4x4xi32> to memref + %mask1 = arith.constant dense<[1, 0, 1, 1, 1, 1, 0, 0]> : vector<8xi1> + %value1 = arith.constant dense<[500, 501, 502, 503, 504, 505, 506, 507]> : vector<8xi32> + + vector.compressstore %base1_casted[%c3, %c1], %mask1, %value1 + : memref, vector<8xi1>, vector<8xi32> + + gpu.return + } + } + + memref.global "private" @gv0 : memref<8xi32> = dense<[0, 1, 2, 3, 4, 5, 6, 7]> + + memref.global "private" @gv1 : memref<4x4xi32> = dense<[[0, 1, 2, 3], + [4, 5, 6, 7], + [8, 9, 10, 11], + [12, 13, 14, 15]]> + + func.func @main() { + %A = memref.get_global @gv0 : memref<8xi32> + %B = memref.get_global @gv1 : memref<4x4xi32> + %A_cast = memref.cast %A : memref<8xi32> to memref<*xi32> + %B_cast = memref.cast %B : memref<4x4xi32> to memref<*xi32> + %c1 = arith.constant 1 : index + gpu.host_register %A_cast : memref<*xi32> + gpu.host_register %B_cast : memref<*xi32> + gpu.launch_func @kernels::@vector_compressstore blocks in (%c1, %c1, %c1) threads in (%c1, %c1, %c1) args(%A : memref<8xi32>, %B : memref<4x4xi32>) + + call @printMemrefI32(%A_cast) : (memref<*xi32>) -> () + call @printMemrefI32(%B_cast) : (memref<*xi32>) -> () + + func.return + } + func.func private @printMemrefI32(%ptr : memref<*xi32>) +} diff --git a/examples/MLIRVectorGPU/vector-constant-mask.mlir b/examples/MLIRVectorGPU/vector-constant-mask.mlir new file mode 100644 index 0000000000..a4225e5677 --- /dev/null +++ b/examples/MLIRVectorGPU/vector-constant-mask.mlir @@ -0,0 +1,31 @@ +module attributes {gpu.container_module} { + gpu.module @kernels { + gpu.func @vector_constant_mask(%result: memref<12xi1>) kernel { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c3 = arith.constant 3 : index + %c4 = arith.constant 4 : index + %mask0_vec = vector.constant_mask [3, 2] : vector<4x3xi1> + + %mask0_shape_casted = vector.shape_cast %mask0_vec : vector<4x3xi1> to vector<12xi1> + + vector.store %mask0_shape_casted, %result[%c0] : memref<12xi1>, vector<12xi1> + gpu.return + } + } + + func.func @main() { + %mask_created = memref.alloc() : memref<12xi1> + %mask_created_cast = memref.cast %mask_created : memref<12xi1> to memref<*xi1> + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + gpu.host_register %mask_created_cast : memref<*xi1> + gpu.launch_func @kernels::@vector_constant_mask blocks in (%c1, %c1, %c1) threads in (%c1, %c1, %c1) args(%mask_created : memref<12xi1>) + %mask_created_vec = vector.load %mask_created[%c0] : memref<12xi1>, vector<12xi1> + %mask_created_vec_reshape = vector.shape_cast %mask_created_vec : vector<12xi1> to vector<4x3xi1> + vector.print %mask_created_vec_reshape : vector<4x3xi1> + + func.return + } + func.func private @printMemrefI32(%ptr : memref<*xi1>) +} diff --git a/examples/MLIRVectorGPU/vector-contract.mlir b/examples/MLIRVectorGPU/vector-contract.mlir new file mode 100644 index 0000000000..03f4c86645 --- /dev/null +++ b/examples/MLIRVectorGPU/vector-contract.mlir @@ -0,0 +1,34 @@ +#map0 = affine_map<(i, j, k) -> (i, j)> +#map1 = affine_map<(i, j, k) -> (j, k)> +#map2 = affine_map<(i, j, k) -> (i, k)> + +module attributes {gpu.container_module} { + gpu.module @kernel { + gpu.func @vector_contract() kernel { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c3 = arith.constant 3 : index + %v0 = arith.constant dense<[[1., 2., 3., 4.], + [5., 6., 7., 8.], + [9., 10., 11., 12.]]> : vector<3x4xf32> + %v1 = arith.constant dense<[[1., 2., 3.], + [4., 5., 6.], + [7., 8., 9.], + [10., 11., 12.]]> : vector<4x3xf32> + %v2 = arith.constant dense<[[0., 0., 0.], + [0., 0., 0.], + [0., 0., 0.]]> : vector<3x3xf32> + %v3 = vector.contract {indexing_maps = [#map0, #map1, #map2], iterator_types = ["parallel", "reduction", "parallel"]} + %v0, %v1, %v2 : vector<3x4xf32>, vector<4x3xf32> into vector<3x3xf32> + // vector.store %v3, %result[] : memref>, vector<3x3xf32> + gpu.return + } + } + + func.func @main() { + %c1 = arith.constant 1 : index + gpu.launch_func @kernel::@vector_contract blocks in (%c1, %c1, %c1) threads in (%c1, %c1, %c1) args() + func.return + } + func.func private @printMemrefF32(%ptr : memref<*xvector<3x3xf32>>) +} diff --git a/examples/MLIRVectorGPU/vector-create-mask.mlir b/examples/MLIRVectorGPU/vector-create-mask.mlir new file mode 100644 index 0000000000..ac9885579a --- /dev/null +++ b/examples/MLIRVectorGPU/vector-create-mask.mlir @@ -0,0 +1,23 @@ +module attributes {gpu.container_module} { + gpu.module @kernels { + gpu.func @vector_create_mask(%result: memref<3xi1>) kernel { + %c0 = arith.constant 0 : index + %cons2 = arith.constant 2 : index + %mask0 = vector.create_mask %cons2 : vector<3xi1> + vector.store %mask0, %result[%c0] : memref<3xi1>, vector<3xi1> + gpu.return + } + } + func.func @main() { + %result = memref.alloc() : memref<3xi1> + %result_cast = memref.cast %result : memref<3xi1> to memref<*xi1> + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + gpu.host_register %result_cast : memref<*xi1> + gpu.launch_func @kernels::@vector_create_mask blocks in (%c1, %c1, %c1) threads in (%c1, %c1, %c1) args(%result : memref<3xi1>) + %result_v = vector.load %result[%c0] : memref<3xi1>, vector<3xi1> + vector.print %result_v : vector<3xi1> + + func.return + } +} diff --git a/examples/MLIRVectorGPU/vector-extract.mlir b/examples/MLIRVectorGPU/vector-extract.mlir new file mode 100644 index 0000000000..fbb5e510e7 --- /dev/null +++ b/examples/MLIRVectorGPU/vector-extract.mlir @@ -0,0 +1,26 @@ +module attributes {gpu.container_module} { + gpu.module @kernels { + gpu.func @vector_extract() kernel { + + %base = arith.constant dense<[[0, 1, 2], + [10, 11, 12], + [20, 21, 22]]> : vector<3x3xi32> + + %c0 = vector.extract %base[1, 1] : i32 from vector<3x3xi32> + gpu.printf "%d\n" %c0 : i32 + + %w1 = vector.extract %base[1] : vector<3xi32> from vector<3x3xi32> + %w1_0 = vector.extract %w1[0] : i32 from vector<3xi32> + %w1_1 = vector.extract %w1[1] : i32 from vector<3xi32> + %w1_2 = vector.extract %w1[2] : i32 from vector<3xi32> + gpu.printf "( %d, %d, %d )\n" %w1_0, %w1_1, %w1_2 : i32, i32, i32 + gpu.return + } + } + + func.func @main() { + %c1 = arith.constant 1 : index + gpu.launch_func @kernels::@vector_extract blocks in (%c1, %c1, %c1) threads in (%c1, %c1, %c1) args() + func.return + } +} diff --git a/examples/MLIRVectorGPU/vector-fma.mlir b/examples/MLIRVectorGPU/vector-fma.mlir new file mode 100644 index 0000000000..da40d84d6e --- /dev/null +++ b/examples/MLIRVectorGPU/vector-fma.mlir @@ -0,0 +1,32 @@ +module attributes {gpu.container_module} { + gpu.module @kernels { + memref.global "private" @gv : memref<4x4xf32> = dense<[[0. , 1. , 2. , 3. ], + [10., 11., 12., 13.], + [20., 21., 22., 23.], + [30., 31., 32., 33.]]> + gpu.func @vector_fma(%result: memref<4xf32>) kernel { + %mem = memref.get_global @gv : memref<4x4xf32> + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c2 = arith.constant 2 : index + %load_vec1 = vector.load %mem[%c0, %c0] : memref<4x4xf32>, vector<4xf32> + %load_vec2 = vector.load %mem[%c1, %c0] : memref<4x4xf32>, vector<4xf32> + %load_vec3 = vector.load %mem[%c2, %c0] : memref<4x4xf32>, vector<4xf32> + %res = vector.fma %load_vec1, %load_vec2, %load_vec3 : vector<4xf32> + vector.store %res, %result[%c0] : memref<4xf32>, vector<4xf32> + gpu.return + } + } + + func.func @main() { + %result = memref.alloc() : memref<4xf32> + %result_cast = memref.cast %result : memref<4xf32> to memref<*xf32> + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + gpu.host_register %result_cast : memref<*xf32> + gpu.launch_func @kernels::@vector_fma blocks in (%c1, %c1, %c1) threads in (%c1, %c1, %c1) args(%result : memref<4xf32>) + %result_v = vector.load %result[%c0] : memref<4xf32>, vector<4xf32> + vector.print %result_v : vector<4xf32> + func.return + } +} diff --git a/examples/MLIRVectorGPU/vector-gather.mlir b/examples/MLIRVectorGPU/vector-gather.mlir new file mode 100644 index 0000000000..bee5583fcb --- /dev/null +++ b/examples/MLIRVectorGPU/vector-gather.mlir @@ -0,0 +1,99 @@ +module attributes {gpu.container_module} { + gpu.module @kernels { + memref.global "private" @gv0 : memref<8xi32> = dense<[0, 1, 2, 3, 4, 5, 6, 7]> + + memref.global "private" @gv1 : memref<4x4xi32> = dense<[[0, 1, 2, 3], + [4, 5, 6, 7], + [8, 9, 10, 11], + [12, 13, 14, 15]]> + + memref.global "private" @gv2 : memref<8xi32> = dense<[0, 1, 2, 3, 4, 5, 6, 7]> + gpu.func @vector_gather(%result0: memref<4xi32>, %result1: memref<4xi32>, %result2: memref<4xi32>, %result3: memref<4xi32>) kernel { + + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c2 = arith.constant 2 : index + %c3 = arith.constant 3 : index + + %base0 = memref.get_global @gv0 : memref<8xi32> + %base1 = memref.get_global @gv1 : memref<4x4xi32> + %base2 = memref.get_global @gv2 : memref<8xi32> + + %pass_thru_4 = arith.constant dense<[2330, 2331, 2332, 2333]> : vector<4xi32> + %pass_thru_8 = arith.constant dense<[2330, 2331, 2332, 2333, 2334, 2335, 2336, 2337]> : vector<8xi32> + %pass_thru_2x2 = arith.constant dense<114> : vector<2x2xi32> + + // normal + %mask0 = arith.constant dense<1> : vector<4xi1> + %index0 = arith.constant dense<[3, 4, 2, 1]> : vector<4xi32> + %v0 = vector.gather %base0[%c0][%index0], %mask0, %pass_thru_4 + : memref<8xi32>, vector<4xi32>, vector<4xi1>, vector<4xi32> into vector<4xi32> + vector.store %v0, %result0[%c0] : memref<4xi32>, vector<4xi32> + + // with mask + %mask1 = arith.constant dense<[1, 0, 1, 0]> : vector<4xi1> + %index1 = arith.constant dense<[3, 4, 2, 1]> : vector<4xi32> + %v1 = vector.gather %base0[%c0][%index1], %mask1, %pass_thru_4 + : memref<8xi32>, vector<4xi32>, vector<4xi1>, vector<4xi32> into vector<4xi32> + vector.store %v1, %result1[%c0] : memref<4xi32>, vector<4xi32> + + %mask2 = arith.constant dense<1> : vector<2x2xi1> + %index2 = arith.constant dense<[[1, 0], [3, 2]]> : vector<2x2xi32> + %v2 = vector.gather %base1[%c1, %c1][%index2], %mask2, %pass_thru_2x2 + : memref<4x4xi32>, vector<2x2xi32>, vector<2x2xi1>, vector<2x2xi32> into vector<2x2xi32> + %v2_shape_casted = vector.shape_cast %v2 : vector<2x2xi32> to vector<4xi32> + vector.store %v2_shape_casted, %result2[%c0] : memref<4xi32>, vector<4xi32> + + %mask3 = arith.constant dense<1> : vector<2x2xi1> + %index3 = arith.constant dense<[[-1, -8], [5, 13]]> : vector<2x2xi32> + %v3 = vector.gather %base1[%c1, %c1][%index3], %mask3, %pass_thru_2x2 + : memref<4x4xi32>, vector<2x2xi32>, vector<2x2xi1>, vector<2x2xi32> into vector<2x2xi32> + + // ( ( 4, 0), ( 10, 0 ) ). + // On GPU, if indices are out-of-bound, the elements will be 0, which is different + // from the CPU case. + %v3_shape_casted = vector.shape_cast %v3 : vector<2x2xi32> to vector<4xi32> + vector.store %v3_shape_casted, %result3[%c0] : memref<4xi32>, vector<4xi32> + + gpu.return + } + } + + func.func @main() { + %result0 = memref.alloc() : memref<4xi32> + %result1 = memref.alloc() : memref<4xi32> + %result2 = memref.alloc() : memref<4xi32> + %result3 = memref.alloc() : memref<4xi32> + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + + // register host memory + %result0_cast = memref.cast %result0 : memref<4xi32> to memref<*xi32> + %result1_cast = memref.cast %result1 : memref<4xi32> to memref<*xi32> + %result2_cast = memref.cast %result2 : memref<4xi32> to memref<*xi32> + %result3_cast = memref.cast %result3 : memref<4xi32> to memref<*xi32> + + gpu.host_register %result0_cast : memref<*xi32> + gpu.host_register %result1_cast : memref<*xi32> + gpu.host_register %result2_cast : memref<*xi32> + gpu.host_register %result3_cast : memref<*xi32> + + gpu.launch_func @kernels::@vector_gather blocks in (%c1, %c1, %c1) threads in (%c1, %c1, %c1) args(%result0 : memref<4xi32>, %result1 : memref<4xi32>, %result2 : memref<4xi32>, %result3 : memref<4xi32>) + + %result0_v = vector.load %result0[%c0] : memref<4xi32>, vector<4xi32> + vector.print %result0_v : vector<4xi32> + + %result1_v = vector.load %result1[%c0] : memref<4xi32>, vector<4xi32> + vector.print %result1_v : vector<4xi32> + + %result2_v = vector.load %result2[%c0] : memref<4xi32>, vector<4xi32> + %result2_v_reshape = vector.shape_cast %result2_v : vector<4xi32> to vector<2x2xi32> + vector.print %result2_v_reshape : vector<2x2xi32> + + %result3_v = vector.load %result3[%c0] : memref<4xi32>, vector<4xi32> + %result3_v_reshape = vector.shape_cast %result3_v : vector<4xi32> to vector<2x2xi32> + vector.print %result3_v_reshape : vector<2x2xi32> + + func.return + } +} diff --git a/examples/MLIRVectorGPU/vector-insert.mlir b/examples/MLIRVectorGPU/vector-insert.mlir new file mode 100644 index 0000000000..8469c08aa7 --- /dev/null +++ b/examples/MLIRVectorGPU/vector-insert.mlir @@ -0,0 +1,72 @@ +module attributes {gpu.container_module} { + gpu.module @kernels { + gpu.func @vector_insert(%result0: memref<9xi32>, %result1: memref<9xi32>, %result2: memref<9xi32>, %the_base: memref<9xi32>) kernel{ + %base = arith.constant dense<[[0, 1, 2], + [10, 11, 12], + [20, 21, 22]]> : vector<3x3xi32> + // insert a scalar + %c0 = arith.constant 0 : index + %src0 = arith.constant 100 : i32 + %v0 = vector.insert %src0, %base[0, 0] : i32 into vector<3x3xi32> + %v0_shape_casted = vector.shape_cast %v0 : vector<3x3xi32> to vector<9xi32> + vector.store %v0_shape_casted, %result0[%c0] : memref<9xi32>, vector<9xi32> + + // insert a vector + %src1 = arith.constant dense<[101, 102, 103]> : vector<3xi32> + %v1 = vector.insert %src1, %base[1] : vector<3xi32> into vector<3x3xi32> + %v1_shape_casted = vector.shape_cast %v1 : vector<3x3xi32> to vector<9xi32> + vector.store %v1_shape_casted, %result1[%c0] : memref<9xi32>, vector<9xi32> + + // insert a vector with exactly the same rank + %src2 = arith.constant dense<[[201, 202, 203], + [211, 212, 213], + [221, 222, 223]]> : vector<3x3xi32> + %v2 = vector.insert %src2, %base[] : vector<3x3xi32> into vector<3x3xi32> + %v2_shape_casted = vector.shape_cast %v2 : vector<3x3xi32> to vector<9xi32> + vector.store %v2_shape_casted, %result2[%c0] : memref<9xi32>, vector<9xi32> + + %the_base_shape_casted = vector.shape_cast %base : vector<3x3xi32> to vector<9xi32> + vector.store %the_base_shape_casted, %the_base[%c0] : memref<9xi32>, vector<9xi32> + + gpu.return + } + } + + func.func @main() { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %result0 = memref.alloc() : memref<9xi32> + %result1 = memref.alloc() : memref<9xi32> + %result2 = memref.alloc() : memref<9xi32> + %the_base = memref.alloc() : memref<9xi32> + %result0_cast = memref.cast %result0 : memref<9xi32> to memref<*xi32> + %result1_cast = memref.cast %result1 : memref<9xi32> to memref<*xi32> + %result2_cast = memref.cast %result2 : memref<9xi32> to memref<*xi32> + %the_base_cast = memref.cast %the_base : memref<9xi32> to memref<*xi32> + + gpu.host_register %result0_cast : memref<*xi32> + gpu.host_register %result1_cast : memref<*xi32> + gpu.host_register %result2_cast : memref<*xi32> + gpu.host_register %the_base_cast : memref<*xi32> + + gpu.launch_func @kernels::@vector_insert blocks in (%c1, %c1, %c1) threads in (%c1, %c1, %c1) args(%result0 : memref<9xi32>, %result1 : memref<9xi32>, %result2 : memref<9xi32>, %the_base : memref<9xi32>) + + %result0_v = vector.load %result0[%c0] : memref<9xi32>, vector<9xi32> + %result0_v_reshape = vector.shape_cast %result0_v : vector<9xi32> to vector<3x3xi32> + vector.print %result0_v_reshape : vector<3x3xi32> + + %result1_v = vector.load %result1[%c0] : memref<9xi32>, vector<9xi32> + %result1_v_reshape = vector.shape_cast %result1_v : vector<9xi32> to vector<3x3xi32> + vector.print %result1_v_reshape : vector<3x3xi32> + + %result2_v = vector.load %result2[%c0] : memref<9xi32>, vector<9xi32> + %result2_v_reshape = vector.shape_cast %result2_v : vector<9xi32> to vector<3x3xi32> + vector.print %result2_v_reshape : vector<3x3xi32> + + %the_base_v = vector.load %the_base[%c0] : memref<9xi32>, vector<9xi32> + %the_base_v_reshape = vector.shape_cast %the_base_v : vector<9xi32> to vector<3x3xi32> + vector.print %the_base_v_reshape : vector<3x3xi32> + + func.return + } +} diff --git a/examples/MLIRVectorGPU/vector-load-store.mlir b/examples/MLIRVectorGPU/vector-load-store.mlir new file mode 100644 index 0000000000..5d5c9b6b10 --- /dev/null +++ b/examples/MLIRVectorGPU/vector-load-store.mlir @@ -0,0 +1,36 @@ +// RUN: buddy-opt %s \ +// RUN: -convert-vector-to-scf -lower-affine -convert-scf-to-cf \ +// RUN: -convert-vector-to-llvm -finalize-memref-to-llvm -convert-func-to-llvm \ +// RUN: -reconcile-unrealized-casts \ +// RUN: | mlir-cpu-runner -e main -entry-point-result=i32 \ +// RUN: -shared-libs=%mlir_runner_utils_dir/libmlir_runner_utils%shlibext \ +// RUN: -shared-libs=%mlir_runner_utils_dir/libmlir_c_runner_utils%shlibext \ +// RUN: | FileCheck %s + + +module attributes {gpu.container_module} { + gpu.module @kernels { + gpu.func @vector_load(%arg0: memref<8xf32>, %arg1: memref<3xf32>) kernel { + %c0 = arith.constant 0 : index + %v0 = vector.load %arg0[%c0] : memref<8xf32>, vector<3xf32> + vector.store %v0, %arg1[%c0] : memref<3xf32>, vector<3xf32> + gpu.return + } + } + memref.global "private" @gv : memref<8xf32> = dense<[0., 1., 2., 3., 4., 5., 6., 7.]> + func.func @main() { + %A = memref.get_global @gv : memref<8xf32> + %B = memref.alloc() : memref<3xf32> + %A_cast = memref.cast %A : memref<8xf32> to memref<*xf32> + %B_cast = memref.cast %B : memref<3xf32> to memref<*xf32> + %c1 = arith.constant 1 : index + gpu.host_register %A_cast : memref<*xf32> + gpu.host_register %B_cast : memref<*xf32> + gpu.launch_func @kernels::@vector_load blocks in (%c1, %c1, %c1) threads in (%c1, %c1, %c1) args(%A : memref<8xf32>, %B : memref<3xf32>) + + call @printMemrefF32(%B_cast) : (memref<*xf32>) -> () + + func.return + } + func.func private @printMemrefF32(%ptr : memref<*xf32>) +} diff --git a/examples/MLIRVectorGPU/vector-outerproduct.mlir b/examples/MLIRVectorGPU/vector-outerproduct.mlir new file mode 100644 index 0000000000..2d6b23f34b --- /dev/null +++ b/examples/MLIRVectorGPU/vector-outerproduct.mlir @@ -0,0 +1,41 @@ +module attributes {gpu.container_module} { + gpu.module @kernels { + gpu.func @vector_outerproduct(%result0: memref<9xi32>, %result1: memref<3xi32>) kernel { + %c0 = arith.constant 0 : index + %v0 = arith.constant dense<[1, 2, 3]> : vector<3xi32> + %v1 = arith.constant dense<[4, 5, 6]> : vector<3xi32> + %v0xv1 = vector.outerproduct %v0, %v1 : vector<3xi32>, vector<3xi32> + %v0xv1_shape_casted = vector.shape_cast %v0xv1 : vector<3x3xi32> to vector<9xi32> + vector.store %v0xv1_shape_casted, %result0[%c0] : memref<9xi32>, vector<9xi32> + + %s0 = arith.constant 3 : i32 + %s0xv0 = vector.outerproduct %v0, %s0 : vector<3xi32>, i32 + %s0xv0_shape_casted = vector.shape_cast %s0xv0 : vector<3xi32> to vector<3xi32> + vector.store %s0xv0_shape_casted, %result1[%c0] : memref<3xi32>, vector<3xi32> + + gpu.return + } + } + + func.func @main() { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %result0 = memref.alloc() : memref<9xi32> + %result0_cast = memref.cast %result0 : memref<9xi32> to memref<*xi32> + %result1 = memref.alloc() : memref<3xi32> + %result1_cast = memref.cast %result1 : memref<3xi32> to memref<*xi32> + gpu.host_register %result0_cast : memref<*xi32> + gpu.host_register %result1_cast : memref<*xi32> + + gpu.launch_func @kernels::@vector_outerproduct blocks in (%c1, %c1, %c1) threads in (%c1, %c1, %c1) args(%result0 : memref<9xi32>, %result1 : memref<3xi32>) + + %result0_v = vector.load %result0[%c0] : memref<9xi32>, vector<9xi32> + %result0_v_reshape = vector.shape_cast %result0_v : vector<9xi32> to vector<3x3xi32> + vector.print %result0_v_reshape : vector<3x3xi32> + + %result1_v = vector.load %result1[%c0] : memref<3xi32>, vector<3xi32> + vector.print %result1_v : vector<3xi32> + + func.return + } +} diff --git a/examples/MLIRVectorGPU/vector-reduction.mlir b/examples/MLIRVectorGPU/vector-reduction.mlir new file mode 100644 index 0000000000..442ba39ac3 --- /dev/null +++ b/examples/MLIRVectorGPU/vector-reduction.mlir @@ -0,0 +1,35 @@ +module attributes {gpu.container_module} { + gpu.module @kernels { + gpu.func @vector_reduction() kernel { + %v0 = arith.constant dense<[12, 13, 14, 15, 16, 90]> : vector<6xi32> + %sum = vector.reduction , %v0 : vector<6xi32> into i32 + gpu.printf "sum: %d\n" %sum : i32 + %mul = vector.reduction , %v0 : vector<6xi32> into i32 + gpu.printf "mul: %d\n" %mul : i32 + %xor = vector.reduction , %v0 : vector<6xi32> into i32 + gpu.printf "xor: %d\n" %xor : i32 + %and = vector.reduction , %v0 : vector<6xi32> into i32 + gpu.printf "and: %d\n" %and : i32 + %or = vector.reduction , %v0 : vector<6xi32> into i32 + gpu.printf "or: %d\n" %or : i32 + + %v1 = arith.constant dense<[1., 2., 3., 4., 5., 6.]> : vector<6xf32> + %sum_f = vector.reduction , %v1 : vector<6xf32> into f32 + gpu.printf "sum_f: %f\n" %sum_f : f32 + %mul_f = vector.reduction , %v1 : vector<6xf32> into f32 + gpu.printf "mul_f: %f\n" %mul_f : f32 + %min_f = vector.reduction , %v1 : vector<6xf32> into f32 + gpu.printf "min_f: %f\n" %min_f : f32 + %max_f = vector.reduction , %v1 : vector<6xf32> into f32 + gpu.printf "max_f: %f\n" %max_f : f32 + gpu.return + } + } + + func.func @main() { + %c1 = arith.constant 1 : index + gpu.launch_func @kernels::@vector_reduction blocks in (%c1, %c1, %c1) threads in (%c1, %c1, %c1) args() + + func.return + } +} diff --git a/examples/MLIRVectorGPU/vector-shape-cast.mlir b/examples/MLIRVectorGPU/vector-shape-cast.mlir new file mode 100644 index 0000000000..b2cd53d7e0 --- /dev/null +++ b/examples/MLIRVectorGPU/vector-shape-cast.mlir @@ -0,0 +1,26 @@ +module attributes {gpu.container_module} { + gpu.module @kernels { + gpu.func @vector_shape_cast(%result: memref<6xi32>) kernel { + %v0 = arith.constant dense<[[1, 2, 3], [4, 5, 6]]> + : vector<2x3xi32> + %v1 = vector.shape_cast %v0 : vector<2x3xi32> to vector<6xi32> + %c0 = arith.constant 0 : index + vector.store %v1, %result[%c0] : memref<6xi32>, vector<6xi32> + gpu.return + } + } + + func.func @main() { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %result = memref.alloc() : memref<6xi32> + %result_cast = memref.cast %result : memref<6xi32> to memref<*xi32> + + gpu.host_register %result_cast : memref<*xi32> + gpu.launch_func @kernels::@vector_shape_cast blocks in (%c1, %c1, %c1) threads in (%c1, %c1, %c1) args(%result : memref<6xi32>) + + %result_v = vector.load %result[%c0] : memref<6xi32>, vector<6xi32> + vector.print %result_v : vector<6xi32> + func.return + } +} diff --git a/examples/MLIRVectorGPU/vector-splat.mlir b/examples/MLIRVectorGPU/vector-splat.mlir new file mode 100644 index 0000000000..d6fc548a78 --- /dev/null +++ b/examples/MLIRVectorGPU/vector-splat.mlir @@ -0,0 +1,23 @@ +module attributes {gpu.container_module} { + gpu.module @kernels { + gpu.func @vector_splat(%result: memref<3xf32>) kernel { + %c0 = arith.constant 0 : index + %c10 = arith.constant 10.0 : f32 + %v1 = vector.splat %c10 : vector<3xf32> + vector.store %v1, %result[%c0] : memref<3xf32>, vector<3xf32> + gpu.return + } + } + + func.func @main() { + %result = memref.alloc() : memref<3xf32> + %result_cast = memref.cast %result : memref<3xf32> to memref<*xf32> + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + gpu.host_register %result_cast : memref<*xf32> + gpu.launch_func @kernels::@vector_splat blocks in (%c1, %c1, %c1) threads in (%c1, %c1, %c1) args(%result : memref<3xf32>) + %result_v = vector.load %result[%c0] : memref<3xf32>, vector<3xf32> + vector.print %result_v : vector<3xf32> + func.return + } +} diff --git a/examples/MLIRVectorGPU/vector-transpose.mlir b/examples/MLIRVectorGPU/vector-transpose.mlir new file mode 100644 index 0000000000..c46050a15f --- /dev/null +++ b/examples/MLIRVectorGPU/vector-transpose.mlir @@ -0,0 +1,61 @@ +module attributes {gpu.container_module} { + gpu.module @kernels { + memref.global "private" @gv0 : memref<3x3xf32> = dense<[[0. , 1. , 2. ], + [10., 11., 12.], + [20., 21., 22.]]> + memref.global "private" @gv1 : memref<3x4x5xf32> = dense<[[[0. , 1. , 2. , 3. , 4. ], + [10., 11., 12., 13., 14.], + [20., 21., 22., 23., 24.], + [30., 31., 32., 33., 34.]], + [[40., 41., 42., 43., 44.], + [50., 51., 52., 53., 54.], + [60., 61., 62., 63., 64.], + [70., 71., 72., 73., 74.]], + [[80., 81., 82., 83., 84.], + [90., 91., 92., 93., 94.], + [100., 101., 102., 103., 104.], + [110., 111., 112., 113., 114.]]]> + gpu.func @vector_transpose(%result0: memref<9xf32>, %result1: memref<60xf32>) kernel { + %c0 = arith.constant 0 : index + %f0 = arith.constant 0.0 : f32 + %mem0 = memref.get_global @gv0 : memref<3x3xf32> + %v0 = vector.transfer_read %mem0[%c0, %c0], %f0 + {permutation_map = affine_map<(d0, d1) -> (d0, d1)>} : memref<3x3xf32>, vector<3x3xf32> + %v0_transposed = vector.transpose %v0, [1, 0] : vector<3x3xf32> to vector<3x3xf32> + %v0_transposed_cast = vector.shape_cast %v0_transposed : vector<3x3xf32> to vector<9xf32> + vector.store %v0_transposed_cast, %result0[%c0] : memref<9xf32>, vector<9xf32> + + %mem1 = memref.get_global @gv1 : memref<3x4x5xf32> + %v1 = vector.transfer_read %mem1[%c0, %c0, %c0], %f0 + {permutation_map = affine_map<(d0, d1, d2) -> (d0, d1, d2)>} : memref<3x4x5xf32>, vector<3x4x5xf32> + %v1_transposed = vector.transpose %v1, [2, 0, 1] : vector<3x4x5xf32> to vector<5x3x4xf32> + %v1_transposed_cast = vector.shape_cast %v1_transposed : vector<5x3x4xf32> to vector<60xf32> + vector.store %v1_transposed_cast, %result1[%c0] : memref<60xf32>, vector<60xf32> + + gpu.return + } + } + + func.func @main() { + %result0 = memref.alloc() : memref<9xf32> + %result0_cast = memref.cast %result0 : memref<9xf32> to memref<*xf32> + %result1 = memref.alloc() : memref<60xf32> + %result1_cast = memref.cast %result1 : memref<60xf32> to memref<*xf32> + + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + gpu.host_register %result0_cast : memref<*xf32> + gpu.host_register %result1_cast : memref<*xf32> + gpu.launch_func @kernels::@vector_transpose blocks in (%c1, %c1, %c1) threads in (%c1, %c1, %c1) args(%result0 : memref<9xf32>, %result1 : memref<60xf32>) + + %result0_v = vector.load %result0[%c0] : memref<9xf32>, vector<9xf32> + %result0_v_reshape = vector.shape_cast %result0_v : vector<9xf32> to vector<3x3xf32> + + %result1_v = vector.load %result1[%c0] : memref<60xf32>, vector<60xf32> + %result1_v_reshape = vector.shape_cast %result1_v : vector<60xf32> to vector<5x3x4xf32> + + vector.print %result0_v_reshape : vector<3x3xf32> + vector.print %result1_v_reshape : vector<5x3x4xf32> + func.return + } +} diff --git a/examples/MLIRVectorGPU/vector-type-cast.mlir b/examples/MLIRVectorGPU/vector-type-cast.mlir new file mode 100644 index 0000000000..d72c65bbbd --- /dev/null +++ b/examples/MLIRVectorGPU/vector-type-cast.mlir @@ -0,0 +1,32 @@ +module attributes {gpu.container_module} { + gpu.module @kernels { + memref.global "private" @gv0 : memref<2x4xi32> = dense<[[1, 2, 3, 4], [5, 6, 7, 8]]> + gpu.func @vector_type_cast(%result : memref<8xi32>) kernel { + %mem0 = memref.get_global @gv0 : memref<2x4xi32> + %m0 = vector.type_cast %mem0 : memref<2x4xi32> to memref> + %v0 = memref.load %m0[] : memref> + %v0_reshape = vector.shape_cast %v0 : vector<2x4xi32> to vector<8xi32> + %c0 = arith.constant 0 : index + vector.store %v0_reshape, %result[%c0] : memref<8xi32>, vector<8xi32> + gpu.return + } + } + + func.func @main() { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + + %result = memref.alloc() : memref<8xi32> + %result_cast = memref.cast %result : memref<8xi32> to memref<*xi32> + + gpu.host_register %result_cast : memref<*xi32> + + gpu.launch_func @kernels::@vector_type_cast blocks in (%c1, %c1, %c1) threads in (%c1, %c1, %c1) args(%result : memref<8xi32>) + + %result_v = vector.load %result[%c0] : memref<8xi32>, vector<8xi32> + %result_v_reshape = vector.shape_cast %result_v : vector<8xi32> to vector<2x4xi32> + vector.print %result_v_reshape : vector<2x4xi32> + + func.return + } +} diff --git a/examples/lit.cfg.py b/examples/lit.cfg.py index 5988051497..98b2d8634a 100644 --- a/examples/lit.cfg.py +++ b/examples/lit.cfg.py @@ -58,6 +58,7 @@ 'MLIRSparseTensor', 'MLIRTOSA', 'MLIRTransform', + 'MLIRVectorGPU', 'Pooling', 'RISCVBuddyExt', 'RVVDialect', From 9e9ea472c797bb327f2a7bc29d33a55f68f63301 Mon Sep 17 00:00:00 2001 From: Q Liu <52538137+LIUQyou@users.noreply.github.com> Date: Mon, 6 Jan 2025 11:53:15 +0800 Subject: [PATCH 7/8] [format] Add pre-commit check (#438) * pre-commit check add * update README and precommit check * update README and precommit check and pre-commit install * update README and precommit check and pre-commit install --------- Co-authored-by: LiuQun --- .pre-commit-config.yaml | 59 +++++++++++++++++++++++++++++++++++++++++ README.md | 6 +++++ requirements.txt | 2 +- 3 files changed, 66 insertions(+), 1 deletion(-) create mode 100644 .pre-commit-config.yaml diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml new file mode 100644 index 0000000000..0f65b75427 --- /dev/null +++ b/.pre-commit-config.yaml @@ -0,0 +1,59 @@ +repos: + - repo: https://github.com/pre-commit/mirrors-clang-format + rev: v19.1.3 # Use the version of clang-format you have installed + hooks: + - id: clang-format + name: clang-format C++ code + files: \.(cpp|hpp|cc|cxx|h|c|hxx)$ + args: [--style=llvm] # You can set your preferred style here + + # - repo: https://github.com/pocc/pre-commit-hooks + # rev: v1.3.5 # Use the latest stable version + # hooks: + # - id: clang-tidy + # name: clang-tidy C++ code + # files: \.(cpp|hpp|cc|cxx|h|c|hxx)$ + # args: + # - --quiet + # - --checks=*,-clang-diagnostic-*,-clang-analyzer-* + # - --extra-arg=-x + # - --extra-arg=c++ + # - --extra-arg=-std=c++17 + # - --warnings-as-errors=* + # - --extra-arg=-I. + + - repo: https://github.com/psf/black + rev: 24.10.0 # Use the latest stable version + hooks: + - id: black + language_version: python3.10 + + # a comprehensive tool for checking the style and quality of Python code. + # It combines three popular Python tools: + # PyFlakes: Checks for logical errors in the code. + # pycodestyle (formerly known as pep8): Checks for adherence to the PEP 8 style guide. + # McCabe Complexity Checker: Measures the complexity of your code. + - repo: https://github.com/PyCQA/flake8 + rev: 6.1.0 # Set the rev to match the desired flake8 version + hooks: + - id: flake8 + args: + - --max-line-length=88 # Adjust as per your style guide + - --ignore=F821,F403,F405,F401,W503,E203,E402,E401,W605,E712,E711,F841 + # W503: Line break before binary operator + # E203: Whitespace before colon + # E402: Module level import not at top of file + # E401: Multiple imports on one line + # W605: Invalid escape sequence + # E712: Comparison to True/False + # E711: Comparison to None + # F841: Unused variable + + - repo: https://github.com/pre-commit/pre-commit-hooks + rev: v5.0.0 # Updated to the latest version + hooks: + - id: end-of-file-fixer + - id: trailing-whitespace + - id: check-merge-conflict + - id: check-yaml + - id: check-added-large-files diff --git a/README.md b/README.md index 2e44658b02..771a0a3175 100644 --- a/README.md +++ b/README.md @@ -195,6 +195,12 @@ This program should be a drop-in replacement for `mlir-lsp-server`, supporting n After modification, your editor should have correct completion and error prompts for new dialects such as `rvv` and `gemmini`. +### pre-commit checks + +The .pre-commit-config.yaml file checks code format and style on each commit, using tools such as clang-format, black, and flake8. You can also run these checks without committing by using "pre-commit run --all-files". This ensures consistent coding standards and prevents common errors before pushing changes. + +To get started, you should install pre-commit (e.g., pip install pre-commit) and verify that clang-format, black, and flake8 are available. On Linux, you can use your package manager for clang-format, and pip for Python tools. If you need to revert any unwanted formatting changes, you can use "git stash" or "git restore ." (for all files) or "git restore " (for a specific file), or revert the commit through your Git history. + ## Examples The purpose of the examples is to give users a better understanding of how to use the passes and the interfaces in buddy-mlir. Currently, we provide three types of examples. diff --git a/requirements.txt b/requirements.txt index 55726a11f8..28586a4356 100644 --- a/requirements.txt +++ b/requirements.txt @@ -16,4 +16,4 @@ PyYAML certifi idna diffusers - +pre-commit From 91bbd57abb93457e71496221f9226cff96df96b4 Mon Sep 17 00:00:00 2001 From: zhanghb97 Date: Tue, 7 Jan 2025 04:06:12 +0000 Subject: [PATCH 8/8] [examples] Set an example for vector dialect to RVV asm. --- examples/MLIRVector/makefile | 37 +++++++++++++++++++++------- examples/MLIRVector/vector-load.mlir | 26 ++++++++++--------- 2 files changed, 43 insertions(+), 20 deletions(-) diff --git a/examples/MLIRVector/makefile b/examples/MLIRVector/makefile index ccc9e9af24..a3acbc74ab 100644 --- a/examples/MLIRVector/makefile +++ b/examples/MLIRVector/makefile @@ -1,10 +1,17 @@ #!/bin/bash +MLIR_BUILD_DIR := ../../llvm/build/ +BUDDY_MLIR_BUILD_DIR := ../../build/ BUDDY_OPT := ../../build/bin/buddy-opt MLIR_OPT := ../../llvm/build/bin/mlir-opt MLIR_TRANSLATE := ../../llvm/build/bin/mlir-translate MLIR_CPU_RUNNER := ../../llvm/build/bin/mlir-cpu-runner LLC := ../../llvm/build/bin/llc OPT_FLAG := -O0 +LOCAL_CLANG := ../../llvm/build/bin/clang + +# RISC-V GNU Toolchain +RISCV_GNU_TOOLCHAIN := ${BUDDY_MLIR_BUILD_DIR}/thirdparty/riscv-gnu-toolchain +RISCV_GNU_TOOLCHAIN_SYSROOT := ${RISCV_GNU_TOOLCHAIN}/sysroot ifeq ($(shell uname),Linux) MLIR_RUNNER_UTILS := ../../llvm/build/lib/libmlir_runner_utils.so @@ -32,6 +39,18 @@ vector-load-translate: --reconcile-unrealized-casts | \ ${MLIR_TRANSLATE} --mlir-to-llvmir -o log.ll +vector-load-asm-rvv: + @${MLIR_OPT} ./vector-load.mlir \ + --convert-vector-to-scf --lower-affine --convert-scf-to-cf \ + --convert-vector-to-llvm --finalize-memref-to-llvm --convert-func-to-llvm \ + --reconcile-unrealized-casts | \ + ${MLIR_TRANSLATE} --mlir-to-llvmir -o log.ll + @${LOCAL_CLANG} -c log.ll \ + -march=rv64gcv --target=riscv64-unknown-linux-gnu \ + --sysroot=${RISCV_GNU_TOOLCHAIN_SYSROOT} --gcc-toolchain=${RISCV_GNU_TOOLCHAIN} \ + -fno-inline -O3 -S \ + -o log.s + run-targets += vector-load-run vector-load-run: @${MLIR_OPT} ./vector-load.mlir \ @@ -298,7 +317,7 @@ vector-splat-run: -split-input-file -verify-diagnostics \ --reconcile-unrealized-casts | \ ${MLIR_CPU_RUNNER} ${OPT_FLAG} -e main -entry-point-result=i32 \ - -shared-libs=${MLIR_RUNNER_UTILS} -shared-libs=${MLIR_C_RUNNER_UTILS} + -shared-libs=${MLIR_RUNNER_UTILS} -shared-libs=${MLIR_C_RUNNER_UTILS} vector-insert-lower: @${MLIR_OPT} ./vector-insert.mlir \ @@ -321,8 +340,8 @@ vector-insert-run: -split-input-file -verify-diagnostics \ --reconcile-unrealized-casts | \ ${MLIR_CPU_RUNNER} ${OPT_FLAG} -e main -entry-point-result=i32 \ - -shared-libs=${MLIR_RUNNER_UTILS} -shared-libs=${MLIR_C_RUNNER_UTILS} - + -shared-libs=${MLIR_RUNNER_UTILS} -shared-libs=${MLIR_C_RUNNER_UTILS} + vector-reduction-lower: @${MLIR_OPT} ./vector-reduction.mlir \ --convert-vector-to-scf --lower-affine --convert-scf-to-cf \ @@ -344,7 +363,7 @@ vector-reduction-run: -split-input-file -verify-diagnostics \ --reconcile-unrealized-casts | \ ${MLIR_CPU_RUNNER} ${OPT_FLAG} -e main -entry-point-result=i32 \ - -shared-libs=${MLIR_RUNNER_UTILS} -shared-libs=${MLIR_C_RUNNER_UTILS} + -shared-libs=${MLIR_RUNNER_UTILS} -shared-libs=${MLIR_C_RUNNER_UTILS} vector-outerproduct-lower: @${MLIR_OPT} ./vector-outerproduct.mlir \ @@ -367,7 +386,7 @@ vector-outerproduct-run: -split-input-file -verify-diagnostics \ --reconcile-unrealized-casts | \ ${MLIR_CPU_RUNNER} ${OPT_FLAG} -e main -entry-point-result=i32 \ - -shared-libs=${MLIR_RUNNER_UTILS} -shared-libs=${MLIR_C_RUNNER_UTILS} + -shared-libs=${MLIR_RUNNER_UTILS} -shared-libs=${MLIR_C_RUNNER_UTILS} vector-create-mask-lower: @${MLIR_OPT} ./vector-create-mask.mlir \ @@ -389,7 +408,7 @@ vector-create-mask-run: --convert-vector-to-llvm --finalize-memref-to-llvm --convert-func-to-llvm \ --reconcile-unrealized-casts | \ ${MLIR_CPU_RUNNER} ${OPT_FLAG} -e main -entry-point-result=i32 \ - -shared-libs=${MLIR_RUNNER_UTILS} -shared-libs=${MLIR_C_RUNNER_UTILS} + -shared-libs=${MLIR_RUNNER_UTILS} -shared-libs=${MLIR_C_RUNNER_UTILS} vector-extract-lower: @${MLIR_OPT} ./vector-extract.mlir \ @@ -502,7 +521,7 @@ vector-constant-mask-run: --reconcile-unrealized-casts | \ ${MLIR_CPU_RUNNER} ${OPT_FLAG} -e main -entry-point-result=i32 \ -shared-libs=${MLIR_RUNNER_UTILS} -shared-libs=${MLIR_C_RUNNER_UTILS} - + vector-expandload-lower: @${MLIR_OPT} ./vector-expandload.mlir \ --convert-vector-to-scf --lower-affine --convert-scf-to-cf \ @@ -523,7 +542,7 @@ vector-expandload-run: --convert-vector-to-llvm --finalize-memref-to-llvm --convert-func-to-llvm \ --reconcile-unrealized-casts | \ ${MLIR_CPU_RUNNER} ${OPT_FLAG} -e main -entry-point-result=i32 \ - -shared-libs=${MLIR_RUNNER_UTILS} -shared-libs=${MLIR_C_RUNNER_UTILS} + -shared-libs=${MLIR_RUNNER_UTILS} -shared-libs=${MLIR_C_RUNNER_UTILS} vector-compressstore-lower: @${MLIR_OPT} ./vector-compressstore.mlir \ @@ -567,7 +586,7 @@ vector-insert-strided-slice-run: --convert-vector-to-llvm --finalize-memref-to-llvm --convert-func-to-llvm \ --reconcile-unrealized-casts | \ ${MLIR_CPU_RUNNER} ${OPT_FLAG} -e main -entry-point-result=i32 \ - -shared-libs=${MLIR_RUNNER_UTILS} -shared-libs=${MLIR_C_RUNNER_UTILS} + -shared-libs=${MLIR_RUNNER_UTILS} -shared-libs=${MLIR_C_RUNNER_UTILS} vector-scatter-lower: @${MLIR_OPT} ./vector-scatter.mlir \ diff --git a/examples/MLIRVector/vector-load.mlir b/examples/MLIRVector/vector-load.mlir index a9d80fd044..47af66020e 100644 --- a/examples/MLIRVector/vector-load.mlir +++ b/examples/MLIRVector/vector-load.mlir @@ -16,6 +16,15 @@ memref.global "private" @gv1 : memref<4x4xi32> = dense<[[0, 1, 2, 3], memref.global "private" @gv2 : memref<8xi32> = dense<[0, 1, 2, 3, 4, 5, 6, 7]> +func.func @kernel_1(%arg0: memref<8xi32>) { + %c0 = arith.constant 0 : index + // load normal usage + %v0 = vector.load %arg0[%c0] : memref<8xi32>, vector<3xi32> + // CHECK: ( 0, 1, 2 ) + vector.print %v0 : vector<3xi32> + return +} + func.func @main() -> i32 { // vector.load can load n-D vector from m-D scalar memref or k-D vector memref @@ -30,12 +39,7 @@ func.func @main() -> i32 { %base1 = memref.get_global @gv1 : memref<4x4xi32> %base2 = memref.get_global @gv2 : memref<8xi32> - - // load normal usage - %v0 = vector.load %base0[%c0] : memref<8xi32>, vector<3xi32> - // CHECK: ( 0, 1, 2 ) - vector.print %v0 : vector<3xi32> - + call @kernel_1(%base0) : (memref<8xi32>) -> () // load with m-D memref // case 1: inside inner-most dimension @@ -82,14 +86,14 @@ func.func @main() -> i32 { %v5 = vector.load %base5[%c1, %c1] : memref, vector<8xi32> // ( 5, 6, 7, 8, 9, 10, 11, 12 ) vector.print %v5 : vector<8xi32> - + // load with dynamic memref // case 2: out of bound // The document says: - // Representation-wise, the ‘vector.load’ operation permits out-of-bounds reads. - // Support and implementation of out-of-bounds vector loads is target-specific. - // No assumptions should be made on the value of elements loaded out of bounds. + // Representation-wise, the ‘vector.load’ operation permits out-of-bounds reads. + // Support and implementation of out-of-bounds vector loads is target-specific. + // No assumptions should be made on the value of elements loaded out of bounds. // Not all targets may support out-of-bounds vector loads. %v6 = vector.load %base5[%c3, %c1] : memref, vector<8xi32> // ( 13, 14, 15, 0, 1, 2, 3, 4 ) @@ -98,7 +102,7 @@ func.func @main() -> i32 { // load with unranked memref is not allowed %base6 = memref.cast %base1 : memref<4x4xi32> to memref<*xi32> - // %v7 = vector.load %base6[%c0, %c0] : memref<*xi32>, vector<8xi32> + // %v7 = vector.load %base6[%c0, %c0] : memref<*xi32>, vector<8xi32> %ret = arith.constant 0 : i32 return %ret : i32