Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[tf-frontend] add new parten for layer_norm and l2_norm #317

Closed
wants to merge 5 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions frontends/tf-frontend/tf_mlir_ext/numerical/numerical_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@ def get_config(config: str):
"layer_norm_without_gamma": 6,
"layer_norm_without_beta": 6,
"layer_norm_multi_dim": 6,
"layer_norm_multi_dim_v2": 6,
"layer_norm_swap_add": 6,
"layer_norm_swap_mul": 6,
"layer_norm_swap_squarediff": 5,
Expand All @@ -80,6 +81,7 @@ def get_config(config: str):
"l2_norm_V2": 3,
"l2_norm_V2_swap_mul": 3,
"l2_norm_V3": 6,
"l2_norm_couple_with_batchmatmulv2": 5,
"onehot_case0": 6,
},
"black_list": [
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -291,20 +291,38 @@ func.func @layer_norm_multi_dim(%arg0: tensor<2x8x4xf32>) -> tensor<2x8x4xf32> {
%9 = "tf.AddV2"(%6, %8) {device = ""} : (tensor<2x8x4xf32>, tensor<2x8x4xf32>) -> tensor<2x8x4xf32>
func.return %9 : tensor<2x8x4xf32>
}
// CHECK-LABLE: func.func @layer_norm_multi_dim(%arg0: tensor<2x8x4xf32>) -> tensor<2x8x4xf32> {
// CHECK-LABLE: %cst = "tf.Const"() <{value = dense<1.000000e+00> : tensor<4xf32>}> : () -> tensor<4xf32>
// CHECK-LABLE: %cst_0 = "tf.Const"() <{value = dense<0.000000e+00> : tensor<4xf32>}> : () -> tensor<4xf32>
// CHECK-LABLE: %cst_1 = "tf.Const"() <{value = dense<[16, 4]> : tensor<2xi64>}> : () -> tensor<2xi64>
// CHECK-LABLE: %cst_2 = "tf.Const"() <{value = dense<[2, 8, 4]> : tensor<3xi64>}> : () -> tensor<3xi64>
// CHECK-LABLE: %cst_3 = "tf.Const"() <{value = dense<[[[1.000000e-01, 2.000000e-01, 3.000000e-01, 4.000000e-01]], [[5.000000e-01, 6.000000e-01, 0.699999988, 8.000000e-01]]]> : tensor<
// CHECK-LABLE: %cst_4 = "tf.Const"() <{value = dense<[[[0.00999999977, 2.000000e-02, 3.000000e-02, 4.000000e-02]], [[5.000000e-02, 6.000000e-02, 7.000000e-02, 8.000000e-02]]]> : tenso
// CHECK-LABLE: %0 = "tf.Reshape"(%arg0, %cst_1) : (tensor<2x8x4xf32>, tensor<2xi64>) -> tensor<16x4xf32>
// CHECK-LABLE: %1 = mhlo.custom_call @byteir.layer_norm(%0, %cst, %cst_0) {backend_config = "", byteir_attrs = {axis = [1], epsilon = 9.9999999747524271E-7 : f64}} : (tensor<16x4xf32>
// CHECK-LABLE: %2 = "tf.Reshape"(%1, %cst_2) : (tensor<16x4xf32>, tensor<3xi64>) -> tensor<2x8x4xf32>
// CHECK-LABLE: %3 = "tf.Mul"(%2, %cst_3) : (tensor<2x8x4xf32>, tensor<2x1x4xf32>) -> tensor<2x8x4xf32>
// CHECK-LABLE: %4 = "tf.Add"(%3, %cst_4) : (tensor<2x8x4xf32>, tensor<2x1x4xf32>) -> tensor<2x8x4xf32>
// CHECK-LABLE: return %4 : tensor<2x8x4xf32>
// CHECK-LABLE: }
// CHECK-LABEL: func.func @layer_norm_multi_dim(%arg0: tensor<2x8x4xf32>) -> tensor<2x8x4xf32> {
// CHECK: %0 = "tf.Reshape"(%arg0, %cst_1) : (tensor<2x8x4xf32>, tensor<2xi64>) -> tensor<16x4xf32>
// CHECK-NEXT: %1 = mhlo.custom_call @byteir.layer_norm(%0, %cst, %cst_0) {backend_config = "", byteir_attrs = {axis = [1], epsilon = 9.9999999747524271E-7 : f64}} : (tensor<16x4xf32>
// CHECK-NEXT: %2 = "tf.Reshape"(%1, %cst_2) : (tensor<16x4xf32>, tensor<3xi64>) -> tensor<2x8x4xf32>
// CHECK-NEXT: %3 = "tf.Mul"(%2, %cst_3) : (tensor<2x8x4xf32>, tensor<2x1x4xf32>) -> tensor<2x8x4xf32>
// CHECK-NEXT: %4 = "tf.Add"(%3, %cst_4) : (tensor<2x8x4xf32>, tensor<2x1x4xf32>) -> tensor<2x8x4xf32>
// CHECK-NEXT: return %4 : tensor<2x8x4xf32>

func.func @layer_norm_multi_dim_v2(%arg0: tensor<2x8x4xf32>) -> tensor<2x8x4xf32> {
%cst = "tf.Const"() <{value = dense<2> : tensor<1xi32>}> : () -> tensor<1xi32>
%cst_1 = "tf.Const"() <{value = dense<9.99999997E-7> : tensor<f32>}> : () -> tensor<f32>
%cst_2 = "tf.Const"() <{value = dense<[[[0.1, 0.2, 0.3, 0.4]], [[0.5, 0.6, 0.7, 0.8]]]> : tensor<2x1x4xf32>}> : () -> tensor<2x1x4xf32>
%cst_3 = "tf.Const"() <{value = dense<[[[0.01, 0.02, 0.03, 0.04]], [[0.05, 0.06, 0.07, 0.08]]]> : tensor<2x1x4xf32>}> : () -> tensor<2x1x4xf32>
%0 = "tf.Mean"(%arg0, %cst) <{keep_dims = true}> {device = ""} : (tensor<2x8x4xf32>, tensor<1xi32>) -> tensor<2x8x1xf32>
%1 = "tf.SquaredDifference"(%arg0, %0) {device = ""} : (tensor<2x8x4xf32>, tensor<2x8x1xf32>) -> tensor<2x8x4xf32>
%2 = "tf.Mean"(%1, %cst) <{keep_dims = true}> {device = ""} : (tensor<2x8x4xf32>, tensor<1xi32>) -> tensor<2x8x1xf32>
%3 = "tf.AddV2"(%2, %cst_1) {device = ""} : (tensor<2x8x1xf32>, tensor<f32>) -> tensor<2x8x1xf32>
%4 = "tf.Rsqrt"(%3) {device = ""} : (tensor<2x8x1xf32>) -> tensor<2x8x1xf32>
%5 = "tf.Mul"(%4, %cst_2) {device = ""} : (tensor<2x8x1xf32>, tensor<2x1x4xf32>) -> tensor<2x8x4xf32>
%6 = "tf.Mul"(%arg0, %5) {device = ""} : (tensor<2x8x4xf32>, tensor<2x8x4xf32>) -> tensor<2x8x4xf32>
%7 = "tf.Mul"(%5, %0) {device = ""} : (tensor<2x8x4xf32>, tensor<2x8x1xf32>) -> tensor<2x8x4xf32>
%8 = "tf.Sub"(%cst_3, %7) {device = ""} : (tensor<2x1x4xf32>, tensor<2x8x4xf32>) -> tensor<2x8x4xf32>
%9 = "tf.AddV2"(%6, %8) {device = ""} : (tensor<2x8x4xf32>, tensor<2x8x4xf32>) -> tensor<2x8x4xf32>
func.return %9 : tensor<2x8x4xf32>
}
// CHECK-LABEL: func.func @layer_norm_multi_dim_v2(%arg0: tensor<2x8x4xf32>) -> tensor<2x8x4xf32> {
// CHECK: %0 = "tf.Reshape"(%arg0, %cst_1) : (tensor<2x8x4xf32>, tensor<2xi64>) -> tensor<16x4xf32>
// CHECK-NEXT: %1 = mhlo.custom_call @byteir.layer_norm(%0, %cst, %cst_0) {backend_config = "", byteir_attrs = {axis = [1], epsilon = 9.9999999747524271E-7 : f64}} : (tensor<16x4xf32>, tensor<4xf32>, tensor<4xf32>) -> tensor<16x4xf32>
// CHECK-NEXT: %2 = "tf.Reshape"(%1, %cst_2) : (tensor<16x4xf32>, tensor<3xi64>) -> tensor<2x8x4xf32>
// CHECK-NEXT: %3 = "tf.Mul"(%2, %cst_3) : (tensor<2x8x4xf32>, tensor<2x1x4xf32>) -> tensor<2x8x4xf32>
// CHECK-NEXT: %4 = "tf.Add"(%3, %cst_4) : (tensor<2x8x4xf32>, tensor<2x1x4xf32>) -> tensor<2x8x4xf32>
// CHECK-NEXT: return %4 : tensor<2x8x4xf32>

func.func @layer_norm_swap_add(%arg0: tensor<2x32x3xf32>) -> tensor<2x32x3xf32> {
%cst_15 = "tf.Const"() {value = dense<9.99999997E-7> : tensor<f32>} : () -> tensor<f32>
Expand Down Expand Up @@ -541,7 +559,7 @@ func.func @l2_norm_V1(%arg0: tensor<1x32x3xf32>) -> tensor<1x32x3xf32> {
// CHECK-LABEL: func.func @l2_norm_V1(%arg0: tensor<1x32x3xf32>) -> tensor<1x32x3xf32> {
// CHECK: mhlo.custom_call
// CHECK-SAME: @byteir.l2_norm
// CHECK-SAME: byteir_attrs = {axis = [2], epsilon = 9.9999999747524271E-7 : f64}
// CHECK-SAME: byteir_attrs = {axis = [2], eps_outside_sqrt = false, epsilon = 9.9999999747524271E-7 : f64}

func.func @l2_norm_V1_with_multiplyer(%arg0: tensor<2x4x8xf32>) -> tensor<2x4x8xf32> {
%cst = "tf.Const"() <{value = dense<-1> : tensor<i32>}> : () -> tensor<i32>
Expand All @@ -557,7 +575,7 @@ func.func @l2_norm_V1_with_multiplyer(%arg0: tensor<2x4x8xf32>) -> tensor<2x4x8x
}
// CHECK-LABEL: func.func @l2_norm_V1_with_multiplyer(%arg0: tensor<2x4x8xf32>) -> tensor<2x4x8xf32> {
// CHECK-NEXT: %0 = mhlo.constant dense<9.34093475> : tensor<2x4x8xf32>
// CHECK-NEXT: %1 = mhlo.custom_call @byteir.l2_norm(%arg0) {backend_config = "", byteir_attrs = {axis = [2], epsilon = 9.999999960041972E-13 : f64}} : (tensor<2x4x8xf32>) -> tensor<2x4
// CHECK-NEXT: %1 = mhlo.custom_call @byteir.l2_norm(%arg0) {backend_config = "", byteir_attrs = {axis = [2], eps_outside_sqrt = false, epsilon = 9.999999960041972E-13 : f64}} : (tensor<2x4x8xf32>) -> tensor<2x4
// CHECK-NEXT: %2 = mhlo.multiply %1, %0 : tensor<2x4x8xf32>

func.func @l2_norm_V1_swap_mul(%54: tensor<1x64xf32>) -> tensor<1x64xf32> {
Expand Down Expand Up @@ -586,7 +604,7 @@ func.func @l2_norm_V2(%1871: tensor<1x128xf16>) -> tensor<1x128xf16> {
// CHECK-LABEL: @l2_norm_V2
// CHECK: mhlo.custom_call
// CHECK-SAME: @byteir.l2_norm
// CHECK-SAME: byteir_attrs = {axis = [1], epsilon = 0.000000e+00 : f64}
// CHECK-SAME: byteir_attrs = {axis = [1], eps_outside_sqrt = false, epsilon = 0.000000e+00 : f64}

func.func @l2_norm_V2_swap_mul(%1871: tensor<1x128xf16>) -> tensor<1x128xf16> {
%cst_5 = "tf.Const"() {value = dense<1> : tensor<i32>} : () -> tensor<i32>
Expand All @@ -600,7 +618,7 @@ func.func @l2_norm_V2_swap_mul(%1871: tensor<1x128xf16>) -> tensor<1x128xf16> {
// CHECK-LABEL: @l2_norm_V2_swap_mul
// CHECK: mhlo.custom_call
// CHECK-SAME: @byteir.l2_norm
// CHECK-SAME: byteir_attrs = {axis = [1], epsilon = 0.000000e+00 : f64}
// CHECK-SAME: byteir_attrs = {axis = [1], eps_outside_sqrt = false, epsilon = 0.000000e+00 : f64}

func.func @l2_norm_V3(%15: tensor<1x100x512xf32>) -> tensor<1x100x512xf32> {
%cst_96 = "tf.Const"() {value = dense<2> : tensor<1xi32>} : () -> tensor<1xi32>
Expand All @@ -613,7 +631,28 @@ func.func @l2_norm_V3(%15: tensor<1x100x512xf32>) -> tensor<1x100x512xf32> {
// CHECK-LABEL: @l2_norm_V3
// CHECK: mhlo.custom_call
// CHECK-SAME: @byteir.l2_norm
// CHECK-SAME: byteir_attrs = {axis = [2], epsilon = 0.000000e+00 : f64}
// CHECK-SAME: byteir_attrs = {axis = [2], eps_outside_sqrt = false, epsilon = 0.000000e+00 : f64}

func.func @l2_norm_couple_with_batchmatmulv2(%arg0: tensor<1x8x16xf32>, %arg1: tensor<1x32x16xf32>) -> tensor<1x8x32xf32> {
%cst = "tf.Const"() <{value = dense<2> : tensor<1xi32>}> : () -> tensor<1xi32>
%cst_1 = "tf.Const"() <{value = dense<1.013280e-06> : tensor<f32>}> : () -> tensor<f32>
%0 = "tf.Square"(%arg0) {device = ""} : (tensor<1x8x16xf32>) -> tensor<1x8x16xf32>
%1 = "tf.Sum"(%0, %cst) <{keep_dims = true}> {device = ""} : (tensor<1x8x16xf32>, tensor<1xi32>) -> tensor<1x8x1xf32>
%2 = "tf.Sqrt"(%1) {device = ""} : (tensor<1x8x1xf32>) -> tensor<1x8x1xf32>
%3 = "tf.BatchMatMulV2"(%arg0, %arg1) <{adj_x = false, adj_y = true, grad_x = false, grad_y = false}> {device = ""} : (tensor<1x8x16xf32>, tensor<1x32x16xf32>) -> tensor<1x8x32xf32>
%4 = "tf.Square"(%arg1) {device = ""} : (tensor<1x32x16xf32>) -> tensor<1x32x16xf32>
%5 = "tf.Sum"(%4, %cst) <{keep_dims = true}> {device = ""} : (tensor<1x32x16xf32>, tensor<1xi32>) -> tensor<1x32x1xf32>
%6 = "tf.Sqrt"(%5) {device = ""} : (tensor<1x32x1xf32>) -> tensor<1x32x1xf32>
%7 = "tf.BatchMatMulV2"(%2, %6) <{adj_x = false, adj_y = true, grad_x = false, grad_y = false}> {device = ""} : (tensor<1x8x1xf32>, tensor<1x32x1xf32>) -> tensor<1x8x32xf32>
%8 = "tf.AddV2"(%7, %cst_1) {device = ""} : (tensor<1x8x32xf32>, tensor<f32>) -> tensor<1x8x32xf32>
%9 = "tf.RealDiv"(%3, %8) {device = ""} : (tensor<1x8x32xf32>, tensor<1x8x32xf32>) -> tensor<1x8x32xf32>
return %9 : tensor<1x8x32xf32>
}
// CHECK-LABEL: func.func @l2_norm_couple_with_batchmatmulv2(%arg0: tensor<1x8x16xf32>, %arg1: tensor<1x32x16xf32>) -> tensor<1x8x32xf32> {
// CHECK-LABEL: %0 = mhlo.custom_call @byteir.l2_norm(%arg0) {backend_config = "", byteir_attrs = {axis = [2], eps_outside_sqrt = true, epsilon = 1.0132799843631801E-6 : f64}} : (tensor<1x8x16xf32>) -> tensor<1x8x16xf32>
// CHECK-LABEL: %1 = mhlo.custom_call @byteir.l2_norm(%arg1) {backend_config = "", byteir_attrs = {axis = [2], eps_outside_sqrt = true, epsilon = 1.0132799843631801E-6 : f64}} : (tensor<1x32x16xf32>) -> tensor<1x32x16xf32>
// CHECK-LABEL: %2 = "tf.BatchMatMulV2"(%0, %1) <{adj_x = false, adj_y = true}> : (tensor<1x8x16xf32>, tensor<1x32x16xf32>) -> tensor<1x8x32xf32>
// CHECK-LABEL: return %2 : tensor<1x8x32xf32>

func.func @dynamic_mask_stitch(%arg0: tensor<4x4xf32>, %arg1: tensor<4xi32>) -> tensor<?x4xf32> {
%cst = "tf.Const"() {value = dense<[[-0.916170597, -0.884184718, 1.60242105, -1.19678485], [0.33643803, -0.431175768, 1.71861267, 0.126368985], [-1.07191086, -1.00517535, -0.666032254, 0.776807785], [1.53380013, 0.83925873, -0.24277249, 1.53341103]]> : tensor<4x4xf32>} : () -> tensor<4x4xf32>
Expand Down
Loading
Loading