From a23e3d4f6eacb0b815193be04b9eef552d870d16 Mon Sep 17 00:00:00 2001 From: Chen Weiwei <100988241+FloatingcloudKnight@users.noreply.github.com> Date: Fri, 27 Dec 2024 17:17:24 +0800 Subject: [PATCH] [midend] Add poolingnhwcmax vectorization pass, examples, and tests. (#430) --- examples/BuddyNext/makefile | 13 + examples/BuddyNext/pooling-nhwc-max-vec.mlir | 132 +++++++ .../MLIRLinalg/linalg-pooling-nhwc-max.mlir | 97 +++++ examples/MLIRLinalg/makefile | 36 ++ .../ConvVectorization/CMakeLists.txt | 1 + .../PoolingNhwcMaxVectorization.cpp | 368 ++++++++++++++++++ .../pooling-nhwc-max-vectorization.mlir | 32 ++ tools/buddy-opt/buddy-opt.cpp | 3 + 8 files changed, 682 insertions(+) create mode 100644 examples/BuddyNext/pooling-nhwc-max-vec.mlir create mode 100644 examples/MLIRLinalg/linalg-pooling-nhwc-max.mlir create mode 100644 midend/lib/Conversion/ConvVectorization/PoolingNhwcMaxVectorization.cpp create mode 100644 tests/Conversion/pooling-nhwc-max-vectorization.mlir diff --git a/examples/BuddyNext/makefile b/examples/BuddyNext/makefile index 83b2364d03..3ee282499c 100644 --- a/examples/BuddyNext/makefile +++ b/examples/BuddyNext/makefile @@ -701,3 +701,16 @@ next-ffn-parallel-vec-run: -reconcile-unrealized-casts | \ ${MLIR_CPU_RUNNER} ${OPT_FLAG} -e main -entry-point-result=void \ -shared-libs=${MLIR_RUNNER_UTILS} -shared-libs=${MLIR_C_RUNNER_UTILS} + +pooling-nhwc-max-vec-run: + @${BUDDY_OPT} ./pooling-nhwc-max-vec.mlir \ + -convert-linalg-to-loops \ + -lower-affine \ + -convert-scf-to-cf \ + -convert-vector-to-llvm \ + -finalize-memref-to-llvm \ + -convert-arith-to-llvm \ + -convert-func-to-llvm \ + -reconcile-unrealized-casts | \ + ${MLIR_CPU_RUNNER} ${OPT_FLAG} -e main -entry-point-result=void \ + -shared-libs=${MLIR_RUNNER_UTILS} -shared-libs=${MLIR_C_RUNNER_UTILS} diff --git a/examples/BuddyNext/pooling-nhwc-max-vec.mlir b/examples/BuddyNext/pooling-nhwc-max-vec.mlir new file mode 100644 index 0000000000..d6d8a35d14 --- /dev/null +++ b/examples/BuddyNext/pooling-nhwc-max-vec.mlir @@ -0,0 +1,132 @@ +// RUN: buddy-opt %s \ +// RUN: -convert-linalg-to-loops \ +// RUN: -convert-vector-to-scf \ +// RUN: -lower-affine \ +// RUN: -convert-scf-to-cf \ +// RUN: -convert-vector-to-llvm \ +// RUN: -finalize-memref-to-llvm \ +// RUN: -convert-func-to-llvm \ +// RUN: -reconcile-unrealized-casts \ +// RUN: | mlir-cpu-runner -e main -entry-point-result=void \ +// RUN: -shared-libs=%mlir_runner_utils_dir/libmlir_runner_utils%shlibext \ +// RUN: -shared-libs=%mlir_runner_utils_dir/libmlir_c_runner_utils%shlibext \ +// RUN: | FileCheck %s + +#map = affine_map<(d0) -> (d0)> +#map1 = affine_map<(d0, d1) -> (d0 + d1)> + +module { + func.func private @rtclock() -> f64 + func.func private @printMemrefF32(memref<*xf32>) + func.func @pooling_nhwc_max(%arg0: memref, %arg1: memref, %arg2: memref) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c2 = arith.constant 2 : index + %c3 = arith.constant 3 : index + %vl_step = arith.constant 32 : index + %c0_f32 = arith.constant 0.000000e+00 : f32 + %0 = vector.splat %c0_f32 : vector<32xf32> + %dim = memref.dim %arg1, %c0 : memref + %dim_0 = memref.dim %arg1, %c1 : memref + %dim_1 = memref.dim %arg2, %c0 : memref + %dim_2 = memref.dim %arg2, %c1 : memref + %dim_3 = memref.dim %arg2, %c2 : memref + %dim_4 = memref.dim %arg2, %c3 : memref + + // Calculate the upper bound for vectorized processing + // - Subtract `vl_step` is to avoid overflow at the vectorization tail. + // - Add 1 to ensure the final loop runs when the workload length + // is divisible by the vector size. + %dim_4_upbound_tmp = arith.subi %dim_4, %vl_step : index + %dim_4_upbound = arith.addi %dim_4_upbound_tmp, %c1 : index + + %t_start = call @rtclock() : () -> f64 + affine.for %arg3 = #map(%c0) to #map(%dim_1) { + affine.for %arg4 = #map(%c0) to #map(%dim_2) { + affine.for %arg5 = #map(%c0) to #map(%dim_3) { + // Perform the vectorization body. + %iter_idx = scf.for %arg6 = %c0 to %dim_4_upbound + step %vl_step iter_args(%iter_init = %c0) -> (index) { // N + %4 = vector.load %arg2[%arg3, %arg4, %arg5, %arg6] : memref, vector<32xf32> + %5 = affine.for %arg7 = #map(%c0) to #map(%dim) iter_args(%arg8 = %4) -> (vector<32xf32>) { + %6 = affine.for %arg9 = #map(%c0) to #map(%dim_0) iter_args(%arg10 = %arg8) -> (vector<32xf32>) { + %in_iter_h = affine.apply #map1 (%arg7, %arg4) + %in_iter_w = affine.apply #map1 (%arg9, %arg5) + %7 = vector.load %arg0[%arg3, %in_iter_h, %in_iter_w, %arg6] : memref, vector<32xf32> + %8 = arith.maximumf %7, %arg10 : vector<32xf32> + affine.yield %8 : vector<32xf32> + } + affine.yield %6 : vector<32xf32> + } + vector.store %5, %arg2[%arg3, %arg4, %arg5, %arg6] : memref, vector<32xf32> + %dim_4_next = arith.addi %dim_4, %vl_step : index + scf.yield %dim_4_next : index + } + // Compute the tail size and Process the remaining elements + // using masked vector operations. + %tail_size = arith.subi %dim_4, %iter_idx : index + %3 = arith.cmpi sgt, %tail_size, %c0 : index + scf.if %3 { + %mask = vector.create_mask %tail_size : vector<32xi1> + %5 = vector.maskedload %arg2[%arg3, %arg4, %arg5, %iter_idx], %mask, %0 : memref, vector<32xi1>, vector<32xf32> into vector<32xf32> + %6 = affine.for %arg7 = #map(%c0) to #map(%dim) iter_args(%arg8 = %5) -> (vector<32xf32>) { + %8 = arith.addi %arg4, %arg7 : index + %7 = affine.for %arg9 = #map(%c0) to #map(%dim_0) iter_args(%arg10 = %arg8) -> (vector<32xf32>) { + %9 = arith.addi %arg9, %arg5 : index + %10 = vector.maskedload %arg0[%arg3, %8, %9, %iter_idx], %mask, %0 : memref, vector<32xi1>, vector<32xf32> into vector<32xf32> + %11 = arith.maximumf %10, %arg10 : vector<32xf32> + affine.yield %11 : vector<32xf32> + } + affine.yield %7 : vector<32xf32> + } + vector.maskedstore %arg2[%arg3, %arg4, %arg5, %iter_idx], %mask, %6 : memref, vector<32xi1>, vector<32xf32> + } + } + } + } + %t_end = call @rtclock() : () -> f64 + %time = arith.subf %t_end, %t_start : f64 + %printed_output = memref.cast %arg2 : memref to memref<*xf32> + call @printMemrefF32(%printed_output) : (memref<*xf32>) -> () + + // Print timings. + vector.print %time : f64 + + return + } + + func.func @main(){ + // Set up dims. + %c1 = arith.constant 1 : index + %cInput = arith.constant 24 : index + %cKernel = arith.constant 2 : index + %cOutput = arith.constant 12 : index + %c6 = arith.constant 6 : index + + // Set Init Value. + %cf1_32 = arith.constant 1.0 : f32 + + %a = memref.alloc(%c1, %cInput, %cInput, %c6) : memref + %b = memref.alloc(%cKernel, %cKernel) : memref + %c = memref.alloc(%c1, %cOutput, %cOutput, %c6) : memref + + linalg.fill ins(%cf1_32 : f32) outs(%a : memref) + linalg.fill ins(%cf1_32 : f32) outs(%b : memref) + linalg.fill ins(%cf1_32 : f32) outs(%c : memref) + + // All the elements of the MemRef are the same, + // only check the first line to verify the correctness. + // CHECK: Unranked Memref + // CHECK: [ + // CHECK: [ + // CHECK: [ + // CHECK: [1{{(, 1)*}}], + call @pooling_nhwc_max(%a, %b, %c) : (memref, memref, memref) -> () + + memref.dealloc %c : memref + memref.dealloc %b : memref + memref.dealloc %a : memref + + return + } +} diff --git a/examples/MLIRLinalg/linalg-pooling-nhwc-max.mlir b/examples/MLIRLinalg/linalg-pooling-nhwc-max.mlir new file mode 100644 index 0000000000..3577c09272 --- /dev/null +++ b/examples/MLIRLinalg/linalg-pooling-nhwc-max.mlir @@ -0,0 +1,97 @@ +// RUN: buddy-opt %s \ +// RUN: -pooling-nhwc-max-vectorization \ +// RUN: -convert-linalg-to-loops \ +// RUN: -lower-affine \ +// RUN: -convert-scf-to-cf \ +// RUN: -convert-vector-to-llvm \ +// RUN: -finalize-memref-to-llvm \ +// RUN: -convert-arith-to-llvm \ +// RUN: -convert-func-to-llvm \ +// RUN: -reconcile-unrealized-casts \ +// RUN: | mlir-cpu-runner -e main -entry-point-result=void \ +// RUN: -shared-libs=%mlir_runner_utils_dir/libmlir_runner_utils%shlibext \ +// RUN: -shared-libs=%mlir_runner_utils_dir/libmlir_c_runner_utils%shlibext \ +// RUN: | FileCheck %s + +module{ + func.func private @rtclock() -> f64 + func.func private @printMemrefF32(memref<*xf32>) + + func.func @pooling_nhwc_max(%a : memref, %b : memref, %c : memref) { + %t_start = call @rtclock() : () -> f64 + + linalg.pooling_nhwc_max {dilations = dense<1> : vector<2xi64>, strides = dense<2> : vector<2xi64>} + ins(%a, %b : memref, memref) + outs(%c : memref) + + %t_end = call @rtclock() : () -> f64 + %time = arith.subf %t_end, %t_start : f64 + %printed_output = memref.cast %c : memref to memref<*xf32> + call @printMemrefF32(%printed_output) : (memref<*xf32>) -> () + + // Print timings. + vector.print %time : f64 + + return + } + + func.func @alloc_f32(%arg0: index, %arg1: index, %arg2: index, %arg3: index, %arg4: f32) -> memref { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %0 = memref.alloc(%arg0, %arg1, %arg2, %arg3) : memref + scf.for %idx0 = %c0 to %arg0 step %c1 { + scf.for %idx1 = %c0 to %arg1 step %c1 { + scf.for %idx2 = %c0 to %arg2 step %c1 { + scf.for %idx3 = %c0 to %arg3 step %c1 { + memref.store %arg4, %0[%idx0, %idx1, %idx2, %idx3] : memref + } + } + } + } + return %0 : memref + } + + func.func @alloc2_f32(%arg0: index, %arg1: index, %arg4: f32) -> memref { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %0 = memref.alloc(%arg0, %arg1) : memref + scf.for %idx0 = %c0 to %arg0 step %c1 { + scf.for %idx1 = %c0 to %arg1 step %c1 { + memref.store %arg4, %0[%idx0, %idx1] : memref + } + } + return %0 : memref + } + + func.func @main(){ + // Set up dims. + %c1 = arith.constant 1 : index + %c24 = arith.constant 24 : index + %c2 = arith.constant 2 : index + %c12 = arith.constant 12 : index + %c6 = arith.constant 6 : index + + // Set Init Value. + %f0 = arith.constant 0.000000e+00 : f32 + %f1 = arith.constant 1.000000e+00 : f32 + + %v0 = call @alloc_f32(%c1, %c24, %c24, %c6, %f1) : (index, index, index, index, f32) -> memref + %v1 = call @alloc2_f32(%c2, %c2, %f0) : (index, index, f32) -> memref + %v2 = call @alloc_f32(%c1, %c12, %c12, %c6, %f0) : (index, index, index, index, f32) -> memref + + // All the elements of the MemRef are the same, + // only check the first line to verify the correctness. + // CHECK: Unranked Memref + // CHECK: [ + // CHECK: [ + // CHECK: [ + // CHECK: [1{{(, 1)*}}], + call @pooling_nhwc_max(%v0, %v1, %v2) : (memref, memref, memref) -> () + + memref.dealloc %v0 : memref + memref.dealloc %v1 : memref + memref.dealloc %v2 : memref + + return + } +} diff --git a/examples/MLIRLinalg/makefile b/examples/MLIRLinalg/makefile index d9a37926f4..865a0b162c 100644 --- a/examples/MLIRLinalg/makefile +++ b/examples/MLIRLinalg/makefile @@ -497,3 +497,39 @@ linalg-matmul-vectorization-lower: @${BUDDY_OPT} linalg-matmul.mlir \ -matmul-vectorization \ -o log.mlir + +linalg-pooling-nhwc-max-run: + @${BUDDY_OPT} linalg-pooling-nhwc-max.mlir \ + -convert-linalg-to-loops \ + -lower-affine \ + -convert-scf-to-cf \ + -convert-vector-to-llvm \ + -finalize-memref-to-llvm \ + -convert-arith-to-llvm \ + -convert-func-to-llvm \ + -reconcile-unrealized-casts | \ + ${MLIR_CPU_RUNNER} ${OPT_FLAG} -e main \ + -entry-point-result=void \ + -shared-libs=${MLIR_RUNNER_UTILS} \ + -shared-libs=${MLIR_C_RUNNER_UTILS} + +linalg-pooling-nhwc-max-vectorization-lower: + @${BUDDY_OPT} linalg-pooling-nhwc-max.mlir \ + -pooling-nhwc-max-vectorization \ + -o log.mlir + +linalg-pooling-nhwc-max-vectorization-run: + @${BUDDY_OPT} linalg-pooling-nhwc-max.mlir \ + -pooling-nhwc-max-vectorization \ + -convert-linalg-to-loops \ + -lower-affine \ + -convert-scf-to-cf \ + -convert-vector-to-llvm \ + -finalize-memref-to-llvm \ + -convert-arith-to-llvm \ + -convert-func-to-llvm \ + -reconcile-unrealized-casts | \ + ${MLIR_CPU_RUNNER} ${OPT_FLAG} -e main \ + -entry-point-result=void \ + -shared-libs=${MLIR_RUNNER_UTILS} \ + -shared-libs=${MLIR_C_RUNNER_UTILS} diff --git a/midend/lib/Conversion/ConvVectorization/CMakeLists.txt b/midend/lib/Conversion/ConvVectorization/CMakeLists.txt index fce89520b6..d4cc3ec987 100644 --- a/midend/lib/Conversion/ConvVectorization/CMakeLists.txt +++ b/midend/lib/Conversion/ConvVectorization/CMakeLists.txt @@ -2,6 +2,7 @@ add_mlir_library(CBConvVectorization CBConvVectorization.cpp GEMMPointwiseConv2DNhwcHwcf.cpp PoolingVectorization.cpp + PoolingNhwcMaxVectorization.cpp LINK_LIBS PUBLIC BuddyUtils diff --git a/midend/lib/Conversion/ConvVectorization/PoolingNhwcMaxVectorization.cpp b/midend/lib/Conversion/ConvVectorization/PoolingNhwcMaxVectorization.cpp new file mode 100644 index 0000000000..280d11b226 --- /dev/null +++ b/midend/lib/Conversion/ConvVectorization/PoolingNhwcMaxVectorization.cpp @@ -0,0 +1,368 @@ +//===--------PoolingNhwcMaxVectorization.cpp-------------------------------===// +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +//===----------------------------------------------------------------------===// +// +// This file implements the Pooling Nhwc Max Vectorization. +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/Linalg/IR/Linalg.h" +#include "mlir/Dialect/Math/IR/Math.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/Dialect/Vector/IR/VectorOps.h" +#include "mlir/IR/AffineExpr.h" +#include "mlir/IR/AffineMap.h" +#include "mlir/IR/Attributes.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/BuiltinAttributes.h" +#include "mlir/IR/TypeRange.h" +#include "mlir/IR/ValueRange.h" +#include "llvm/ADT/ArrayRef.h" +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "Utils/Utils.h" + +using namespace mlir; +using namespace vector; + +//===----------------------------------------------------------------------===// +// Rewrite Pattern +//===----------------------------------------------------------------------===// + +namespace { + +class PoolingNhwcMaxVectorizationPattern : public ConversionPattern { +public: + explicit PoolingNhwcMaxVectorizationPattern(MLIRContext *context, + int64_t stripParam) + : ConversionPattern(linalg::PoolingNhwcMaxOp::getOperationName(), 1, + context) { + strip = stripParam; + } + + LogicalResult + matchAndRewrite(Operation *op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const override { + auto loc = op->getLoc(); + auto ctx = op->getContext(); + + // Get i1 as the element type for mask vector. + IntegerType i1 = IntegerType::get(ctx, 1); + VectorType vectorMaskTy = mlir::VectorType::get({strip}, i1); + + // Get input, kernel and output. + Value input = op->getOperand(0); + Value kernel = op->getOperand(1); + Value output = op->getOperand(2); + // Get strides. + SmallVector strides = {1, 1}; + if (op->hasAttr("strides")) { + if (auto attr = + op->getAttrOfType("strides")) { + strides.clear(); + for (auto value : attr.getValues()) { + strides.push_back(value); + } + } + } + bool stride1 = strides[0] != 1; + bool stride2 = strides[1] != 1; + Value strHeight = rewriter.create(loc, strides[0]); + Value strWidth = rewriter.create(loc, strides[1]); + + // // Get dilations. + SmallVector dilations = {1, 1}; + if (op->hasAttr("dilations")) { + if (auto attr = + op->getAttrOfType("dilations")) { + dilations.clear(); + for (auto value : attr.getValues()) { + dilations.push_back(value); + } + } + } + bool dilated1 = dilations[0] != 1; + bool dilated2 = dilations[1] != 1; + Value dilHeight = + rewriter.create(loc, dilations[0]); + Value dilWidth = rewriter.create(loc, dilations[1]); + + // Get ElementType of input. + Type elementTy = input.getType().cast().getElementType(); + VectorType vectorTy = mlir::VectorType::get({strip}, elementTy); + + // Get Constants. + const Value c0 = rewriter.create(loc, 0); + const Value c1 = rewriter.create(loc, 1); + const Value c2 = rewriter.create(loc, 2); + const Value c3 = rewriter.create(loc, 3); + const Value vlStep = rewriter.create(loc, strip); + const Value zero = + buddy::insertZeroConstantOp(ctx, rewriter, loc, elementTy); + + // Create pass through vector. + Value passThroughVec = rewriter.create(loc, vectorTy, zero); + + // Get Dimensions of Kernel. + Value kernelHeight = rewriter.create(loc, kernel, c0); + Value kernelWidth = rewriter.create(loc, kernel, c1); + + // Get Dimensions of Outputs. + Value batch = rewriter.create(loc, output, c0); + Value height = rewriter.create(loc, output, c1); + Value width = rewriter.create(loc, output, c2); + Value channels = rewriter.create(loc, output, c3); + + // Calculate the upper bound for vectorized processing + // - Subtract `vlStep` is to avoid overflow at the vectorization tail. + // - Add 1 to ensure the final loop runs when the workload length + // is divisible by the vector size. + Value upperBoundTmp = rewriter.create(loc, channels, vlStep); + Value upperBound = rewriter.create(loc, upperBoundTmp, c1); + + SmallVector lowerBounds(3, c0); + SmallVector uperBounds{batch, height, width}; + SmallVector steps(3, /*Value=*/1); + affine::buildAffineLoopNest( + rewriter, loc, lowerBounds, uperBounds, steps, + [&](OpBuilder &builder, Location loc, ValueRange ivs) { + // Create strides variables. + Value tmpIvs1 = ivs[1]; + if (stride1) { + tmpIvs1 = builder.create(loc, ivs[1], strHeight); + } + Value tmpIvs2 = ivs[2]; + if (stride2) { + tmpIvs2 = builder.create(loc, ivs[2], strWidth); + } + // Create strip mining loop. + auto iterIdx = builder.create( + loc, c0, upperBound, /*Step=*/vlStep, ValueRange{c0}, + [&](OpBuilder &nestedBuilder, Location nestedLoc, Value iv, + ValueRange itrArgs) { + Value outputVector = nestedBuilder.create( + loc, vectorTy, output, + ValueRange{ivs[0], ivs[1], ivs[2], iv}); + + auto tmp0 = nestedBuilder.create( + loc, ValueRange{c0}, builder.getDimIdentityMap(), + ValueRange{kernelHeight}, builder.getDimIdentityMap(), + /*Step=*/1, ValueRange{outputVector}, + [&](OpBuilder &builder, Location loc, Value iv0, + ValueRange itrArgs0) { + // Create dilated[0] variables. + Value tmpIvs3 = iv0; + if (dilated1) { + tmpIvs3 = + builder.create(loc, iv0, dilHeight); + } + Value inputHeight = + builder.create(loc, tmpIvs1, tmpIvs3); + auto tmp1 = builder.create( + loc, ValueRange{c0}, builder.getDimIdentityMap(), + ValueRange{kernelWidth}, builder.getDimIdentityMap(), + /*Step=*/1, ValueRange{itrArgs0[0]}, + [&](OpBuilder &builder, Location loc, Value iv1, + ValueRange itrArgs1) { + // Create dilated[1] variables. + Value tmpIvs4 = iv1; + if (dilated2) { + tmpIvs4 = builder.create(loc, iv1, + dilWidth); + } + Value inputWidth = builder.create( + loc, tmpIvs2, tmpIvs4); + Value inputVector = builder.create( + loc, vectorTy, input, + ValueRange{ivs[0], inputHeight, inputWidth, + iv}); + // Max + Value resultVector; + if (auto ty = + llvm::dyn_cast(elementTy)) { + resultVector = builder.create( + loc, inputVector, itrArgs1[0]); + } else { + resultVector = builder.create( + loc, inputVector, itrArgs1[0]); + } + builder.create(loc, + resultVector); + }); + nestedBuilder.create( + loc, tmp1.getResult(0)); + }); + builder.create( + loc, tmp0.getResult(0), output, + ValueRange{ivs[0], ivs[1], ivs[2], iv}); + Value idx = + builder.create(loc, itrArgs[0], vlStep); + builder.create(loc, idx); + }); + // Compute the tail size and Process the remaining elements + // using masked vector operations. + Value idx = iterIdx.getResult(0); + Value tailSize = builder.create(loc, channels, idx); + Value tailCond = rewriter.create( + loc, arith::CmpIPredicate::sgt, tailSize, c0); + // If the current column does not reach the tail. + builder.create< + scf::IfOp>(loc, tailCond, [&](OpBuilder &builder, Location loc) { + // Create mask according to the tail. + Value tailMask = + builder.create(loc, vectorMaskTy, tailSize); + // Masked load output. + Value maskedOutputVec = builder.create( + loc, vectorTy, output, ValueRange{ivs[0], ivs[1], ivs[2], idx}, + tailMask, passThroughVec); + auto tmp0 = builder.create( + loc, ValueRange{c0}, builder.getDimIdentityMap(), + ValueRange{kernelHeight}, builder.getDimIdentityMap(), + /*Step=*/1, ValueRange{maskedOutputVec}, + [&](OpBuilder &builder, Location loc, Value iv0, + ValueRange itrArgs0) { + // Create dilated[0] variables. + Value tmpIvs3 = iv0; + if (dilated1) { + tmpIvs3 = + builder.create(loc, iv0, dilHeight); + } + Value inputHeight = + builder.create(loc, tmpIvs1, tmpIvs3); + auto tmp1 = builder.create( + loc, ValueRange{c0}, builder.getDimIdentityMap(), + ValueRange{kernelWidth}, builder.getDimIdentityMap(), + /*Step=*/1, ValueRange{itrArgs0[0]}, + [&](OpBuilder &builder, Location loc, Value iv1, + ValueRange itrArgs1) { + // Calculate the index of the input and + // output. + // Create dilated[1] variables. + Value tmpIvs4 = iv1; + if (dilated2) { + tmpIvs4 = + builder.create(loc, iv1, dilWidth); + } + Value inputWidth = + builder.create(loc, iv1, tmpIvs2); + // Masked load input and output. + Value maskedInputVec = builder.create( + loc, vectorTy, input, + ValueRange{ivs[0], inputHeight, inputWidth, idx}, + tailMask, passThroughVec); + // Max + Value resultVec; + if (auto ty = llvm::dyn_cast(elementTy)) { + resultVec = builder.create( + loc, maskedInputVec, itrArgs1[0]); + } else { + resultVec = builder.create( + loc, maskedInputVec, itrArgs1[0]); + } + builder.create(loc, resultVec); + }); + builder.create(loc, tmp1.getResult(0)); + }); + // Masked store the result to output. + builder.create( + loc, output, ValueRange{ivs[0], ivs[1], ivs[2], idx}, tailMask, + tmp0.getResult(0)); + builder.create(loc); + }); + }); + // Remove the origin convolution operation. + rewriter.eraseOp(op); + return success(); + } + +private: + int64_t strip; +}; +} // end anonymous namespace + +//===----------------------------------------------------------------------===// +// PoolingNhwcMaxVectorizationPass +//===----------------------------------------------------------------------===// + +/// This is a partial lowering linalg pooling max operations to mixture of +/// Arith + Vector operations. +namespace { +class PoolingNhwcMaxVectorizationPass + : public PassWrapper> { +public: + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(PoolingNhwcMaxVectorizationPass) + StringRef getArgument() const final { + return "pooling-nhwc-max-vectorization"; + } + StringRef getDescription() const final { + return "Pooling_Nhwc_Max vectorization."; + } + PoolingNhwcMaxVectorizationPass() = default; + PoolingNhwcMaxVectorizationPass(const PoolingNhwcMaxVectorizationPass &) {} + + void runOnOperation() override; + + void getDependentDialects(DialectRegistry ®istry) const override { + registry.insert(); + } + Option strip{*this, "vector-size", + llvm::cl::desc("Specify vector type size."), + llvm::cl::init(16)}; +}; +} // end anonymous namespace. + +void PoolingNhwcMaxVectorizationPass::runOnOperation() { + MLIRContext *context = &getContext(); + ModuleOp module = getOperation(); + + ConversionTarget target(*context); + target.addLegalDialect(); + target.addLegalOp(); + target.addLegalOp(); + + RewritePatternSet patterns(context); + patterns.add(context, strip); + + if (failed(applyPartialConversion(module, target, std::move(patterns)))) + signalPassFailure(); +} + +namespace mlir { +namespace buddy { +void registerPoolingNhwcMaxVectorizationPass() { + PassRegistration(); +} +} // namespace buddy +} // namespace mlir diff --git a/tests/Conversion/pooling-nhwc-max-vectorization.mlir b/tests/Conversion/pooling-nhwc-max-vectorization.mlir new file mode 100644 index 0000000000..0fa188eeb6 --- /dev/null +++ b/tests/Conversion/pooling-nhwc-max-vectorization.mlir @@ -0,0 +1,32 @@ +// RUN: buddy-opt -pooling-nhwc-max-vectorization %s | FileCheck %s + +// CHECK: #map = affine_map<(d0) -> (d0)> +// CHECK: module { +// CHECK: affine.for %arg3 = #map(%c0) to #map(%dim_5) { +// CHECK-NEXT: affine.for %arg4 = #map(%c0) to #map(%dim_6) { +// CHECK-NEXT: affine.for %arg5 = #map(%c0) to #map(%dim_7) { +// CHECK-NEXT: %3 = arith.muli %arg4, %c2 : index +// CHECK-NEXT: %4 = arith.muli %arg5, %c2_0 : index +// CHECK-NEXT: %5 = scf.for %arg6 = %c0 to %2 step %c16 iter_args(%arg7 = %c0) -> (index) { +// CHECK-NEXT: %8 = vector.load %arg2[%arg3, %arg4, %arg5, %arg6] : memref, vector<16xf32> +// CHECK-NEXT: %9 = affine.for %arg8 = #map(%c0) to #map(%dim) iter_args(%arg9 = %8) -> (vector<16xf32>) { +// CHECK-NEXT: %11 = arith.addi %3, %arg8 : index +// CHECK-NEXT: %12 = affine.for %arg10 = #map(%c0) to #map(%dim_4) iter_args(%arg11 = %arg9) -> (vector<16xf32>) { +// CHECK-NEXT: %13 = arith.addi %4, %arg10 : index +// CHECK-NEXT: %14 = vector.load %arg0[%arg3, %11, %13, %arg6] : memref, vector<16xf32> +// CHECK-NEXT: %15 = arith.maximumf %14, %arg11 : vector<16xf32> +// CHECK-NEXT: affine.yield %15 : vector<16xf32> +// CHECK-NEXT: } +// CHECK-NEXT: affine.yield %12 : vector<16xf32> +// CHECK-NEXT: } +// CHECK-NEXT: vector.store %9, %arg2[%arg3, %arg4, %arg5, %arg6] : memref, vector<16xf32> +// CHECK-NEXT: %10 = arith.addi %arg7, %c16 : index +// CHECK-NEXT: scf.yield %10 : index +// CHECK-NEXT: } + +func.func @pooling_nhwc_max(%a : memref, %b : memref, %c : memref) { + linalg.pooling_nhwc_max {dilations = dense<1> : vector<2xi64>, strides = dense<2> : vector<2xi64>} + ins(%a, %b : memref, memref) + outs(%c : memref) + return +} diff --git a/tools/buddy-opt/buddy-opt.cpp b/tools/buddy-opt/buddy-opt.cpp index 08e172f8bc..2c7de877c6 100644 --- a/tools/buddy-opt/buddy-opt.cpp +++ b/tools/buddy-opt/buddy-opt.cpp @@ -55,6 +55,7 @@ void registerConvVectorizationPass(); void registerPointwiseConvToGemmPass(); void registerPointwiseConvToGemmForNhwcFhwcPass(); void registerPoolingVectorizationPass(); +void registerPoolingNhwcMaxVectorizationPass(); void registerLowerBudPass(); void registerLowerDIPPass(); void registerBatchMatMulOptimizePass(); @@ -94,6 +95,8 @@ int main(int argc, char **argv) { mlir::buddy::registerConvVectorizationPass(); // Register Vectorization of Pooling. mlir::buddy::registerPoolingVectorizationPass(); + // Register Vectorization of Pooling Nhwc Max. + mlir::buddy::registerPoolingNhwcMaxVectorizationPass(); mlir::buddy::registerLowerBudPass(); mlir::buddy::registerLowerDIPPass(); mlir::buddy::registerLowerDAPPass();