Skip to content

Commit

Permalink
Merge pull request #3 from FloatingcloudKnight/vector
Browse files Browse the repository at this point in the history
[midend/lib/Conversion/ConvVectorization] Add cond2dnhwcfhwc pass and poolingnhwcmax pass
  • Loading branch information
FloatingcloudKnight authored Oct 23, 2024
2 parents 3ff7b40 + 58aa900 commit ce0a6d9
Show file tree
Hide file tree
Showing 32 changed files with 4,017 additions and 93 deletions.
77 changes: 55 additions & 22 deletions examples/BuddyConvolution/conv2d-nhwc-fhwc-opt.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -14,22 +14,24 @@
// RUN: | FileCheck %s

// Using `8` as the vector size.
#map = affine_map<(d0) -> (d0 floordiv 8)>
#map = affine_map<(d0) -> (d0 ceildiv 16)>
#map0 = affine_map<(d0, d1, d2, d3) -> (d2)>
#map1 = affine_map<(d0, d1) -> (d0 + d1)>
#map2 = affine_map<(d0, d1) -> (d0 + d1 * 8)>
#map3 = affine_map<(d0) -> (d0 * 8)>
#map2 = affine_map<(d0, d1) -> (d0 + d1 * 16)>
#map3 = affine_map<(d0) -> (d0 * 16)>

