Skip to content

Commit

Permalink
Use tl.exp2
Browse files Browse the repository at this point in the history
  • Loading branch information
Rahul Batra committed Sep 25, 2024
1 parent 8939325 commit ba9a70c
Showing 1 changed file with 8 additions and 7 deletions.
15 changes: 8 additions & 7 deletions python/perf-kernels/softmax.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,39 +75,40 @@ def softmax_kernel(
row_block = tl.load(input_ptrs, cache_modifier=".cg") #load block
m_p = tl.max(row_block, axis=0) #find block max
m_p = tl.maximum(m, m_p) #Find new max across all blocks so far
row_sum = row_sum * tl.exp(m - m_p) #Adjust previous sum
row_sum += tl.sum(tl.exp(row_block - m_p)) #Add to exponentiated sum of this block
row_sum = row_sum * tl.exp2((m - m_p) * 1.44269504) #Adjust previous sum
row_sum += tl.sum(tl.exp2((row_block - m_p) * 1.44269504)) #Add to exponentiated sum of this block
m = m_p #save max

#Last iteration with masked load/store
col_offsets = loop_num * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
input_ptrs = row_start_ptr + col_offsets
mask = col_offsets < n_cols
row_block = tl.load(input_ptrs, mask=mask, other=-float('inf'), cache_modifier=".cg") #load block
m_p = tl.max(row_block, axis=0) #find block max
m_p = tl.maximum(m, m_p) #Find new max across all blocks so far
row_sum = row_sum * tl.exp(m - m_p) #Adjust previous sum
row_sum += tl.sum(tl.exp(row_block - m_p)) #Add to exponentiated sum of this block
row_sum = row_sum * tl.exp2((m - m_p) * 1.44269504) #Adjust previous sum
row_sum += tl.sum(tl.exp2((row_block - m_p) * 1.44269504)) #Add to exponentiated sum of this block
m = m_p #save max

output_row_start_ptr = output_ptr + row_idx * output_row_stride
#Loop 2
loop_num = tl.cdiv(n_cols, BLOCK_SIZE) - 1
for b in tl.range(0, loop_num):
col_offsets = b * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
input_ptrs = row_start_ptr + col_offsets
row_block = tl.load(input_ptrs, cache_modifier=".cg") #load block
#subtract, exponentiate and divide by sum
softmax_output = tl.exp(row_block - m) / row_sum
softmax_output = tl.exp2((row_block - m) * 1.44269504) / row_sum
#store
output_ptrs = output_row_start_ptr + col_offsets
tl.store(output_ptrs, softmax_output)

#Last iteration with masked load/store
col_offsets = loop_num * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
input_ptrs = row_start_ptr + col_offsets
mask = col_offsets < n_cols
row_block = tl.load(input_ptrs, mask=mask, other=-float('inf'), cache_modifier=".cg") #load block
#subtract, exponentiate and divide by sum
softmax_output = tl.exp(row_block - m) / row_sum
softmax_output = tl.exp2((row_block - m) * 1.44269504) / row_sum
#store
output_ptrs = output_row_start_ptr + col_offsets
tl.store(output_ptrs, softmax_output, mask=mask)
Expand Down

0 comments on commit ba9a70c

Please sign in to comment.