Skip to content

Commit

Permalink
[examples] Add iteration patterns for vector dialect.
Browse files Browse the repository at this point in the history
  • Loading branch information
zhanghb97 committed Oct 12, 2024
1 parent 9a7187c commit 8da4b2c
Showing 1 changed file with 100 additions and 0 deletions.
100 changes: 100 additions & 0 deletions examples/MLIRVector/vector-iteration.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,11 @@ memref.global "private" @gv : memref<4x4xf32> = dense<[[0. , 1. , 2. , 3. ],
[20., 21., 22., 23.],
[30., 31., 32., 33.]]>

memref.global "private" @gv_pat_1 : memref<10xf32> = dense<[0. , 1. , 2. , 3. , 4. , 5. , 6. , 7. , 8. , 9.]>
memref.global "private" @gv_pat_2 : memref<10xf32> = dense<[0. , 1. , 2. , 3. , 4. , 5. , 6. , 7. , 8. , 9.]>

func.func private @printMemrefF32(memref<*xf32>)

func.func @main() -> i32 {
%mem = memref.get_global @gv : memref<4x4xf32>
%c0 = arith.constant 0 : index
Expand All @@ -27,6 +32,101 @@ func.func @main() -> i32 {
}
// CHECK: ( 0, 33, 72, 117 )
vector.print %sum : vector<4xf32>

// ---------------------------------------------------------------------------
// Iteration Pattern 1
// Main Vector Loop + Scalar Remainder + Fixed Vector Type
// ---------------------------------------------------------------------------

// 1. Get the total length of the workload.
%mem_pat_1 = memref.get_global @gv_pat_1 : memref<10xf32>
%print_mem_pat_1 = memref.cast %mem_pat_1 : memref<10xf32> to memref<*xf32>
%vl_total_pat_1 = memref.dim %mem_pat_1, %c0 : memref<10xf32>

// 2. Set the iteration step (vector size).
%vl_step_pat_1 = arith.constant 4 : index

// 3. Calculate the upper bound for vectorized processing
// - Subtract `vl_step` is to avoid overflow at the vectorization tail.
// - Add 1 to ensure the final loop runs when the workload length is divisible
// by the vector size.
%vl_upbound_pat_1_ = arith.subi %vl_total_pat_1, %vl_step_pat_1 : index
%vl_upbound_pat_1 = arith.addi %vl_upbound_pat_1_, %c1 : index

// 4. Perform the vectorization body.
%iter_idx_pat_1 = scf.for %i = %c0 to %vl_upbound_pat_1 step %vl_step_pat_1
iter_args(%iter_init = %c0) -> (index) {
%load_vec1 = vector.load %mem_pat_1[%i] : memref<10xf32>, vector<4xf32>
%load_vec2 = vector.load %mem_pat_1[%i] : memref<10xf32>, vector<4xf32>
%res = arith.addf %load_vec1, %load_vec2 : vector<4xf32>
vector.store %res, %mem_pat_1[%i] : memref<10xf32>, vector<4xf32>
scf.yield %i : index
}
// CHECK: [0, 2, 4, 6, 8, 10, 12, 14, 8, 9]
call @printMemrefF32(%print_mem_pat_1) : (memref<*xf32>) -> ()

// 5. Calculate the position for tail processing.
%tail_idx_pat_1 = arith.addi %iter_idx_pat_1, %vl_step_pat_1 : index

// 6. Process the remainder of the elements with scalar operations.
scf.for %i = %tail_idx_pat_1 to %vl_total_pat_1 step %c1 {
%ele1 = memref.load %mem_pat_1[%i] : memref<10xf32>
%ele2 = memref.load %mem_pat_1[%i] : memref<10xf32>
%res = arith.addf %ele1, %ele2 : f32
memref.store %res, %mem_pat_1[%i] : memref<10xf32>
}
// CHECK: [0, 2, 4, 6, 8, 10, 12, 14, 16, 18]
call @printMemrefF32(%print_mem_pat_1) : (memref<*xf32>) -> ()

// ---------------------------------------------------------------------------
// Iteration Pattern 2
// Main Vector Loop + Masked Vector Remainder + Fixed Vector Type
// ---------------------------------------------------------------------------

// 1. Get the total length of the workload.
%mem_pat_2 = memref.get_global @gv_pat_2 : memref<10xf32>
%print_mem_pat_2 = memref.cast %mem_pat_2 : memref<10xf32> to memref<*xf32>
%vl_total_pat_2 = memref.dim %mem_pat_2, %c0 : memref<10xf32>

// 2. Set the iteration step (vector size).
%vl_step_pat_2 = arith.constant 4 : index

// 3. Calculate the upper bound for vectorized processing
// - Subtract `vl_step` is to avoid overflow at the vectorization tail.
// - Add 1 to ensure the final loop runs when the workload length is divisible
// by the vector size.
%vl_upbound_pat_2_ = arith.subi %vl_total_pat_2, %vl_step_pat_2 : index
%vl_upbound_pat_2 = arith.addi %vl_upbound_pat_2_, %c1 : index

// 4. Perform the vectorization body.
%iter_idx_pat_2 = scf.for %i = %c0 to %vl_upbound_pat_2 step %vl_step_pat_2
iter_args(%iter_init = %c0) -> (index) {
%load_vec1 = vector.load %mem_pat_2[%i] : memref<10xf32>, vector<4xf32>
%load_vec2 = vector.load %mem_pat_2[%i] : memref<10xf32>, vector<4xf32>
%res = arith.addf %load_vec1, %load_vec2 : vector<4xf32>
vector.store %res, %mem_pat_2[%i] : memref<10xf32>, vector<4xf32>
scf.yield %i : index
}
// CHECK: [0, 2, 4, 6, 8, 10, 12, 14, 8, 9]
call @printMemrefF32(%print_mem_pat_2) : (memref<*xf32>) -> ()

// 5. Calculate the position for tail processing.
%tail_idx_pat_2 = arith.addi %iter_idx_pat_2, %vl_step_pat_2 : index

// 6. Compute the tail size and create mask and pass-through vector for the
// remaining elements.
%tail_size_pat_2 = arith.subi %vl_total_pat_2, %iter_idx_pat_2 :index
%mask_pat_2 = vector.create_mask %tail_size_pat_2 : vector<4xi1>
%pass_thr_vec = arith.constant dense<0.> : vector<4xf32>

// 7. Process the remaining elements using masked vector operations.
%ele1 = vector.maskedload %mem_pat_2[%tail_idx_pat_2], %mask_pat_2, %pass_thr_vec : memref<10xf32>, vector<4xi1>, vector<4xf32> into vector<4xf32>
%ele2 = vector.maskedload %mem_pat_2[%tail_idx_pat_2], %mask_pat_2, %pass_thr_vec : memref<10xf32>, vector<4xi1>, vector<4xf32> into vector<4xf32>
%res = arith.addf %ele1, %ele2 : vector<4xf32>
vector.maskedstore %mem_pat_2[%tail_idx_pat_2], %mask_pat_2, %res : memref<10xf32>, vector<4xi1>, vector<4xf32>
// CHECK: [0, 2, 4, 6, 8, 10, 12, 14, 16, 18]
call @printMemrefF32(%print_mem_pat_2) : (memref<*xf32>) -> ()

%ret = arith.constant 0 : i32
return %ret : i32
}

0 comments on commit 8da4b2c

Please sign in to comment.