Skip to content

Commit

Permalink
fixed more lit tests
Browse files Browse the repository at this point in the history
Signed-off-by: Boyana Norris <[email protected]>
  • Loading branch information
brnorris03 committed Feb 23, 2025
1 parent c804965 commit 73368d5
Show file tree
Hide file tree
Showing 4 changed files with 112 additions and 68 deletions.
4 changes: 2 additions & 2 deletions test/mlir/conversion/krnl_to_llvm/reshape.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ func.func @test_reshape(%arg0 : tensor<?x10xf32>, %arg1 : tensor<4xi64>) -> tens
"func.return"(%0) : (tensor<*xf32>) -> ()

// CHECK-LABEL: llvm.func @test_reshape
// CHECK: [[OLD_MEMREF:%.+]] = llvm.mlir.undef : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
// CHECK: [[OLD_MEMREF:%.+]] = llvm.mlir.poison : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
// CHECK: [[INSERT_1_:%.+]] = llvm.insertvalue {{.*}}, [[OLD_MEMREF]][0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
// CHECK: [[INSERT_2_:%.+]] = llvm.insertvalue {{.*}}, [[INSERT_1_]][1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
// CHECK: [[INSERT_3_:%.+]] = llvm.insertvalue {{.*}}, [[INSERT_2_]][2] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
Expand All @@ -17,7 +17,7 @@ func.func @test_reshape(%arg0 : tensor<?x10xf32>, %arg1 : tensor<4xi64>) -> tens
// CHECK-DAG:[[INSERT_7_:%.+]] = llvm.insertvalue {{.*}}, [[INSERT_6_]][4, 1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>

