From ba9a70c1feb52422e196e87222f591cb709c4257 Mon Sep 17 00:00:00 2001 From: Rahul Batra Date: Wed, 25 Sep 2024 01:47:25 +0000 Subject: [PATCH] Use tl.exp2 --- python/perf-kernels/softmax.py | 15 ++++++++------- 1 file changed, 8 insertions(+), 7 deletions(-) diff --git a/python/perf-kernels/softmax.py b/python/perf-kernels/softmax.py index 19764b2138e3..a549a87461ad 100644 --- a/python/perf-kernels/softmax.py +++ b/python/perf-kernels/softmax.py @@ -75,39 +75,40 @@ def softmax_kernel( row_block = tl.load(input_ptrs, cache_modifier=".cg") #load block m_p = tl.max(row_block, axis=0) #find block max m_p = tl.maximum(m, m_p) #Find new max across all blocks so far - row_sum = row_sum * tl.exp(m - m_p) #Adjust previous sum - row_sum += tl.sum(tl.exp(row_block - m_p)) #Add to exponentiated sum of this block + row_sum = row_sum * tl.exp2((m - m_p) * 1.44269504) #Adjust previous sum + row_sum += tl.sum(tl.exp2((row_block - m_p) * 1.44269504)) #Add to exponentiated sum of this block m = m_p #save max + #Last iteration with masked load/store col_offsets = loop_num * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) input_ptrs = row_start_ptr + col_offsets mask = col_offsets < n_cols row_block = tl.load(input_ptrs, mask=mask, other=-float('inf'), cache_modifier=".cg") #load block m_p = tl.max(row_block, axis=0) #find block max m_p = tl.maximum(m, m_p) #Find new max across all blocks so far - row_sum = row_sum * tl.exp(m - m_p) #Adjust previous sum - row_sum += tl.sum(tl.exp(row_block - m_p)) #Add to exponentiated sum of this block + row_sum = row_sum * tl.exp2((m - m_p) * 1.44269504) #Adjust previous sum + row_sum += tl.sum(tl.exp2((row_block - m_p) * 1.44269504)) #Add to exponentiated sum of this block m = m_p #save max output_row_start_ptr = output_ptr + row_idx * output_row_stride #Loop 2 - loop_num = tl.cdiv(n_cols, BLOCK_SIZE) - 1 for b in tl.range(0, loop_num): col_offsets = b * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) input_ptrs = row_start_ptr + col_offsets row_block = tl.load(input_ptrs, cache_modifier=".cg") #load block #subtract, exponentiate and divide by sum - softmax_output = tl.exp(row_block - m) / row_sum + softmax_output = tl.exp2((row_block - m) * 1.44269504) / row_sum #store output_ptrs = output_row_start_ptr + col_offsets tl.store(output_ptrs, softmax_output) + #Last iteration with masked load/store col_offsets = loop_num * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) input_ptrs = row_start_ptr + col_offsets mask = col_offsets < n_cols row_block = tl.load(input_ptrs, mask=mask, other=-float('inf'), cache_modifier=".cg") #load block #subtract, exponentiate and divide by sum - softmax_output = tl.exp(row_block - m) / row_sum + softmax_output = tl.exp2((row_block - m) * 1.44269504) / row_sum #store output_ptrs = output_row_start_ptr + col_offsets tl.store(output_ptrs, softmax_output, mask=mask)