-
Notifications
You must be signed in to change notification settings - Fork 183
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Llama2 model Operator/Layer level instance extraction
- Loading branch information
hayden-brown
committed
Jul 26, 2024
1 parent
ec8a179
commit 194f0bd
Showing
15 changed files
with
1,473 additions
and
0 deletions.
There are no files selected for viewing
Large diffs are not rendered by default.
Oops, something went wrong.
Large diffs are not rendered by default.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,44 @@ | ||
func.func private @rtclock() -> f64 | ||
|
||
func.func @kernel_fc_layer(%arg0 : tensor<1x40x4096xf32>, %arg1 : tensor<4096x4096xf32>, %arg2 : tensor<4096x4096xf32>, %arg3 : tensor<1x40x4096xf32>) { | ||
%t_start = call @rtclock() : () -> f64 | ||
|
||
%cst_0 = arith.constant dense<0.0> : tensor<40x4096xf32> | ||
%cst_1 = arith.constant dense<0.0> : tensor<40x4096xf32> | ||
|
||
%41 = tosa.mul %arg0, %arg3 {shift = 0 : i8} : (tensor<1x40x4096xf32>, tensor<1x40x4096xf32>) -> tensor<1x40x4096xf32> | ||
%42 = "tosa.const"() <{value = dense<[1, 0]> : tensor<2xi32>}> : () -> tensor<2xi32> | ||
%43 = tosa.transpose %arg1, %42 : (tensor<4096x4096xf32>, tensor<2xi32>) -> tensor<4096x4096xf32> | ||
%44 = tosa.reshape %41 {new_shape = array<i64: 40, 4096>} : (tensor<1x40x4096xf32>) -> tensor<40x4096xf32> | ||
%45 = linalg.matmul {cast = #linalg.type_fn<cast_signed>} ins(%44, %43 : tensor<40x4096xf32>, tensor<4096x4096xf32>) outs(%cst_0 : tensor<40x4096xf32>) -> tensor<40x4096xf32> | ||
%46 = tosa.reshape %45 {new_shape = array<i64: 1, 40, 4096>} : (tensor<40x4096xf32>) -> tensor<1x40x4096xf32> | ||
|
||
%47 = "tosa.const"() <{value = dense<[1, 0]> : tensor<2xi32>}> : () -> tensor<2xi32> | ||
%48 = tosa.transpose %arg2, %47 : (tensor<4096x4096xf32>, tensor<2xi32>) -> tensor<4096x4096xf32> | ||
%49 = tosa.reshape %41 {new_shape = array<i64: 40, 4096>} : (tensor<1x40x4096xf32>) -> tensor<40x4096xf32> | ||
%50 = linalg.matmul {cast = #linalg.type_fn<cast_signed>} ins(%49, %48 : tensor<40x4096xf32>, tensor<4096x4096xf32>) outs(%cst_1 : tensor<40x4096xf32>) -> tensor<40x4096xf32> | ||
%51 = tosa.reshape %50 {new_shape = array<i64: 1, 40, 4096>} : (tensor<40x4096xf32>) -> tensor<1x40x4096xf32> | ||
|
||
%t_end = call @rtclock() : () -> f64 | ||
%time = arith.subf %t_end, %t_start : f64 | ||
|
||
%tensor_unranked = tensor.cast %51 : tensor<1x40x4096xf32> to tensor<*xf32> | ||
|
||
call @printMemrefF32(%tensor_unranked) : (tensor<*xf32>) -> () | ||
vector.print %time : f64 | ||
|
||
return | ||
} | ||
|
||
func.func @main() { | ||
%input_tensor_1 = arith.constant dense<3.0> : tensor<1x40x4096xf32> | ||
%input_tensor_2 = arith.constant dense<2.0> : tensor<4096x4096xf32> | ||
%input_tensor_3 = arith.constant dense<1.0> : tensor<4096x4096xf32> | ||
%input_tensor_4 = arith.constant dense<4.0> : tensor<1x40x4096xf32> | ||
|
||
call @kernel_fc_layer(%input_tensor_1, %input_tensor_2, %input_tensor_3, %input_tensor_4) : (tensor<1x40x4096xf32>, tensor<4096x4096xf32>, tensor<4096x4096xf32>, tensor<1x40x4096xf32>) -> () | ||
|
||
return | ||
} | ||
|
||
func.func private @printMemrefF32(%ptr : tensor<*xf32>) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,64 @@ | ||
#map = affine_map<(d0, d1, d2) -> (d1)> | ||
#map1 = affine_map<(d0, d1, d2) -> (d0, d2)> | ||
#map2 = affine_map<(d0, d1, d2) -> (d0, d1)> | ||
#map3 = affine_map<(d0, d1) -> (d0, d1)> | ||
#map4 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)> | ||
#map5 = affine_map<(d0, d1, d2) -> (d0, d1, d2)> | ||
#map6 = affine_map<(d0, d1, d2) -> (d0, 0, d1, d2)> | ||
#map7 = affine_map<(d0, d1) -> (0, d0, d1)> | ||
|
||
func.func private @rtclock() -> f64 | ||
|
||
func.func @kernel_ffn(%arg0: tensor<1x40x4096xf32>, %arg9: tensor<4096xf32>, %arg10: tensor<11008x4096xf32>, %arg11: tensor<11008x4096xf32>, %arg12: tensor<4096x11008xf32>) { | ||
%t_start = call @rtclock() : () -> f64 | ||
|
||
// FFN | ||
%138 = tosa.reshape %arg9 {new_shape = array<i64: 1, 1, 4096>} : (tensor<4096xf32>) -> tensor<1x1x4096xf32> | ||
%139 = tosa.mul %138, %arg0 {shift = 0 : i8} : (tensor<1x1x4096xf32>, tensor<1x40x4096xf32>) -> tensor<1x40x4096xf32> | ||
%140 = "tosa.const"() <{value = dense<[1, 0]> : tensor<2xi32>}> : () -> tensor<2xi32> | ||
%141 = tosa.transpose %arg10, %140 : (tensor<11008x4096xf32>, tensor<2xi32>) -> tensor<4096x11008xf32> | ||
%142 = tosa.reshape %139 {new_shape = array<i64: 40, 4096>} : (tensor<1x40x4096xf32>) -> tensor<40x4096xf32> | ||
%cst_24 = arith.constant dense<0.000000e+00> : tensor<40x11008xf32> | ||
%143 = linalg.matmul {cast = #linalg.type_fn<cast_signed>} ins(%142, %141 : tensor<40x4096xf32>, tensor<4096x11008xf32>) outs(%cst_24 : tensor<40x11008xf32>) -> tensor<40x11008xf32> | ||
%144 = tosa.reshape %143 {new_shape = array<i64: 1, 40, 11008>} : (tensor<40x11008xf32>) -> tensor<1x40x11008xf32> | ||
%145 = tosa.sigmoid %144 : (tensor<1x40x11008xf32>) -> tensor<1x40x11008xf32> | ||
%146 = tosa.mul %144, %145 {shift = 0 : i8} : (tensor<1x40x11008xf32>, tensor<1x40x11008xf32>) -> tensor<1x40x11008xf32> | ||
%147 = "tosa.const"() <{value = dense<[1, 0]> : tensor<2xi32>}> : () -> tensor<2xi32> | ||
%148 = tosa.transpose %arg11, %147 : (tensor<11008x4096xf32>, tensor<2xi32>) -> tensor<4096x11008xf32> | ||
%149 = tosa.reshape %139 {new_shape = array<i64: 40, 4096>} : (tensor<1x40x4096xf32>) -> tensor<40x4096xf32> | ||
%cst_25 = arith.constant dense<0.000000e+00> : tensor<40x11008xf32> | ||
%150 = linalg.matmul {cast = #linalg.type_fn<cast_signed>} ins(%149, %148 : tensor<40x4096xf32>, tensor<4096x11008xf32>) outs(%cst_25 : tensor<40x11008xf32>) -> tensor<40x11008xf32> | ||
%151 = tosa.reshape %150 {new_shape = array<i64: 1, 40, 11008>} : (tensor<40x11008xf32>) -> tensor<1x40x11008xf32> | ||
%152 = tosa.mul %146, %151 {shift = 0 : i8} : (tensor<1x40x11008xf32>, tensor<1x40x11008xf32>) -> tensor<1x40x11008xf32> | ||
%153 = "tosa.const"() <{value = dense<[1, 0]> : tensor<2xi32>}> : () -> tensor<2xi32> | ||
%154 = tosa.transpose %arg12, %153 : (tensor<4096x11008xf32>, tensor<2xi32>) -> tensor<11008x4096xf32> | ||
%155 = tosa.reshape %152 {new_shape = array<i64: 40, 11008>} : (tensor<1x40x11008xf32>) -> tensor<40x11008xf32> | ||
%cst_26 = arith.constant dense<0.000000e+00> : tensor<40x4096xf32> | ||
%156 = linalg.matmul {cast = #linalg.type_fn<cast_signed>} ins(%155, %154 : tensor<40x11008xf32>, tensor<11008x4096xf32>) outs(%cst_26 : tensor<40x4096xf32>) -> tensor<40x4096xf32> | ||
%157 = tosa.reshape %156 {new_shape = array<i64: 1, 40, 4096>} : (tensor<40x4096xf32>) -> tensor<1x40x4096xf32> | ||
%158 = tosa.add %arg0, %157 : (tensor<1x40x4096xf32>, tensor<1x40x4096xf32>) -> tensor<1x40x4096xf32> | ||
|
||
%t_end = call @rtclock() : () -> f64 | ||
%time = arith.subf %t_end, %t_start : f64 | ||
|
||
%tensor_unranked = tensor.cast %158 : tensor<1x40x4096xf32> to tensor<*xf32> | ||
|
||
call @printMemrefF32(%tensor_unranked) : (tensor<*xf32>) -> () | ||
vector.print %time : f64 | ||
|
||
return | ||
} | ||
|
||
func.func @main() { | ||
%input_tensor = arith.constant dense<3.0> : tensor<1x40x4096xf32> | ||
%weight1 = arith.constant dense<1.0> : tensor<4096xf32> | ||
%weight2 = arith.constant dense<1.0> : tensor<11008x4096xf32> | ||
%weight3 = arith.constant dense<2.0> : tensor<11008x4096xf32> | ||
%weight4 = arith.constant dense<1.0> : tensor<4096x11008xf32> | ||
|
||
call @kernel_ffn(%input_tensor, %weight1, %weight2, %weight3, %weight4) : (tensor<1x40x4096xf32>, tensor<4096xf32>, tensor<11008x4096xf32>, tensor<11008x4096xf32>, tensor<4096x11008xf32>) -> () | ||
|
||
return | ||
} | ||
|
||
func.func private @printMemrefF32(%ptr : tensor<*xf32>) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,36 @@ | ||
module { | ||
func.func private @rtclock() -> f64 | ||
|
||
func.func @kernel_fpowi(%arg0: tensor<1x32x40x64xf32>) { | ||
%t_start = call @rtclock() : () -> f64 | ||
|
||
// Power operation | ||
%c2_i32 = arith.constant 2 : i32 | ||
%output_tensor = tensor.empty() : tensor<1x32x40x64xf32> | ||
%result = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%arg0 : tensor<1x32x40x64xf32>) outs(%output_tensor : tensor<1x32x40x64xf32>) { | ||
^bb0(%in: f32, %out: f32): | ||
%0 = math.fpowi %in, %c2_i32 : f32, i32 | ||
linalg.yield %0 : f32 | ||
} -> tensor<1x32x40x64xf32> | ||
|
||
%t_end = call @rtclock() : () -> f64 | ||
%time = arith.subf %t_end, %t_start : f64 | ||
|
||
%tensor_unranked = tensor.cast %result : tensor<1x32x40x64xf32> to tensor<*xf32> | ||
|
||
call @printMemrefF32(%tensor_unranked) : (tensor<*xf32>) -> () | ||
vector.print %time : f64 | ||
|
||
return | ||
} | ||
|
||
func.func @main() { | ||
%input_tensor = arith.constant dense<5.0> : tensor<1x32x40x64xf32> | ||
|
||
call @kernel_fpowi(%input_tensor) : (tensor<1x32x40x64xf32>) -> () | ||
|
||
return | ||
} | ||
|
||
func.func private @printMemrefF32(%ptr : tensor<*xf32>) | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,29 @@ | ||
func.func private @rtclock() -> f64 | ||
|
||
func.func @kernel_matmul(%arg0 : tensor<40x4096xf32>, %arg1 : tensor<4096x4096xf32>, %arg2 : tensor<40x4096xf32>) { | ||
%t_start = call @rtclock() : () -> f64 | ||
|
||
%matmul_result = linalg.matmul {cast = #linalg.type_fn<cast_signed>} ins(%arg0, %arg1 : tensor<40x4096xf32>, tensor<4096x4096xf32>) outs(%arg2 : tensor<40x4096xf32>) -> tensor<40x4096xf32> | ||
|
||
%t_end = call @rtclock() : () -> f64 | ||
%time = arith.subf %t_end, %t_start : f64 | ||
|
||
%tensor_unranked = tensor.cast %matmul_result : tensor<40x4096xf32> to tensor<*xf32> | ||
|
||
call @printMemrefF32(%tensor_unranked) : (tensor<*xf32>) -> () | ||
vector.print %time : f64 | ||
|
||
return | ||
} | ||
|
||
func.func @main() { | ||
%input_tensor_1 = arith.constant dense<3.0> : tensor<40x4096xf32> | ||
%input_tensor_2 = arith.constant dense<2.0> : tensor<4096x4096xf32> | ||
%output_tensor = arith.constant dense<0.0> : tensor<40x4096xf32> | ||
|
||
call @kernel_matmul(%input_tensor_1, %input_tensor_2, %output_tensor) : (tensor<40x4096xf32>, tensor<4096x4096xf32>, tensor<40x4096xf32>) -> () | ||
|
||
return | ||
} | ||
|
||
func.func private @printMemrefF32(%ptr : tensor<*xf32>) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,31 @@ | ||
module { | ||
func.func private @rtclock() -> f64 | ||
|
||
func.func @kernel_mul(%arg0: tensor<1xf32>, %arg1: tensor<1x40x1xf32>) { | ||
%t_start = call @rtclock() : () -> f64 | ||
|
||
// Perform the multiplication operation | ||
%mul_result = tosa.mul %arg0, %arg1 {shift = 0 : i8} : (tensor<1xf32>, tensor<1x40x1xf32>) -> tensor<1x40x1xf32> | ||
|
||
%t_end = call @rtclock() : () -> f64 | ||
%time = arith.subf %t_end, %t_start : f64 | ||
|
||
%tensor_unranked = tensor.cast %mul_result : tensor<1x40x1xf32> to tensor<*xf32> | ||
|
||
call @printMemrefF32(%tensor_unranked) : (tensor<*xf32>) -> () | ||
vector.print %time : f64 | ||
|
||
return | ||
} | ||
|
||
func.func @main() { | ||
%input_tensor_1 = arith.constant dense<3.0> : tensor<1xf32> | ||
%input_tensor_2 = arith.constant dense<2.0> : tensor<1x40x1xf32> | ||
|
||
call @kernel_mul(%input_tensor_1, %input_tensor_2) : (tensor<1xf32>, tensor<1x40x1xf32>) -> () | ||
|
||
return | ||
} | ||
|
||
func.func private @printMemrefF32(%ptr: tensor<*xf32>) | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,30 @@ | ||
module { | ||
func.func private @rtclock() -> f64 | ||
|
||
func.func @kernel_negate(%arg0: tensor<1x32x40x64xf32>) { | ||
%t_start = call @rtclock() : () -> f64 | ||
|
||
// Negate operation | ||
%negated = tosa.negate %arg0 : (tensor<1x32x40x64xf32>) -> tensor<1x32x40x64xf32> | ||
|
||
%t_end = call @rtclock() : () -> f64 | ||
%time = arith.subf %t_end, %t_start : f64 | ||
|
||
%tensor_unranked = tensor.cast %negated : tensor<1x32x40x64xf32> to tensor<*xf32> | ||
|
||
call @printMemrefF32(%tensor_unranked) : (tensor<*xf32>) -> () | ||
vector.print %time : f64 | ||
|
||
return | ||
} | ||
|
||
func.func @main() { | ||
%input_tensor = arith.constant dense<1.0> : tensor<1x32x40x64xf32> | ||
|
||
call @kernel_negate(%input_tensor) : (tensor<1x32x40x64xf32>) -> () | ||
|
||
return | ||
} | ||
|
||
func.func private @printMemrefF32(%ptr : tensor<*xf32>) | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,30 @@ | ||
module { | ||
func.func private @rtclock() -> f64 | ||
|
||
func.func @kernel_reciprocal(%arg0: tensor<1x10xf32>) { | ||
%t_start = call @rtclock() : () -> f64 | ||
|
||
// Reciprocal operation | ||
%result = tosa.reciprocal %arg0 : (tensor<1x10xf32>) -> tensor<1x10xf32> | ||
|
||
%t_end = call @rtclock() : () -> f64 | ||
%time = arith.subf %t_end, %t_start : f64 | ||
|
||
%tensor_unranked = tensor.cast %result : tensor<1x10xf32> to tensor<*xf32> | ||
|
||
call @printMemrefF32(%tensor_unranked) : (tensor<*xf32>) -> () | ||
vector.print %time : f64 | ||
|
||
return | ||
} | ||
|
||
func.func @main() { | ||
%input_tensor = "tosa.const"() {value = dense<[[1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0]]> : tensor<1x10xf32>} : () -> tensor<1x10xf32> | ||
|
||
call @kernel_reciprocal(%input_tensor) : (tensor<1x10xf32>) -> () | ||
|
||
return | ||
} | ||
|
||
func.func private @printMemrefF32(%ptr : tensor<*xf32>) | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,30 @@ | ||
module { | ||
func.func private @rtclock() -> f64 | ||
|
||
func.func @kernel_reduce_sum(%arg0: tensor<1x40x4096xf32>) { | ||
%t_start = call @rtclock() : () -> f64 | ||
|
||
// Reduce sum operation | ||
%result = tosa.reduce_sum %arg0 {axis = 2 : i32} : (tensor<1x40x4096xf32>) -> tensor<1x40x1xf32> | ||
|
||
%t_end = call @rtclock() : () -> f64 | ||
%time = arith.subf %t_end, %t_start : f64 | ||
|
||
%tensor_unranked = tensor.cast %result : tensor<1x40x1xf32> to tensor<*xf32> | ||
|
||
call @printMemrefF32(%tensor_unranked) : (tensor<*xf32>) -> () | ||
vector.print %time : f64 | ||
|
||
return | ||
} | ||
|
||
func.func @main() { | ||
%input_tensor = arith.constant dense<1.0> : tensor<1x40x4096xf32> | ||
|
||
call @kernel_reduce_sum(%input_tensor) : (tensor<1x40x4096xf32>) -> () | ||
|
||
return | ||
} | ||
|
||
func.func private @printMemrefF32(%ptr : tensor<*xf32>) | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,51 @@ | ||
#map = affine_map<(d0, d1, d2) -> (d1)> | ||
#map1 = affine_map<(d0, d1, d2) -> (d0, d2)> | ||
#map2 = affine_map<(d0, d1, d2) -> (d0, d1)> | ||
#map3 = affine_map<(d0, d1) -> (d0, d1)> | ||
#map4 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)> | ||
#map5 = affine_map<(d0, d1, d2) -> (d0, d1, d2)> | ||
#map6 = affine_map<(d0, d1, d2) -> (d0, 0, d1, d2)> | ||
#map7 = affine_map<(d0, d1) -> (0, d0, d1)> | ||
|
||
func.func private @rtclock() -> f64 | ||
|
||
func.func @kernel_rmsnorm(%arg0: tensor<1x40x4096xf32>) { | ||
%t_start = call @rtclock() : () -> f64 | ||
|
||
// RMSNorm operations | ||
%30 = tensor.empty() : tensor<1x40x4096xf32> | ||
%c2_i32 = arith.constant 2 : i32 | ||
%31 = linalg.generic {indexing_maps = [#map5, #map5], iterator_types = ["parallel", "parallel", "parallel"]} ins(%arg0 : tensor<1x40x4096xf32>) outs(%30 : tensor<1x40x4096xf32>) { | ||
^bb0(%in: f32, %out: f32): | ||
%4175 = math.fpowi %in, %c2_i32 : f32, i32 | ||
linalg.yield %4175 : f32 | ||
} -> tensor<1x40x4096xf32> | ||
%32 = tosa.reduce_sum %31 {axis = 2 : i32} : (tensor<1x40x4096xf32>) -> tensor<1x40x1xf32> | ||
%33 = "tosa.const"() <{value = dense<4.096000e+03> : tensor<1x1xf32>}> : () -> tensor<1x1xf32> | ||
%34 = tosa.reciprocal %33 : (tensor<1x1xf32>) -> tensor<1x1xf32> | ||
%35 = tosa.mul %34, %32 {shift = 0 : i8} : (tensor<1x1xf32>, tensor<1x40x1xf32>) -> tensor<1x40x1xf32> | ||
%36 = "tosa.const"() <{value = dense<9.99999974E-6> : tensor<1x40x1xf32>}> : () -> tensor<1x40x1xf32> | ||
%37 = tosa.add %35, %36 : (tensor<1x40x1xf32>, tensor<1x40x1xf32>) -> tensor<1x40x1xf32> | ||
%38 = tosa.rsqrt %37 : (tensor<1x40x1xf32>) -> tensor<1x40x1xf32> | ||
%39 = tosa.mul %arg0, %38 {shift = 0 : i8} : (tensor<1x40x4096xf32>, tensor<1x40x1xf32>) -> tensor<1x40x4096xf32> | ||
|
||
%t_end = call @rtclock() : () -> f64 | ||
%time = arith.subf %t_end, %t_start : f64 | ||
|
||
%tensor_unranked = tensor.cast %39 : tensor<1x40x4096xf32> to tensor<*xf32> | ||
|
||
call @printMemrefF32(%tensor_unranked) : (tensor<*xf32>) -> () | ||
vector.print %time : f64 | ||
|
||
return | ||
} | ||
|
||
func.func @main() { | ||
%input_tensor_1 = arith.constant dense<3.0> : tensor<1x40x4096xf32> | ||
|
||
call @kernel_rmsnorm(%input_tensor_1) : (tensor<1x40x4096xf32>) -> () | ||
|
||
return | ||
} | ||
|
||
func.func private @printMemrefF32(%ptr : tensor<*xf32>) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,28 @@ | ||
func.func private @rtclock() -> f64 | ||
|
||
func.func @kernel_rsqrt(%arg0 : tensor<1x40x1xf32>) { | ||
%t_start = call @rtclock() : () -> f64 | ||
|
||
// rsqrt operation | ||
%rsqrt_result = tosa.rsqrt %arg0 : (tensor<1x40x1xf32>) -> tensor<1x40x1xf32> | ||
|
||
%t_end = call @rtclock() : () -> f64 | ||
%time = arith.subf %t_end, %t_start : f64 | ||
|
||
%tensor_unranked = tensor.cast %rsqrt_result : tensor<1x40x1xf32> to tensor<*xf32> | ||
|
||
call @printMemrefF32(%tensor_unranked) : (tensor<*xf32>) -> () | ||
vector.print %time : f64 | ||
|
||
return | ||
} | ||
|
||
func.func @main() { | ||
%input_tensor = arith.constant dense<3.0> : tensor<1x40x1xf32> | ||
|
||
call @kernel_rsqrt(%input_tensor) : (tensor<1x40x1xf32>) -> () | ||
|
||
return | ||
} | ||
|
||
func.func private @printMemrefF32(%ptr : tensor<*xf32>) |
Oops, something went wrong.