Skip to content

Commit

Permalink
examples/BuddyConvolution/:
Browse files Browse the repository at this point in the history
linalg-pooling-nhwc-max.mlir: Linalg Dialect handwritten mlir;
conv2d-nhwc-fhwc-vec.mlir: Vector Dialect handwritten mlir file;
makefile update related commands.

midend/lib/Conversion/ConvVectorizationn/: PoolingNhwcMaxVectorization.cpp implements vectorisation.

tests/Conversion/:
pooling-nhwc-max-vectorisation.mlir: a test file.

Before vectorisation optimization: 1.71661e-05, after vectorization optimization: 1.09673e-05 (vectorization size is 32), speedup ratio is 0.454. The amount of data in the example is small and the optimization effect is not obvious.
  • Loading branch information
FloatingcloudKnight committed Dec 25, 2024
1 parent 899bba4 commit 330b2bb
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 15 deletions.
17 changes: 10 additions & 7 deletions examples/BuddyNext/pooling-nhwc-max-vec.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ module {
%dim_4_upbound_tmp = arith.subi %dim_4, %vl_step : index
%dim_4_upbound = arith.addi %dim_4_upbound_tmp, %c1 : index

%t_start = call @rtclock() : () -> f64
affine.for %arg3 = #map(%c0) to #map(%dim_1) {
affine.for %arg4 = #map(%c0) to #map(%dim_2) {
affine.for %arg5 = #map(%c0) to #map(%dim_3) {
Expand Down Expand Up @@ -83,6 +84,14 @@ module {
}
}
}
%t_end = call @rtclock() : () -> f64
%time = arith.subf %t_end, %t_start : f64
%printed_output = memref.cast %arg2 : memref<?x?x?x?xf32> to memref<*xf32>
call @printMemrefF32(%printed_output) : (memref<*xf32>) -> ()

// Print timings.
vector.print %time : f64

return
}

Expand All @@ -105,20 +114,14 @@ module {
linalg.fill ins(%cf1_32 : f32) outs(%b : memref<?x?xf32>)
linalg.fill ins(%cf1_32 : f32) outs(%c : memref<?x?x?x?xf32>)

%t0 = call @rtclock() : () -> f64
call @pooling_nhwc_max(%a, %b, %c) : (memref<?x?x?x?xf32>, memref<?x?xf32>, memref<?x?x?x?xf32>) -> ()
%t1 = call @rtclock() : () -> f64
// All the elements of the MemRef are the same,
// only check the first line to verify the correctness.
// CHECK: Unranked Memref
// CHECK: [
// CHECK: [
// CHECK: [
// CHECK: [1{{(, 1)*}}],
%print_C = memref.cast %c : memref<?x?x?x?xf32> to memref<*xf32>
call @printMemrefF32(%print_C) : (memref<*xf32>) -> ()
%time = arith.subf %t1, %t0 : f64
vector.print %time : f64
call @pooling_nhwc_max(%a, %b, %c) : (memref<?x?x?x?xf32>, memref<?x?xf32>, memref<?x?x?x?xf32>) -> ()

memref.dealloc %c : memref<?x?x?x?xf32>
memref.dealloc %b : memref<?x?xf32>
Expand Down
21 changes: 13 additions & 8 deletions examples/MLIRLinalg/linalg-pooling-nhwc-max.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,20 @@ module{
func.func private @printMemrefF32(memref<*xf32>)

func.func @pooling_nhwc_max(%a : memref<?x?x?x?xf32>, %b : memref<?x?xf32>, %c : memref<?x?x?x?xf32>) {
%t_start = call @rtclock() : () -> f64

linalg.pooling_nhwc_max {dilations = dense<1> : vector<2xi64>, strides = dense<2> : vector<2xi64>}
ins(%a, %b : memref<?x?x?x?xf32>, memref<?x?xf32>)
outs(%c : memref<?x?x?x?xf32>)

%t_end = call @rtclock() : () -> f64
%time = arith.subf %t_end, %t_start : f64
%printed_output = memref.cast %c : memref<?x?x?x?xf32> to memref<*xf32>
call @printMemrefF32(%printed_output) : (memref<*xf32>) -> ()

// Print timings.
vector.print %time : f64

return
}

Expand Down Expand Up @@ -68,21 +79,15 @@ module{
%v1 = call @alloc2_f32(%c2, %c2, %f0) : (index, index, f32) -> memref<?x?xf32>
%v2 = call @alloc_f32(%c1, %c12, %c12, %c6, %f0) : (index, index, index, index, f32) -> memref<?x?x?x?xf32>

%t0 = call @rtclock() : () -> f64
call @pooling_nhwc_max(%v0, %v1, %v2) : (memref<?x?x?x?xf32>, memref<?x?xf32>, memref<?x?x?x?xf32>) -> ()
%t1 = call @rtclock() : () -> f64
// All the elements of the MemRef are the same,
// only check the first line to verify the correctness.
// CHECK: Unranked Memref
// CHECK: [
// CHECK: [
// CHECK: [
// CHECK: [1{{(, 1)*}}],
%print_C = memref.cast %v2 : memref<?x?x?x?xf32> to memref<*xf32>
call @printMemrefF32(%print_C) : (memref<*xf32>) -> ()
%time = arith.subf %t1, %t0 : f64
vector.print %time : f64

call @pooling_nhwc_max(%v0, %v1, %v2) : (memref<?x?x?x?xf32>, memref<?x?xf32>, memref<?x?x?x?xf32>) -> ()

memref.dealloc %v0 : memref<?x?x?x?xf32>
memref.dealloc %v1 : memref<?x?xf32>
memref.dealloc %v2 : memref<?x?x?x?xf32>
Expand Down

0 comments on commit 330b2bb

Please sign in to comment.