Skip to content

Commit

Permalink
Benchmark -> Test
Browse files Browse the repository at this point in the history
  • Loading branch information
Groverkss committed Feb 12, 2024
1 parent 9235c5a commit 64c43d8
Show file tree
Hide file tree
Showing 8 changed files with 14 additions and 14 deletions.
6 changes: 3 additions & 3 deletions core/shark_turbine/kernel/_support/tracing.py
Original file line number Diff line number Diff line change
Expand Up @@ -390,7 +390,7 @@ def __call__(self, *args, **kwargs):
def eager_execute(self, args, kwargs):
...

def benchmark_execute(self, args, kwargs):
def test_execute(self, args, kwargs):
...


Expand Down Expand Up @@ -441,9 +441,9 @@ def launch(self, launchable: Launchable, args, kwargs):
return launchable.eager_execute(args, kwargs)


class BenchmarkLaunchContext(LaunchContext):
class TestLaunchContext(LaunchContext):
def launch(self, launchable: Launchable, args, kwargs):
return launchable.benchmark_execute(args, kwargs)
return launchable.test_execute(args, kwargs)


###############################################################################
Expand Down
2 changes: 1 addition & 1 deletion core/shark_turbine/kernel/compiler/host_codegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ def memref_to_tensor(memrefs: list[IrType]):
return tensors


def isolated_benchmark_call(
def isolated_test_call(
mb: ModuleBuilder, exe: StreamExecutable, sig: KernelSignature, entrypoint: str
):
with InsertionPoint(mb.body_block), Location.unknown():
Expand Down
2 changes: 1 addition & 1 deletion core/shark_turbine/kernel/gen/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
from .thread import *

from .._support.tracing import BenchmarkLaunchContext
from .._support.tracing import TestLaunchContext
4 changes: 2 additions & 2 deletions core/shark_turbine/kernel/gen/thread.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ def eager_execute(self, args, kwargs):
current_thread[-1] = it
self._eager_function(*bound.args, **bound.kwargs)

def benchmark_execute(self, args, kwargs):
def test_execute(self, args, kwargs):
# Trace the function.
trace = self._trace()
idxc = IndexingContext.current()
Expand Down Expand Up @@ -136,7 +136,7 @@ def benchmark_execute(self, args, kwargs):

mb.module_op.verify()

host_codegen.isolated_benchmark_call(mb, exe, kernel_sig, entrypoint_name)
host_codegen.isolated_test_call(mb, exe, kernel_sig, entrypoint_name)

print(mb.module_op.get_asm())

Expand Down
2 changes: 1 addition & 1 deletion core/tests/kernel/arith_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ def iota_kernel(out: tk.lang.OutputBuffer[M]):
c = (a * b) / c
c = c + a - b

with tk.gen.BenchmarkLaunchContext():
with tk.gen.TestLaunchContext():
iota_kernel(torch.zeros(17))


Expand Down
2 changes: 1 addition & 1 deletion core/tests/kernel/dispatch_codegen_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ def softmax_kernel(

input = torch.randn(128, 64)
output = torch.zeros(128, 64)
with tk.gen.BenchmarkLaunchContext():
with tk.gen.TestLaunchContext():
softmax_kernel(input, output)


Expand Down
2 changes: 1 addition & 1 deletion core/tests/kernel/simple_kernel_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ def iota_kernel(out: tk.lang.OutputBuffer[M]):
out[i] = i

out = torch.empty(8, dtype=torch.int32)
with tk.gen.BenchmarkLaunchContext():
with tk.gen.TestLaunchContext():
iota_kernel(out)
print(out)

Expand Down
8 changes: 4 additions & 4 deletions core/tests/kernel/vector_codegen_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ def iota_kernel(out: tk.lang.OutputBuffer[M]):
secret_value = ((i * (33 - i) + 4) % 8) // 2
out[i] = secret_value

with tk.gen.BenchmarkLaunchContext():
with tk.gen.TestLaunchContext():
out = torch.zeros(17, dtype=torch.int32)

def testSoftmaxFx(self):
Expand All @@ -33,7 +33,7 @@ def softmax_kernel(
output_row = numerator / tkl.sum(numerator)
output[row_index, :] = output_row

with tk.gen.BenchmarkLaunchContext():
with tk.gen.TestLaunchContext():
input = torch.randn(128, 64, dtype=torch.float32)
output = torch.zeros(128, 64, dtype=torch.float32)
softmax_kernel(input, output)
Expand All @@ -55,7 +55,7 @@ def prefetch_sum(i, sum, prefetch):

output[row_idx, 0] = prefetch_sum[0]

with tk.gen.BenchmarkLaunchContext():
with tk.gen.TestLaunchContext():
input = torch.randn(128, 64, dtype=torch.float32)
output = torch.zeros(128, 64, dtype=torch.float32)
for_loop_kernel(input, output)
Expand Down Expand Up @@ -85,7 +85,7 @@ def body(i, c):

tkl.store(output, (grid_n, grid_m), body[0])

with tk.gen.BenchmarkLaunchContext({BLOCK_SIZE: 32}):
with tk.gen.TestLaunchContext({BLOCK_SIZE: 32}):
A = torch.randn(512, 1024, dtype=torch.float32)
B = torch.randn(1024, 2048, dtype=torch.float32)
output = torch.zeros(512, 2048, dtype=torch.float32)
Expand Down

0 comments on commit 64c43d8

Please sign in to comment.