Skip to content

Commit

Permalink
Use mask during load for Softmax
Browse files Browse the repository at this point in the history
  • Loading branch information
Rahul Batra committed Sep 24, 2024
1 parent 96b3d37 commit 8939325
Showing 1 changed file with 39 additions and 12 deletions.
51 changes: 39 additions & 12 deletions python/perf-kernels/softmax.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,39 +51,66 @@ def get_autotune_config():

@triton.autotune(configs=get_autotune_config(), key=['n_rows', 'n_cols'], use_cuda_graph=True)
@triton.jit
def softmax_kernel_online(output_ptr, input_ptr, input_row_stride, output_row_stride, n_rows, n_cols,
BLOCK_SIZE: tl.constexpr):
def softmax_kernel(
output_ptr,
input_ptr,
input_row_stride,
output_row_stride,
n_rows,
n_cols,
BLOCK_SIZE: tl.constexpr,
):

row_start = tl.program_id(0)
row_idx = row_start

#loop 1, find max and sum
loop_num = tl.cdiv(n_cols, BLOCK_SIZE) - 1
m = -float('inf') #Initial value of max
row_sum = 0.0
row_start_ptr = input_ptr + row_idx * input_row_stride
for b in tl.range(0, n_cols, BLOCK_SIZE):
col_offsets = b + tl.arange(0, BLOCK_SIZE)
for b in tl.range(0, loop_num):
col_offsets = b * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
input_ptrs = row_start_ptr + col_offsets
mask = col_offsets < n_cols
row_block = tl.load(input_ptrs, mask=mask, other=-float('inf'), cache_modifier=".cg") #load block
row_block = tl.load(input_ptrs, cache_modifier=".cg") #load block
m_p = tl.max(row_block, axis=0) #find block max
m_p = tl.maximum(m, m_p) #Find new max across all blocks so far
row_sum = row_sum * tl.exp(m - m_p) #Adjust previous sum
row_sum += tl.sum(tl.exp(row_block - m_p)) #Add to exponentiated sum of this block
m = m_p #save max

col_offsets = loop_num * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
input_ptrs = row_start_ptr + col_offsets
mask = col_offsets < n_cols
row_block = tl.load(input_ptrs, mask=mask, other=-float('inf'), cache_modifier=".cg") #load block
m_p = tl.max(row_block, axis=0) #find block max
m_p = tl.maximum(m, m_p) #Find new max across all blocks so far
row_sum = row_sum * tl.exp(m - m_p) #Adjust previous sum
row_sum += tl.sum(tl.exp(row_block - m_p)) #Add to exponentiated sum of this block
m = m_p #save max

output_row_start_ptr = output_ptr + row_idx * output_row_stride
#Loop 2
for b in tl.range(0, n_cols, BLOCK_SIZE):
col_offsets = b + tl.arange(0, BLOCK_SIZE)
loop_num = tl.cdiv(n_cols, BLOCK_SIZE) - 1
for b in tl.range(0, loop_num):
col_offsets = b * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
input_ptrs = row_start_ptr + col_offsets
mask = col_offsets < n_cols
row_block = tl.load(input_ptrs, mask=mask, other=-float('inf'), cache_modifier=".cg") #load block
row_block = tl.load(input_ptrs, cache_modifier=".cg") #load block
#subtract, exponentiate and divide by sum
softmax_output = tl.exp(row_block - m) / row_sum
#store
output_ptrs = output_row_start_ptr + col_offsets
tl.store(output_ptrs, softmax_output, mask=mask)
tl.store(output_ptrs, softmax_output)

col_offsets = loop_num * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
input_ptrs = row_start_ptr + col_offsets
mask = col_offsets < n_cols
row_block = tl.load(input_ptrs, mask=mask, other=-float('inf'), cache_modifier=".cg") #load block
#subtract, exponentiate and divide by sum
softmax_output = tl.exp(row_block - m) / row_sum
#store
output_ptrs = output_row_start_ptr + col_offsets
tl.store(output_ptrs, softmax_output, mask=mask)


def softmax(x):
Expand All @@ -96,7 +123,7 @@ def softmax(x):
num_programs = n_rows

grid = lambda meta: (num_programs, )
softmax_kernel_online[grid](
softmax_kernel[grid](
y,
x,
x.stride(0),
Expand Down

0 comments on commit 8939325

Please sign in to comment.