diff --git a/python/perf-kernels/softmax.py b/python/perf-kernels/softmax.py index 60eefb91986f..19764b2138e3 100644 --- a/python/perf-kernels/softmax.py +++ b/python/perf-kernels/softmax.py @@ -51,39 +51,66 @@ def get_autotune_config(): @triton.autotune(configs=get_autotune_config(), key=['n_rows', 'n_cols'], use_cuda_graph=True) @triton.jit -def softmax_kernel_online(output_ptr, input_ptr, input_row_stride, output_row_stride, n_rows, n_cols, - BLOCK_SIZE: tl.constexpr): +def softmax_kernel( + output_ptr, + input_ptr, + input_row_stride, + output_row_stride, + n_rows, + n_cols, + BLOCK_SIZE: tl.constexpr, +): row_start = tl.program_id(0) row_idx = row_start #loop 1, find max and sum + loop_num = tl.cdiv(n_cols, BLOCK_SIZE) - 1 m = -float('inf') #Initial value of max row_sum = 0.0 row_start_ptr = input_ptr + row_idx * input_row_stride - for b in tl.range(0, n_cols, BLOCK_SIZE): - col_offsets = b + tl.arange(0, BLOCK_SIZE) + for b in tl.range(0, loop_num): + col_offsets = b * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) input_ptrs = row_start_ptr + col_offsets - mask = col_offsets < n_cols - row_block = tl.load(input_ptrs, mask=mask, other=-float('inf'), cache_modifier=".cg") #load block + row_block = tl.load(input_ptrs, cache_modifier=".cg") #load block m_p = tl.max(row_block, axis=0) #find block max m_p = tl.maximum(m, m_p) #Find new max across all blocks so far row_sum = row_sum * tl.exp(m - m_p) #Adjust previous sum row_sum += tl.sum(tl.exp(row_block - m_p)) #Add to exponentiated sum of this block m = m_p #save max + col_offsets = loop_num * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) + input_ptrs = row_start_ptr + col_offsets + mask = col_offsets < n_cols + row_block = tl.load(input_ptrs, mask=mask, other=-float('inf'), cache_modifier=".cg") #load block + m_p = tl.max(row_block, axis=0) #find block max + m_p = tl.maximum(m, m_p) #Find new max across all blocks so far + row_sum = row_sum * tl.exp(m - m_p) #Adjust previous sum + row_sum += tl.sum(tl.exp(row_block - m_p)) #Add to exponentiated sum of this block + m = m_p #save max + output_row_start_ptr = output_ptr + row_idx * output_row_stride #Loop 2 - for b in tl.range(0, n_cols, BLOCK_SIZE): - col_offsets = b + tl.arange(0, BLOCK_SIZE) + loop_num = tl.cdiv(n_cols, BLOCK_SIZE) - 1 + for b in tl.range(0, loop_num): + col_offsets = b * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) input_ptrs = row_start_ptr + col_offsets - mask = col_offsets < n_cols - row_block = tl.load(input_ptrs, mask=mask, other=-float('inf'), cache_modifier=".cg") #load block + row_block = tl.load(input_ptrs, cache_modifier=".cg") #load block #subtract, exponentiate and divide by sum softmax_output = tl.exp(row_block - m) / row_sum #store output_ptrs = output_row_start_ptr + col_offsets - tl.store(output_ptrs, softmax_output, mask=mask) + tl.store(output_ptrs, softmax_output) + + col_offsets = loop_num * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) + input_ptrs = row_start_ptr + col_offsets + mask = col_offsets < n_cols + row_block = tl.load(input_ptrs, mask=mask, other=-float('inf'), cache_modifier=".cg") #load block + #subtract, exponentiate and divide by sum + softmax_output = tl.exp(row_block - m) / row_sum + #store + output_ptrs = output_row_start_ptr + col_offsets + tl.store(output_ptrs, softmax_output, mask=mask) def softmax(x): @@ -96,7 +123,7 @@ def softmax(x): num_programs = n_rows grid = lambda meta: (num_programs, ) - softmax_kernel_online[grid]( + softmax_kernel[grid]( y, x, x.stride(0),