// COM: Check that there is no copy but only a new MemRef with a new view, i.e. new sizes and strides.
// CHECK-DAG: [[NEW_MEMREF:%.+]] = llvm.mlir.undef : !llvm.struct<(ptr, ptr, i64, array<4 x i64>, array<4 x i64>)>
// CHECK-DAG: [[NEW_MEMREF:%.+]] = llvm.mlir.poison : !llvm.struct<(ptr, ptr, i64, array<4 x i64>, array<4 x i64>)>
// CHECK: [[INSERT_8_:%.+]] = llvm.insertvalue {{.*}}, [[NEW_MEMREF]][0] : !llvm.struct<(ptr, ptr, i64, array<4 x i64>, array<4 x i64>)>
// CHECK-DAG: [[INSERT_9_:%.+]] = llvm.insertvalue {{.*}}, [[INSERT_8_]][1] : !llvm.struct<(ptr, ptr, i64, array<4 x i64>, array<4 x i64>)>
// CHECK-DAG: [[C0:%.+]] = llvm.mlir.constant(0 : index) : i64
Expand Down
22 changes: 14 additions & 8 deletions test/mlir/conversion/onnx_to_tosa/Math/Elementwise.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ func.func @test_relu(%arg0 : tensor<10x10xf32>) -> tensor<10x10xf32> {
"func.return"(%0) : (tensor<10x10xf32>) -> ()
// CHECK-LABEL: func @test_relu
// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<10x10xf32>) -> tensor<10x10xf32> {
// CHECK-NEXT: [[VAR_0_:%.+]] = tosa.clamp [[PARAM_0_]] {max_fp = 3.40282347E+38 : f32, max_int = 2147483647 : i64, min_fp = 0.000000e+00 : f32, min_int = 0 : i64} : (tensor<10x10xf32>) -> tensor<10x10xf32>
// CHECK-NEXT: [[VAR_0_:%.+]] = tosa.clamp [[PARAM_0_]] {max_val = 3.40282347E+38 : f32, min_val = 0.000000e+00 : f32} : (tensor<10x10xf32>) -> tensor<10x10xf32>
// CHECK-NEXT: return [[VAR_0_]] : tensor<10x10xf32>
// CHECK-NEXT: }
}
Expand All @@ -17,7 +17,7 @@ func.func @test_relu_dynamic(%arg0 : tensor<?x10xf32>) -> tensor<*xf32> {
"func.return"(%0) : (tensor<*xf32>) -> ()
// CHECK-LABEL: func @test_relu_dynamic
// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<?x10xf32>) -> tensor<?x10xf32> {
// CHECK-NEXT: [[VAR_0_:%.+]] = tosa.clamp [[PARAM_0_]] {max_fp = 3.40282347E+38 : f32, max_int = 2147483647 : i64, min_fp = 0.000000e+00 : f32, min_int = 0 : i64} : (tensor<?x10xf32>) -> tensor<?x10xf32>
// CHECK-NEXT: [[VAR_0_:%.+]] = tosa.clamp [[PARAM_0_]] {max_val = 3.40282347E+38 : f32, min_val = 0.000000e+00 : f32} : (tensor<?x10xf32>) -> tensor<?x10xf32>
// CHECK-NEXT: return [[VAR_0_]] : tensor<?x10xf32>
// CHECK-NEXT: }
}
Expand Down Expand Up @@ -60,7 +60,8 @@ func.func @test_add_broadcast(%arg0: tensor<13x21x1xf32>, %arg1: tensor<1xf32>)
"func.return"(%0) : (tensor<13x21x1xf32>) -> ()
// CHECK-LABEL: func.func @test_add_broadcast
// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<13x21x1xf32>, [[PARAM_1_:%.+]]: tensor<1xf32>) -> tensor<13x21x1xf32> {
// CHECK: [[VAR_0_:%.+]] = tosa.reshape [[PARAM_1_]] {new_shape = array<i64: 1, 1, 1>} : (tensor<1xf32>) -> tensor<1x1x1xf32>
// CHECK: [[SHAPE:%.+]] = tosa.const_shape {value = dense<1> : tensor<3xindex>} : () -> !tosa.shape<3>
// CHECK: [[VAR_0_:%.+]] = tosa.reshape [[PARAM_1_]], [[SHAPE]] : (tensor<1xf32>, !tosa.shape<3>) -> tensor<1x1x1xf32>
// CHECK: [[VAR_1_:%.+]] = tosa.add [[PARAM_0_]], [[VAR_0_]] : (tensor<13x21x1xf32>, tensor<1x1x1xf32>) -> tensor<13x21x1xf32>
// CHECK: return [[VAR_1_]] : tensor<13x21x1xf32>
}
Expand All @@ -83,7 +84,8 @@ func.func @test_sub_broadcast(%arg0: tensor<13x21x1xf32>, %arg1: tensor<1xf32>)
"func.return"(%0) : (tensor<13x21x1xf32>) -> ()
// CHECK-LABEL: func.func @test_sub_broadcast
// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<13x21x1xf32>, [[PARAM_1_:%.+]]: tensor<1xf32>) -> tensor<13x21x1xf32> {
// CHECK: [[VAR_0_:%.+]] = tosa.reshape [[PARAM_1_]] {new_shape = array<i64: 1, 1, 1>} : (tensor<1xf32>) -> tensor<1x1x1xf32>
// CHECK: [[SHAPE:%.+]] = tosa.const_shape {value = dense<1> : tensor<3xindex>} : () -> !tosa.shape<3>
// CHECK: [[VAR_0_:%.+]] = tosa.reshape [[PARAM_1_]], [[SHAPE]] : (tensor<1xf32>, !tosa.shape<3>) -> tensor<1x1x1xf32>
// CHECK: [[VAR_1_:%.+]] = tosa.sub [[PARAM_0_]], [[VAR_0_]] : (tensor<13x21x1xf32>, tensor<1x1x1xf32>) -> tensor<13x21x1xf32>
// CHECK: return [[VAR_1_]] : tensor<13x21x1xf32>
}
Expand All @@ -106,7 +108,8 @@ func.func @test_div_broadcast(%arg0: tensor<13x21x1xi32>, %arg1: tensor<1xi32>)
"func.return"(%0) : (tensor<13x21x1xi32>) -> ()
// CHECK-LABEL: func @test_div_broadcast
// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<13x21x1xi32>, [[PARAM_1_:%.+]]: tensor<1xi32>) -> tensor<13x21x1xi32> {
// CHECK-NEXT: [[VAR_0_:%.+]] = tosa.reshape [[PARAM_1_]] {new_shape = array<i64: 1, 1, 1>} : (tensor<1xi32>) -> tensor<1x1x1xi32>
// CHECK-NEXT: [[SHAPE:%.+]] = tosa.const_shape {value = dense<1> : tensor<3xindex>} : () -> !tosa.shape<3>
// CHECK-NEXT: [[VAR_0_:%.+]] = tosa.reshape [[PARAM_1_]], [[SHAPE]] : (tensor<1xi32>, !tosa.shape<3>) -> tensor<1x1x1xi32>
// CHECK-NEXT: [[VAR_1_:%.+]] = tosa.int_div [[PARAM_0_]], [[VAR_0_]] : (tensor<13x21x1xi32>, tensor<1x1x1xi32>) -> tensor<13x21x1xi32>
}

