Skip to content

Commit

Permalink
Use SW pipelining with loops
Browse files Browse the repository at this point in the history
  • Loading branch information
Rahul Batra committed Sep 25, 2024
1 parent ba9a70c commit a542825
Showing 1 changed file with 12 additions and 2 deletions.
14 changes: 12 additions & 2 deletions python/perf-kernels/softmax.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,15 @@ def get_hip_autotune_config():
triton.Config({'waves_per_eu': 4}, num_warps=4, num_stages=1),
triton.Config({'waves_per_eu': 4}, num_warps=8, num_stages=1),
triton.Config({'waves_per_eu': 4}, num_warps=16, num_stages=1),
triton.Config({'waves_per_eu': 1}, num_warps=4, num_stages=2),
triton.Config({'waves_per_eu': 1}, num_warps=8, num_stages=2),
triton.Config({'waves_per_eu': 1}, num_warps=16, num_stages=2),
triton.Config({'waves_per_eu': 2}, num_warps=4, num_stages=2),
triton.Config({'waves_per_eu': 2}, num_warps=8, num_stages=2),
triton.Config({'waves_per_eu': 2}, num_warps=16, num_stages=2),
triton.Config({'waves_per_eu': 4}, num_warps=4, num_stages=2),
triton.Config({'waves_per_eu': 4}, num_warps=8, num_stages=2),
triton.Config({'waves_per_eu': 4}, num_warps=16, num_stages=2),
]


Expand Down Expand Up @@ -69,7 +78,7 @@ def softmax_kernel(
m = -float('inf') #Initial value of max
row_sum = 0.0
row_start_ptr = input_ptr + row_idx * input_row_stride
for b in tl.range(0, loop_num):
for b in tl.range(0, loop_num, num_stages=2):
col_offsets = b * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
input_ptrs = row_start_ptr + col_offsets
row_block = tl.load(input_ptrs, cache_modifier=".cg") #load block
Expand All @@ -92,7 +101,7 @@ def softmax_kernel(

output_row_start_ptr = output_ptr + row_idx * output_row_stride
#Loop 2
for b in tl.range(0, loop_num):
for b in tl.range(0, loop_num, num_stages=2):
col_offsets = b * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
input_ptrs = row_start_ptr + col_offsets
row_block = tl.load(input_ptrs, cache_modifier=".cg") #load block
Expand All @@ -118,6 +127,7 @@ def softmax(x):
n_rows, n_cols = x.shape

MAX_FUSED_SIZE = 65536 // x.element_size()
#MAX_FUSED_SIZE = 16384 // x.element_size()
BLOCK_SIZE = min(MAX_FUSED_SIZE, triton.next_power_of_2(n_cols))
y = torch.empty_like(x)

Expand Down

0 comments on commit a542825

Please sign in to comment.