diff --git a/python/perf-kernels/softmax.py b/python/perf-kernels/softmax.py index bd00f24c42fc..b3924e55d281 100644 --- a/python/perf-kernels/softmax.py +++ b/python/perf-kernels/softmax.py @@ -5,7 +5,6 @@ import triton import triton.language as tl -from triton.runtime import driver def is_cuda(): @@ -52,43 +51,55 @@ def get_autotune_config(): @triton.autotune(configs=get_autotune_config(), key=['n_rows', 'n_cols'], use_cuda_graph=True) @triton.jit -def softmax_kernel(output_ptr, input_ptr, input_row_stride, output_row_stride, n_rows, n_cols, - BLOCK_SIZE: tl.constexpr): +def softmax_kernel_online(output_ptr, input_ptr, input_row_stride, output_row_stride, n_rows, n_cols, + BLOCK_SIZE: tl.constexpr): + row_start = tl.program_id(0) - row_step = tl.num_programs(0) - col_offsets = tl.arange(0, BLOCK_SIZE) - mask = col_offsets < n_cols - for row_idx in tl.range(row_start, n_rows, row_step): - row_start_ptr = input_ptr + row_idx * input_row_stride + row_idx = row_start + + #loop 1, find max and sum + m = -float('inf') #Initial value of max + row_sum = 0.0 + row_start_ptr = input_ptr + row_idx * input_row_stride + for b in tl.range(0, n_cols, BLOCK_SIZE): + col_offsets = b + tl.arange(0, BLOCK_SIZE) + input_ptrs = row_start_ptr + col_offsets + mask = col_offsets < n_cols + #row_block = tl.load(input_ptrs, mask=mask, other=-float('inf'), cache_modifier=".cg") #load block + row_block = tl.load(input_ptrs, mask=mask, other=-float('inf')) #load block + m_p = tl.max(row_block, axis=0) #find block max + m_p = tl.maximum(m, m_p) #Find new max across all blocks so far + row_sum = row_sum * tl.exp(m - m_p) #Adjust previous sum + row_sum += tl.sum(tl.exp(row_block - m_p)) #Add to exponentiated sum of this block + m = m_p #save max + + output_row_start_ptr = output_ptr + row_idx * output_row_stride + #Loop 2 + for b in tl.range(0, n_cols, BLOCK_SIZE): + col_offsets = b + tl.arange(0, BLOCK_SIZE) input_ptrs = row_start_ptr + col_offsets - input_ptrs = tl.multiple_of(input_ptrs, (16, )) - row = tl.load(input_ptrs, mask=mask, other=-float('inf'), cache_modifier=".cg") - row_minus_max = row - tl.max(row, axis=0) - numerator = tl.exp(row_minus_max) - denominator = tl.sum(numerator, axis=0) - softmax_output = numerator / denominator - output_row_start_ptr = output_ptr + row_idx * output_row_stride + mask = col_offsets < n_cols + #row_block = tl.load(input_ptrs, mask=mask, other=-float('inf'), cache_modifier=".cg") #load block + row_block = tl.load(input_ptrs, mask=mask, other=-float('inf')) #load block + #subtract, exponentiate and divide by sum + softmax_output = tl.exp(row_block - m) / row_sum + #store output_ptrs = output_row_start_ptr + col_offsets - output_ptrs = tl.multiple_of(output_ptrs, (16, )) + #output_ptrs = tl.multiple_of(output_ptrs, (16, )) tl.store(output_ptrs, softmax_output, mask=mask) -device = torch.cuda.current_device() -properties = driver.active.utils.get_device_properties(device) -NUM_SM = properties["multiprocessor_count"] - - def softmax(x): n_rows, n_cols = x.shape - BLOCK_SIZE = triton.next_power_of_2(n_cols) + MAX_FUSED_SIZE = 65536 // x.element_size() + BLOCK_SIZE = min(MAX_FUSED_SIZE, triton.next_power_of_2(n_cols)) y = torch.empty_like(x) - #Persistent kernel. Simply, set num of programs equal to number of streaming multi-processors - num_programs = min(NUM_SM, n_rows) + num_programs = n_rows grid = lambda meta: (num_programs, ) - softmax_kernel[grid]( + softmax_kernel_online[grid]( y, x, x.stride(0), @@ -111,17 +122,8 @@ def run_softmax(M, N): #pytest -@pytest.mark.parametrize('M, N', [ - (1823, 781), - (1, 1), - (128, 1), - (1, 128), - (8192, 8192), - (4096, 8192), - (359, 1), - (1, 359), - (1, 131072), -]) +@pytest.mark.parametrize('M, N', [(1823, 781), (1, 1), (128, 1), (1, 128), (8192, 8192), (4096, 8192), (359, 1), + (1, 359), (1, 131072), (1, 89999)]) def test_softmax(M, N): torch.manual_seed(0) x = torch.randn(M, N, device='cuda')