module {
func.func private @printMemrefF32(memref<*xf32>)
func.func private @rtclock() -> f64

func.func @conv_2d_nhwc_fhwc(%arg0: memref<?x?x?x?xf32>, %arg1: memref<?x?x?x?xf32>, %arg2: memref<?x?x?x?xf32>) {
%f0 = arith.constant 0. : f32
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
%c2 = arith.constant 2 : index
%c3 = arith.constant 3 : index
%c32 = arith.constant 16 : index
%f0 = arith.constant 0.000000e+00 : f32
%0 = vector.splat %f0 : vector<16xf32>
%n = memref.dim %arg0, %c0 : memref<?x?x?x?xf32>
%h_i = memref.dim %arg0, %c1 : memref<?x?x?x?xf32>
%w_i = memref.dim %arg0, %c2 : memref<?x?x?x?xf32>
Expand All @@ -45,22 +47,53 @@ module {
affine.for %idx_f = %c0 to %f {
affine.for %idx_c = %c0 to %c {
affine.for %idx_h_o = %c0 to %h_o {
affine.for %idx_h_k = %c0 to %h_k {
affine.for %idx_w_k = %c0 to %w_k {
affine.for %idx_w_o = %c0 to #map(%w_o) {
%kernel_ele = memref.load %arg1[%idx_f, %idx_h_k, %idx_w_k, %idx_c] : memref<?x?x?x?xf32>
%kernel_vec = vector.broadcast %kernel_ele : f32 to vector<8xf32>
%in_iter_h = affine.apply #map1 (%idx_h_k, %idx_h_o)
%in_iter_w = affine.apply #map2 (%idx_w_k, %idx_w_o)
%out_iter_w = affine.apply #map3 (%idx_w_o)
%input_vec = vector.transfer_read %arg0[%idx_n, %in_iter_h, %in_iter_w, %idx_c], %f0
{ permutation_map = #map0 } : memref<?x?x?x?xf32>, vector<8xf32>
%output_vec = vector.transfer_read %arg2[%idx_n, %idx_h_o, %out_iter_w, %idx_f], %f0
{ permutation_map = #map0 } : memref<?x?x?x?xf32>, vector<8xf32>
%res_vec = vector.fma %kernel_vec, %input_vec, %output_vec : vector<8xf32>
vector.transfer_write %res_vec, %arg2[%idx_n, %idx_h_o, %out_iter_w, %idx_f]
{ permutation_map = #map0 } : vector<8xf32>, memref<?x?x?x?xf32>
affine.for %idx_w_o = %c0 to #map(%w_o) {
%1 = arith.muli %idx_w_o, %c32 : index
%2 = arith.subi %w_o, %1 : index
%3 = arith.cmpi sge, %2, %c32 : index
scf.if %3 {
// %arg2[%n, %h_o, %w_o*16, %f]
%output_vec = vector.transfer_read %arg2[%idx_n, %idx_h_o, %1, %idx_f], %f0
{ permutation_map = #map0 } : memref<?x?x?x?xf32>, vector<16xf32>
%5 = affine.for %idx_h_k = %c0 to %h_k iter_args(%arg8 = %output_vec) -> (vector<16xf32>) { // %h_k
%6 = affine.for %idx_w_k = %c0 to %w_k iter_args(%arg10 = %arg8) -> (vector<16xf32>) { // %w_k
// %arg1[%f, %h_k, %w_k, %c]
%kernel_ele = memref.load %arg1[%idx_f, %idx_h_k, %idx_w_k, %idx_c] : memref<?x?x?x?xf32>
%kernel_vec = vector.broadcast %kernel_ele : f32 to vector<16xf32>
%in_iter_h = affine.apply #map1 (%idx_h_k, %idx_h_o)
%in_iter_w = affine.apply #map2 (%idx_w_k, %idx_w_o)
// %arg0[%n, %h_k+%h_o, %w_k+%w_o*16, %c]
%input_vec = vector.transfer_read %arg0[%idx_n, %in_iter_h, %in_iter_w, %idx_c], %f0
{ permutation_map = #map0 } : memref<?x?x?x?xf32>, vector<16xf32>
%res_vec = vector.fma %kernel_vec, %input_vec, %arg10 : vector<16xf32>
affine.yield %res_vec : vector<16xf32>
}
affine.yield %6 : vector<16xf32>
}
vector.transfer_write %5, %arg2[%idx_n, %idx_h_o, %1, %idx_f]
{ permutation_map = #map0 } : vector<16xf32>, memref<?x?x?x?xf32>
} else {
%9 = vector.create_mask %2 : vector<16xi1>
// %arg2[%n, %h_o, %w_o*16, %f]
%output_vec = vector.transfer_read %arg2[%idx_n, %idx_h_o, %1, %idx_f], %f0, %9
{ permutation_map = #map0 } : memref<?x?x?x?xf32>, vector<16xf32>
%5 = affine.for %idx_h_k = %c0 to %h_k iter_args(%arg8 = %output_vec) -> (vector<16xf32>) { // %h_k
%6 = affine.for %idx_w_k = %c0 to %w_k iter_args(%arg10 = %arg8) -> (vector<16xf32>) { // %w_k
// %arg1[%f, %h_k, %w_k, %c]
%kernel_ele = memref.load %arg1[%idx_f, %idx_h_k, %idx_w_k, %idx_c] : memref<?x?x?x?xf32>
%kernel_vec = vector.broadcast %kernel_ele : f32 to vector<16xf32>
%in_iter_h = affine.apply #map1 (%idx_h_k, %idx_h_o)
%in_iter_w = affine.apply #map2 (%idx_w_k, %idx_w_o)
// %arg0[%n, %h_k+%h_o, %w_k+%w_o*16, %c]
%input_vec = vector.transfer_read %arg0[%idx_n, %in_iter_h, %in_iter_w, %idx_c], %f0, %9
{ permutation_map = #map0 } : memref<?x?x?x?xf32>, vector<16xf32>
%res_vec = vector.fma %kernel_vec, %input_vec, %arg10 : vector<16xf32>
affine.yield %res_vec : vector<16xf32>
}
affine.yield %6 : vector<16xf32>
}
vector.transfer_write %5, %arg2[%idx_n, %idx_h_o, %1, %idx_f], %9
{ permutation_map = #map0 } : vector<16xf32>, memref<?x?x?x?xf32>
}
}
}
Expand Down Expand Up @@ -107,8 +140,8 @@ module {
// %v1 = call @alloc_f32(%c16, %c5, %c5, %c6, %f3) : (index, index, index, index, f32) -> memref<?x?x?x?xf32>
// %v2 = call @alloc_f32(%c1, %c8, %c8, %c16, %f0) : (index, index, index, index, f32) -> memref<?x?x?x?xf32>

%v0 = call @alloc_f32(%c1, %c28, %c28, %c1, %f2) : (index, index, index, index, f32) -> memref<?x?x?x?xf32>
%v1 = call @alloc_f32(%c6, %c5, %c5, %c1, %f3) : (index, index, index, index, f32) -> memref<?x?x?x?xf32>
%v0 = call @alloc_f32(%c1, %c28, %c28, %c5, %f2) : (index, index, index, index, f32) -> memref<?x?x?x?xf32>
%v1 = call @alloc_f32(%c6, %c5, %c5, %c5, %f3) : (index, index, index, index, f32) -> memref<?x?x?x?xf32>
%v2 = call @alloc_f32(%c1, %c24, %c24, %c6, %f0) : (index, index, index, index, f32) -> memref<?x?x?x?xf32>

%t_start = call @rtclock() : () -> f64
Expand All @@ -121,7 +154,7 @@ module {
// CHECK: [
// CHECK: [
// CHECK: [
// CHECK: [150{{(, 150)*}}],
// CHECK: [750{{(, 750)*}}],
%print_v2 = memref.cast %v2 : memref<?x?x?x?xf32> to memref<*xf32>
call @printMemrefF32(%print_v2) : (memref<*xf32>) -> ()

Expand Down
129 changes: 129 additions & 0 deletions examples/BuddyConvolution/conv2d-nhwc-fhwc-vec.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,129 @@
// RUN: buddy-opt %s \
// RUN: -convert-vector-to-scf \
// RUN: -lower-affine \
// RUN: -arith-bufferize \
// RUN: -convert-scf-to-cf \
// RUN: -convert-vector-to-llvm \
// RUN: -convert-arith-to-llvm \
// RUN: -finalize-memref-to-llvm \
// RUN: -convert-func-to-llvm \
// RUN: -reconcile-unrealized-casts \
// RUN: | mlir-cpu-runner -O3 -e main -entry-point-result=void \
// RUN: -shared-libs=%mlir_runner_utils_dir/libmlir_runner_utils%shlibext \
// RUN: -shared-libs=%mlir_runner_utils_dir/libmlir_c_runner_utils%shlibext \
// RUN: | FileCheck %s

// Using `8` as the vector size.
#map = affine_map<(d0) -> (d0 floordiv 1)>
#map0 = affine_map<(d0, d1, d2, d3) -> (d2)>
#map1 = affine_map<(d0, d1) -> (d0 + d1)>
module {
func.func private @printMemrefF32(memref<*xf32>)
func.func private @rtclock() -> f64

func.func @conv_2d_nhwc_fhwc(%arg0: memref<?x?x?x?xf32>, %arg1: memref<?x?x?x?xf32>, %arg2: memref<?x?x?x?xf32>) {
%f0 = arith.constant 0. : f32
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
%c2 = arith.constant 2 : index
%c3 = arith.constant 3 : index
%vec1 = vector.splat %f0 : vector<16xf32>
%n = memref.dim %arg0, %c0 : memref<?x?x?x?xf32>
%h_i = memref.dim %arg0, %c1 : memref<?x?x?x?xf32>
%w_i = memref.dim %arg0, %c2 : memref<?x?x?x?xf32>
%c = memref.dim %arg0, %c3 : memref<?x?x?x?xf32>
%f = memref.dim %arg1, %c0 : memref<?x?x?x?xf32>
%h_k = memref.dim %arg1, %c1 : memref<?x?x?x?xf32>
%w_k = memref.dim %arg1, %c2 : memref<?x?x?x?xf32>
%h_o = memref.dim %arg2, %c1 : memref<?x?x?x?xf32>
%w_o = memref.dim %arg2, %c2 : memref<?x?x?x?xf32>

// Output is NHoWoF
affine.for %idx_n = %c0 to %n {
affine.for %idx_h_o = %c0 to %h_o {
affine.for %idx_w_o = %c0 to %w_o {
%tmp8 = affine.for %idx_h_k = %c0 to %h_k iter_args(%tmp9 = %vec1) -> (vector<16xf32>) {
%tmp6 = affine.for %idx_w_k = %c0 to %w_k iter_args(%tmp7 = %vec1) -> (vector<16xf32>) {
%in_iter_h = affine.apply #map1 (%idx_h_k, %idx_h_o)
%in_iter_w = affine.apply #map1 (%idx_w_k, %idx_w_o)
%input_vec = affine.vector_load %arg0[%idx_n, %in_iter_h, %in_iter_w, %c0] : memref<?x?x?x?xf32>, vector<6xf32>
%tmp0 = affine.for %idx_f = %c0 to %f iter_args(%tmp1 = %vec1) -> (vector<16xf32>) {
%kernel_vec = affine.vector_load %arg1[%idx_f, %idx_h_k, %idx_w_k, %c0] : memref<?x?x?x?xf32>, vector<6xf32>
%tmp_vec0 = arith.mulf %kernel_vec, %input_vec : vector<6xf32>
%tmp_val = vector.reduction <add>, %tmp_vec0 : vector<6xf32> into f32
%tmp4 = vector.insert %tmp_val, %tmp1[%idx_f] : f32 into vector<16xf32>
affine.yield %tmp4 : vector<16xf32>
}
%tmp5 = arith.addf %tmp7, %tmp0 : vector<16xf32>
affine.yield %tmp5 : vector<16xf32>
}
%tmp5 = arith.addf %tmp9, %tmp6 : vector<16xf32>
affine.yield %tmp5 : vector<16xf32>
}
affine.vector_store %tmp8, %arg2[%idx_n, %idx_h_o, %idx_w_o, %c0] : memref<?x?x?x?xf32>, vector<16xf32>
}
}
}
return
}

func.func @alloc_f32(%arg0: index, %arg1: index, %arg2: index, %arg3: index, %arg4: f32) -> memref<?x?x?x?xf32> {
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
%0 = memref.alloc(%arg0, %arg1, %arg2, %arg3) : memref<?x?x?x?xf32>
scf.for %idx0 = %c0 to %arg0 step %c1 {
scf.for %idx1 = %c0 to %arg1 step %c1 {
scf.for %idx2 = %c0 to %arg2 step %c1 {
scf.for %idx3 = %c0 to %arg3 step %c1 {
memref.store %arg4, %0[%idx0, %idx1, %idx2, %idx3] : memref<?x?x?x?xf32>
}
}
}
}
return %0 : memref<?x?x?x?xf32>
}

func.func @main() {
%f0 = arith.constant 0.000000e+00 : f32
%f2 = arith.constant 2.000000e+00 : f32
%f3 = arith.constant 3.000000e+00 : f32

%c1 = arith.constant 1 : index
%c2 = arith.constant 2 : index
%c3 = arith.constant 3 : index
%c5 = arith.constant 5 : index
%c6 = arith.constant 6 : index
%c8 = arith.constant 8 : index
%c12 = arith.constant 12 : index
%c16 = arith.constant 16 : index
%c24 = arith.constant 24 : index
%c28 = arith.constant 28 : index

%v0 = call @alloc_f32(%c1, %c12, %c12, %c6, %f2) : (index, index, index, index, f32) -> memref<?x?x?x?xf32>
%v1 = call @alloc_f32(%c16, %c5, %c5, %c6, %f3) : (index, index, index, index, f32) -> memref<?x?x?x?xf32>
%v2 = call @alloc_f32(%c1, %c8, %c8, %c16, %f0) : (index, index, index, index, f32) -> memref<?x?x?x?xf32>

%t_start = call @rtclock() : () -> f64
call @conv_2d_nhwc_fhwc(%v0, %v1, %v2) : (memref<?x?x?x?xf32>, memref<?x?x?x?xf32>, memref<?x?x?x?xf32>) -> ()
%t_end = call @rtclock() : () -> f64

// All the elements of the MemRef are the same,
// only check the first line to verify the correctness.
// CHECK: Unranked Memref
// CHECK: [
// CHECK: [
// CHECK: [
// CHECK: [900{{(, 900)*}}],
%print_v2 = memref.cast %v2 : memref<?x?x?x?xf32> to memref<*xf32>
call @printMemrefF32(%print_v2) : (memref<*xf32>) -> ()

%time = arith.subf %t_end, %t_start : f64
vector.print %time : f64

memref.dealloc %v0 : memref<?x?x?x?xf32>
memref.dealloc %v1 : memref<?x?x?x?xf32>
memref.dealloc %v2 : memref<?x?x?x?xf32>

return
}
}
27 changes: 14 additions & 13 deletions examples/BuddyConvolution/conv2d-nhwc-fhwc.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,10 @@ module {
func.func private @printMemrefF32(memref<*xf32>)
func.func private @rtclock() -> f64

func.func @conv_2d_nhwc_fhwc(%arg0: memref<?x?x?x?xf32>, %arg1: memref<?x?x?x?xf32>, %arg2: memref<?x?x?x?xf32>) {
linalg.conv_2d_nhwc_fhwc ins (%arg0, %arg1: memref<?x?x?x?xf32>, memref<?x?x?x?xf32>)
outs (%arg2: memref<?x?x?x?xf32>)
func.func @conv_2d_nhwc_fhwc(%arg0: memref<1x12x12x6xf32>, %arg1: memref<16x5x5x6xf32>, %arg2: memref<1x8x8x16xf32>) {
linalg.conv_2d_nhwc_fhwc {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>}
ins (%arg0, %arg1: memref<1x12x12x6xf32>, memref<16x5x5x6xf32>)
outs (%arg2: memref<1x8x8x16xf32>)
return
}

Expand Down Expand Up @@ -54,17 +55,17 @@ module {
%c16 = arith.constant 16 : index
%c24 = arith.constant 24 : index
%c28 = arith.constant 28 : index

// %v0 = call @alloc_f32(%c1, %c12, %c12, %c6, %f2) : (index, index, index, index, f32) -> memref<?x?x?x?xf32>
// %v1 = call @alloc_f32(%c16, %c5, %c5, %c6, %f3) : (index, index, index, index, f32) -> memref<?x?x?x?xf32>
// %v2 = call @alloc_f32(%c1, %c8, %c8, %c16, %f0) : (index, index, index, index, f32) -> memref<?x?x?x?xf32>

%v0 = call @alloc_f32(%c1, %c28, %c28, %c1, %f2) : (index, index, index, index, f32) -> memref<?x?x?x?xf32>
%v1 = call @alloc_f32(%c6, %c5, %c5, %c1, %f3) : (index, index, index, index, f32) -> memref<?x?x?x?xf32>
%v2 = call @alloc_f32(%c1, %c24, %c24, %c6, %f0) : (index, index, index, index, f32) -> memref<?x?x?x?xf32>
%v0 = call @alloc_f32(%c1, %c12, %c12, %c6, %f2) : (index, index, index, index, f32) -> memref<?x?x?x?xf32>
%v1 = call @alloc_f32(%c16, %c5, %c5, %c6, %f3) : (index, index, index, index, f32) -> memref<?x?x?x?xf32>
%v2 = call @alloc_f32(%c1, %c8, %c8, %c16, %f0) : (index, index, index, index, f32) -> memref<?x?x?x?xf32>

%a = memref.cast %v0 : memref<?x?x?x?xf32> to memref<1x12x12x6xf32>
%b = memref.cast %v1 : memref<?x?x?x?xf32> to memref<16x5x5x6xf32>
%c = memref.cast %v2 : memref<?x?x?x?xf32> to memref<1x8x8x16xf32>

%t_start = call @rtclock() : () -> f64
call @conv_2d_nhwc_fhwc(%v0, %v1, %v2) : (memref<?x?x?x?xf32>, memref<?x?x?x?xf32>, memref<?x?x?x?xf32>) -> ()
call @conv_2d_nhwc_fhwc(%a, %b, %c) : (memref<1x12x12x6xf32>, memref<16x5x5x6xf32>, memref<1x8x8x16xf32>) -> ()
%t_end = call @rtclock() : () -> f64

// All the elements of the MemRef are the same,
Expand All @@ -73,8 +74,8 @@ module {
// CHECK: [
// CHECK: [
// CHECK: [
// CHECK: [150{{(, 150)*}}],
%print_v2 = memref.cast %v2 : memref<?x?x?x?xf32> to memref<*xf32>
// CHECK: [900{{(, 900)*}}],
%print_v2 = memref.cast %c : memref<1x8x8x16xf32> to memref<*xf32>
call @printMemrefF32(%print_v2) : (memref<*xf32>) -> ()

%time = arith.subf %t_end, %t_start : f64
Expand Down
31 changes: 31 additions & 0 deletions examples/BuddyConvolution/makefile
Original file line number Diff line number Diff line change
Expand Up @@ -125,3 +125,34 @@ conv2d-nhwc-fhwc-opt-aot:
-L${MLIR_LIB} -lmlir_runner_utils -lmlir_c_runner_utils \
-o a.out
@LD_LIBRARY_PATH=${MLIR_LIB} ./a.out

conv2d-nhwc-fhwc-vec-run:
@${BUDDY_OPT} ./conv2d-nhwc-fhwc-vec.mlir \
-convert-vector-to-scf \
-lower-affine \
-arith-bufferize \
-convert-scf-to-cf \
-convert-vector-to-llvm \
-convert-arith-to-llvm \
-finalize-memref-to-llvm \
-convert-func-to-llvm \
-reconcile-unrealized-casts | \
${MLIR_CPU_RUNNER} ${OPT_FLAG} -e main -entry-point-result=void \
-shared-libs=${MLIR_RUNNER_UTILS} -shared-libs=${MLIR_C_RUNNER_UTILS}

conv2d-nhwc-fhwc-vec-aot:
@${BUDDY_OPT} ./conv2d-nhwc-fhwc-vec.mlir \
-convert-vector-to-scf \
-lower-affine \
-arith-bufferize \
-convert-scf-to-cf \
-convert-vector-to-llvm \
-convert-arith-to-llvm \
-finalize-memref-to-llvm \
-convert-func-to-llvm \
-reconcile-unrealized-casts | \
${MLIR_TRANSLATE} -mlir-to-llvmir -o log.ll
${CLANG} log.ll -O3 \
-L${MLIR_LIB} -lmlir_runner_utils -lmlir_c_runner_utils \
-o a.out
@LD_LIBRARY_PATH=${MLIR_LIB} ./a.out
3 changes: 2 additions & 1 deletion examples/BuddyLeNet/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@ add_custom_command(
-linalg-bufferize
-batchmatmul-optimize
-convert-linalg-to-affine-loops
-conv2d-nhwc-fhwc-vectorization
-pooling-nhwc-max-vectorization
-lower-affine
-func-bufferize-dynamic-offset
-arith-bufferize
Expand All @@ -51,7 +53,6 @@ add_custom_command(
VERBATIM)

add_library(LENET STATIC subgraph0.o forward.o)

SET_TARGET_PROPERTIES(LENET PROPERTIES LINKER_LANGUAGE C)

add_executable(buddy-lenet-run buddy-lenet-main.cpp)
Expand Down
Loading

0 comments on commit ce0a6d9

Please sign in to comment.