Skip to content

Commit

Permalink
[examples] Update vectorization iteration pattern.
Browse files Browse the repository at this point in the history
  • Loading branch information
zhanghb97 committed Oct 18, 2024
1 parent ec68604 commit e154c3a
Showing 1 changed file with 12 additions and 16 deletions.
28 changes: 12 additions & 16 deletions examples/MLIRVector/vector-iteration.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -60,16 +60,14 @@ func.func @main() -> i32 {
%load_vec2 = vector.load %mem_pat_1[%i] : memref<10xf32>, vector<4xf32>
%res = arith.addf %load_vec1, %load_vec2 : vector<4xf32>
vector.store %res, %mem_pat_1[%i] : memref<10xf32>, vector<4xf32>
scf.yield %i : index
%i_next = arith.addi %i, %vl_step_pat_1 : index
scf.yield %i_next : index
}
// CHECK: [0, 2, 4, 6, 8, 10, 12, 14, 8, 9]
call @printMemrefF32(%print_mem_pat_1) : (memref<*xf32>) -> ()

// 5. Calculate the position for tail processing.
%tail_idx_pat_1 = arith.addi %iter_idx_pat_1, %vl_step_pat_1 : index

// 6. Process the remainder of the elements with scalar operations.
scf.for %i = %tail_idx_pat_1 to %vl_total_pat_1 step %c1 {
// 5. Process the remainder of the elements with scalar operations.
scf.for %i = %iter_idx_pat_1 to %vl_total_pat_1 step %c1 {
%ele1 = memref.load %mem_pat_1[%i] : memref<10xf32>
%ele2 = memref.load %mem_pat_1[%i] : memref<10xf32>
%res = arith.addf %ele1, %ele2 : f32
Expand Down Expand Up @@ -105,25 +103,23 @@ func.func @main() -> i32 {
%load_vec2 = vector.load %mem_pat_2[%i] : memref<10xf32>, vector<4xf32>
%res = arith.addf %load_vec1, %load_vec2 : vector<4xf32>
vector.store %res, %mem_pat_2[%i] : memref<10xf32>, vector<4xf32>
scf.yield %i : index
%i_next = arith.addi %i, %vl_step_pat_1 : index
scf.yield %i_next : index
}
// CHECK: [0, 2, 4, 6, 8, 10, 12, 14, 8, 9]
call @printMemrefF32(%print_mem_pat_2) : (memref<*xf32>) -> ()

// 5. Calculate the position for tail processing.
%tail_idx_pat_2 = arith.addi %iter_idx_pat_2, %vl_step_pat_2 : index

// 6. Compute the tail size and create mask and pass-through vector for the
// 5. Compute the tail size and create mask and pass-through vector for the
// remaining elements.
%tail_size_pat_2 = arith.subi %vl_total_pat_2, %iter_idx_pat_2 :index
%tail_size_pat_2 = arith.subi %vl_total_pat_2, %iter_idx_pat_2 : index
%mask_pat_2 = vector.create_mask %tail_size_pat_2 : vector<4xi1>
%pass_thr_vec = arith.constant dense<0.> : vector<4xf32>

// 7. Process the remaining elements using masked vector operations.
%ele1 = vector.maskedload %mem_pat_2[%tail_idx_pat_2], %mask_pat_2, %pass_thr_vec : memref<10xf32>, vector<4xi1>, vector<4xf32> into vector<4xf32>
%ele2 = vector.maskedload %mem_pat_2[%tail_idx_pat_2], %mask_pat_2, %pass_thr_vec : memref<10xf32>, vector<4xi1>, vector<4xf32> into vector<4xf32>
// 6. Process the remaining elements using masked vector operations.
%ele1 = vector.maskedload %mem_pat_2[%iter_idx_pat_2], %mask_pat_2, %pass_thr_vec : memref<10xf32>, vector<4xi1>, vector<4xf32> into vector<4xf32>
%ele2 = vector.maskedload %mem_pat_2[%iter_idx_pat_2], %mask_pat_2, %pass_thr_vec : memref<10xf32>, vector<4xi1>, vector<4xf32> into vector<4xf32>
%res = arith.addf %ele1, %ele2 : vector<4xf32>
vector.maskedstore %mem_pat_2[%tail_idx_pat_2], %mask_pat_2, %res : memref<10xf32>, vector<4xi1>, vector<4xf32>
vector.maskedstore %mem_pat_2[%iter_idx_pat_2], %mask_pat_2, %res : memref<10xf32>, vector<4xi1>, vector<4xf32>
// CHECK: [0, 2, 4, 6, 8, 10, 12, 14, 16, 18]
call @printMemrefF32(%print_mem_pat_2) : (memref<*xf32>) -> ()

Expand Down

0 comments on commit e154c3a

Please sign in to comment.