diff --git a/core/shark_turbine/kernel/_support/tracing.py b/core/shark_turbine/kernel/_support/tracing.py index 92d898f47..061f2dd9e 100644 --- a/core/shark_turbine/kernel/_support/tracing.py +++ b/core/shark_turbine/kernel/_support/tracing.py @@ -390,7 +390,7 @@ def __call__(self, *args, **kwargs): def eager_execute(self, args, kwargs): ... - def benchmark_execute(self, args, kwargs): + def test_execute(self, args, kwargs): ... @@ -441,9 +441,9 @@ def launch(self, launchable: Launchable, args, kwargs): return launchable.eager_execute(args, kwargs) -class BenchmarkLaunchContext(LaunchContext): +class TestLaunchContext(LaunchContext): def launch(self, launchable: Launchable, args, kwargs): - return launchable.benchmark_execute(args, kwargs) + return launchable.test_execute(args, kwargs) ############################################################################### diff --git a/core/shark_turbine/kernel/compiler/host_codegen.py b/core/shark_turbine/kernel/compiler/host_codegen.py index df841425e..0a35ff87e 100644 --- a/core/shark_turbine/kernel/compiler/host_codegen.py +++ b/core/shark_turbine/kernel/compiler/host_codegen.py @@ -60,7 +60,7 @@ def memref_to_tensor(memrefs: list[IrType]): return tensors -def isolated_benchmark_call( +def isolated_test_call( mb: ModuleBuilder, exe: StreamExecutable, sig: KernelSignature, entrypoint: str ): with InsertionPoint(mb.body_block), Location.unknown(): diff --git a/core/shark_turbine/kernel/gen/__init__.py b/core/shark_turbine/kernel/gen/__init__.py index e2d7ed139..f5bee0e1e 100644 --- a/core/shark_turbine/kernel/gen/__init__.py +++ b/core/shark_turbine/kernel/gen/__init__.py @@ -1,3 +1,3 @@ from .thread import * -from .._support.tracing import BenchmarkLaunchContext +from .._support.tracing import TestLaunchContext diff --git a/core/shark_turbine/kernel/gen/thread.py b/core/shark_turbine/kernel/gen/thread.py index 0d545d05e..99ed315d1 100644 --- a/core/shark_turbine/kernel/gen/thread.py +++ b/core/shark_turbine/kernel/gen/thread.py @@ -101,7 +101,7 @@ def eager_execute(self, args, kwargs): current_thread[-1] = it self._eager_function(*bound.args, **bound.kwargs) - def benchmark_execute(self, args, kwargs): + def test_execute(self, args, kwargs): # Trace the function. trace = self._trace() idxc = IndexingContext.current() @@ -136,7 +136,7 @@ def benchmark_execute(self, args, kwargs): mb.module_op.verify() - host_codegen.isolated_benchmark_call(mb, exe, kernel_sig, entrypoint_name) + host_codegen.isolated_test_call(mb, exe, kernel_sig, entrypoint_name) print(mb.module_op.get_asm()) diff --git a/core/tests/kernel/arith_test.py b/core/tests/kernel/arith_test.py index 04767f4ea..f9977e00e 100644 --- a/core/tests/kernel/arith_test.py +++ b/core/tests/kernel/arith_test.py @@ -46,7 +46,7 @@ def iota_kernel(out: tk.lang.OutputBuffer[M]): c = (a * b) / c c = c + a - b - with tk.gen.BenchmarkLaunchContext(): + with tk.gen.TestLaunchContext(): iota_kernel(torch.zeros(17)) diff --git a/core/tests/kernel/dispatch_codegen_test.py b/core/tests/kernel/dispatch_codegen_test.py index 661ce5820..8bd21c7ca 100644 --- a/core/tests/kernel/dispatch_codegen_test.py +++ b/core/tests/kernel/dispatch_codegen_test.py @@ -34,7 +34,7 @@ def softmax_kernel( input = torch.randn(128, 64) output = torch.zeros(128, 64) - with tk.gen.BenchmarkLaunchContext(): + with tk.gen.TestLaunchContext(): softmax_kernel(input, output) diff --git a/core/tests/kernel/simple_kernel_test.py b/core/tests/kernel/simple_kernel_test.py index df0e52d0c..4d0b94557 100644 --- a/core/tests/kernel/simple_kernel_test.py +++ b/core/tests/kernel/simple_kernel_test.py @@ -23,7 +23,7 @@ def iota_kernel(out: tk.lang.OutputBuffer[M]): out[i] = i out = torch.empty(8, dtype=torch.int32) - with tk.gen.BenchmarkLaunchContext(): + with tk.gen.TestLaunchContext(): iota_kernel(out) print(out) diff --git a/core/tests/kernel/vector_codegen_test.py b/core/tests/kernel/vector_codegen_test.py index 5420fefae..7a9c375ce 100644 --- a/core/tests/kernel/vector_codegen_test.py +++ b/core/tests/kernel/vector_codegen_test.py @@ -19,7 +19,7 @@ def iota_kernel(out: tk.lang.OutputBuffer[M]): secret_value = ((i * (33 - i) + 4) % 8) // 2 out[i] = secret_value - with tk.gen.BenchmarkLaunchContext(): + with tk.gen.TestLaunchContext(): out = torch.zeros(17, dtype=torch.int32) def testSoftmaxFx(self): @@ -33,7 +33,7 @@ def softmax_kernel( output_row = numerator / tkl.sum(numerator) output[row_index, :] = output_row - with tk.gen.BenchmarkLaunchContext(): + with tk.gen.TestLaunchContext(): input = torch.randn(128, 64, dtype=torch.float32) output = torch.zeros(128, 64, dtype=torch.float32) softmax_kernel(input, output) @@ -55,7 +55,7 @@ def prefetch_sum(i, sum, prefetch): output[row_idx, 0] = prefetch_sum[0] - with tk.gen.BenchmarkLaunchContext(): + with tk.gen.TestLaunchContext(): input = torch.randn(128, 64, dtype=torch.float32) output = torch.zeros(128, 64, dtype=torch.float32) for_loop_kernel(input, output) @@ -85,7 +85,7 @@ def body(i, c): tkl.store(output, (grid_n, grid_m), body[0]) - with tk.gen.BenchmarkLaunchContext({BLOCK_SIZE: 32}): + with tk.gen.TestLaunchContext({BLOCK_SIZE: 32}): A = torch.randn(512, 1024, dtype=torch.float32) B = torch.randn(1024, 2048, dtype=torch.float32) output = torch.zeros(512, 2048, dtype=torch.float32)