Skip to content

Commit

Permalink
[midend/lib/Conversion/ConvVectorization] fix poolingnhwcmax vectoriz…
Browse files Browse the repository at this point in the history
…ation pass and examples about attributes.
  • Loading branch information
FloatingcloudKnight committed Dec 19, 2024
1 parent 7838000 commit 899bba4
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 14 deletions.
14 changes: 5 additions & 9 deletions examples/MLIRLinalg/linalg-pooling-nhwc-max.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,10 @@ module{
func.func private @rtclock() -> f64
func.func private @printMemrefF32(memref<*xf32>)

func.func @pooling_nhwc_max(%a : memref<1x24x24x6xf32>, %b : memref<2x2xf32>, %c : memref<1x12x12x6xf32>) {
func.func @pooling_nhwc_max(%a : memref<?x?x?x?xf32>, %b : memref<?x?xf32>, %c : memref<?x?x?x?xf32>) {
linalg.pooling_nhwc_max {dilations = dense<1> : vector<2xi64>, strides = dense<2> : vector<2xi64>}
ins(%a, %b : memref<1x24x24x6xf32>, memref<2x2xf32>)
outs(%c : memref<1x12x12x6xf32>)
ins(%a, %b : memref<?x?x?x?xf32>, memref<?x?xf32>)
outs(%c : memref<?x?x?x?xf32>)
return
}

Expand Down Expand Up @@ -68,12 +68,8 @@ module{
%v1 = call @alloc2_f32(%c2, %c2, %f0) : (index, index, f32) -> memref<?x?xf32>
%v2 = call @alloc_f32(%c1, %c12, %c12, %c6, %f0) : (index, index, index, index, f32) -> memref<?x?x?x?xf32>

%a = memref.cast %v0 : memref<?x?x?x?xf32> to memref<1x24x24x6xf32>
%b = memref.cast %v1 : memref<?x?xf32> to memref<2x2xf32>
%c = memref.cast %v2 : memref<?x?x?x?xf32> to memref<1x12x12x6xf32>

%t0 = call @rtclock() : () -> f64
call @pooling_nhwc_max(%a, %b, %c) : (memref<1x24x24x6xf32>, memref<2x2xf32>, memref<1x12x12x6xf32>) -> ()
call @pooling_nhwc_max(%v0, %v1, %v2) : (memref<?x?x?x?xf32>, memref<?x?xf32>, memref<?x?x?x?xf32>) -> ()
%t1 = call @rtclock() : () -> f64
// All the elements of the MemRef are the same,
// only check the first line to verify the correctness.
Expand All @@ -82,7 +78,7 @@ module{
// CHECK: [
// CHECK: [
// CHECK: [1{{(, 1)*}}],
%print_C = memref.cast %c : memref<1x12x12x6xf32> to memref<*xf32>
%print_C = memref.cast %v2 : memref<?x?x?x?xf32> to memref<*xf32>
call @printMemrefF32(%print_C) : (memref<*xf32>) -> ()
%time = arith.subf %t1, %t0 : f64
vector.print %time : f64
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -81,16 +81,32 @@ class PoolingNhwcMaxVectorizationPattern : public ConversionPattern {
Value input = op->getOperand(0);
Value kernel = op->getOperand(1);
Value output = op->getOperand(2);
auto strides = op->getAttrOfType<mlir::DenseIntElementsAttr>("strides")
.getValues<int64_t>();
// Get strides.
SmallVector<int64_t, 2> strides = {1, 1};
if (op->hasAttr("strides")) {
// 获取 "strides" 属性
if (auto attr = op->getAttrOfType<mlir::DenseIntElementsAttr>("strides")) {
strides.clear(); // 清空默认值
for (auto value : attr.getValues<int64_t>()) {
strides.push_back(value);
}
}
}
bool stride1 = strides[0] != 1;
bool stride2 = strides[1] != 1;
Value strHeight = rewriter.create<arith::ConstantIndexOp>(loc, strides[0]);
Value strWidth = rewriter.create<arith::ConstantIndexOp>(loc, strides[1]);

// Dilations.
auto dilations = op->getAttrOfType<mlir::DenseIntElementsAttr>("dilations")
.getValues<int64_t>();
// // Get dilations.
SmallVector<int64_t, 2> dilations = {1, 1};
if (op->hasAttr("dilations")) {
if (auto attr = op->getAttrOfType<mlir::DenseIntElementsAttr>("dilations")) {
dilations.clear();
for (auto value : attr.getValues<int64_t>()) {
dilations.push_back(value);
}
}
}
bool dilated1 = dilations[0] != 1;
bool dilated2 = dilations[1] != 1;
Value dilHeight = rewriter.create<arith::ConstantIndexOp>(loc, dilations[0]);
Expand Down

0 comments on commit 899bba4

Please sign in to comment.