Expand All @@ -118,7 +121,8 @@ func.func @test_div_decomposed(%arg0: tensor<13x21x1xf32>, %arg1: tensor<13x21x1
// CHECK-LABEL: func @test_div_decomposed
// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<13x21x1xf32>, [[PARAM_1_:%.+]]: tensor<13x21x1xf32>) -> tensor<13x21x1xf32> {
// CHECK-NEXT: [[VAR_0_:%.+]] = tosa.reciprocal [[PARAM_1_]] : (tensor<13x21x1xf32>) -> tensor<13x21x1xf32>
// CHECK-NEXT: [[VAR_1_:%.+]] = tosa.mul [[PARAM_0_]], [[VAR_0_]] {shift = 0 : i8} : (tensor<13x21x1xf32>, tensor<13x21x1xf32>) -> tensor<13x21x1xf32>
// CHECK-NEXT: [[ZERO:%.+]] = "tosa.const"() <{value = dense<0> : tensor<1xi8>}> : () -> tensor<1xi8>
// CHECK-NEXT: [[VAR_1_:%.+]] = tosa.mul [[PARAM_0_]], [[VAR_0_]], [[ZERO]] : (tensor<13x21x1xf32>, tensor<13x21x1xf32>, tensor<1xi8>) -> tensor<13x21x1xf32>
}

// -----
Expand All @@ -129,6 +133,8 @@ func.func @test_div_decomposed_broadcast(%arg0: tensor<13x21x1xf32>, %arg1: tens
// CHECK-LABEL: func @test_div_decomposed_broadcast
// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<13x21x1xf32>, [[PARAM_1_:%.+]]: tensor<1xf32>) -> tensor<13x21x1xf32> {
// CHECK-NEXT: [[VAR_0_:%.+]] = tosa.reciprocal [[PARAM_1_]] : (tensor<1xf32>) -> tensor<1xf32>
// CHECK-NEXT: [[VAR_1_:%.+]] = tosa.reshape [[VAR_0_]] {new_shape = array<i64: 1, 1, 1>} : (tensor<1xf32>) -> tensor<1x1x1xf32>
// CHECK-NEXT: [[VAR_2_:%.+]] = tosa.mul [[PARAM_0_]], [[VAR_1_]] {shift = 0 : i8} : (tensor<13x21x1xf32>, tensor<1x1x1xf32>) -> tensor<13x21x1xf32>
// CHECK-NEXT: [[SHAPE:%.+]] = tosa.const_shape {value = dense<1> : tensor<3xindex>} : () -> !tosa.shape<3>
// CHECK-NEXT: [[VAR_1_:%.+]] = tosa.reshape [[VAR_0_]], [[SHAPE]] : (tensor<1xf32>, !tosa.shape<3>) -> tensor<1x1x1xf32>
// CHECK-NEXT: [[ZERO:%.+]] = "tosa.const"() <{value = dense<0> : tensor<1xi8>}> : () -> tensor<1xi8>
// CHECK-NEXT: [[VAR_2_:%.+]] = tosa.mul [[PARAM_0_]], [[VAR_1_]], [[ZERO]] : (tensor<13x21x1xf32>, tensor<1x1x1xf32>, tensor<1xi8>) -> tensor<13x21x1xf32>
}
Loading

0 comments on commit 73368d5

Please sign in to comment.