Skip to content

Commit

Permalink
Online softmax implementation
Browse files Browse the repository at this point in the history
  • Loading branch information
Rahul Batra committed Sep 13, 2024
1 parent c4bd738 commit 10d9836
Showing 1 changed file with 38 additions and 36 deletions.
74 changes: 38 additions & 36 deletions python/perf-kernels/softmax.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@

import triton
import triton.language as tl
from triton.runtime import driver


def is_cuda():
Expand Down Expand Up @@ -52,43 +51,55 @@ def get_autotune_config():

@triton.autotune(configs=get_autotune_config(), key=['n_rows', 'n_cols'], use_cuda_graph=True)
@triton.jit
def softmax_kernel(output_ptr, input_ptr, input_row_stride, output_row_stride, n_rows, n_cols,
BLOCK_SIZE: tl.constexpr):
def softmax_kernel_online(output_ptr, input_ptr, input_row_stride, output_row_stride, n_rows, n_cols,
BLOCK_SIZE: tl.constexpr):

row_start = tl.program_id(0)
row_step = tl.num_programs(0)
col_offsets = tl.arange(0, BLOCK_SIZE)
mask = col_offsets < n_cols
for row_idx in tl.range(row_start, n_rows, row_step):
row_start_ptr = input_ptr + row_idx * input_row_stride
row_idx = row_start

#loop 1, find max and sum
m = -float('inf') #Initial value of max
row_sum = 0.0
row_start_ptr = input_ptr + row_idx * input_row_stride
for b in tl.range(0, n_cols, BLOCK_SIZE):
col_offsets = b + tl.arange(0, BLOCK_SIZE)
input_ptrs = row_start_ptr + col_offsets
mask = col_offsets < n_cols
#row_block = tl.load(input_ptrs, mask=mask, other=-float('inf'), cache_modifier=".cg") #load block
row_block = tl.load(input_ptrs, mask=mask, other=-float('inf')) #load block
m_p = tl.max(row_block, axis=0) #find block max
m_p = tl.maximum(m, m_p) #Find new max across all blocks so far
row_sum = row_sum * tl.exp(m - m_p) #Adjust previous sum
row_sum += tl.sum(tl.exp(row_block - m_p)) #Add to exponentiated sum of this block
m = m_p #save max

output_row_start_ptr = output_ptr + row_idx * output_row_stride
#Loop 2
for b in tl.range(0, n_cols, BLOCK_SIZE):
col_offsets = b + tl.arange(0, BLOCK_SIZE)
input_ptrs = row_start_ptr + col_offsets
input_ptrs = tl.multiple_of(input_ptrs, (16, ))
row = tl.load(input_ptrs, mask=mask, other=-float('inf'), cache_modifier=".cg")
row_minus_max = row - tl.max(row, axis=0)
numerator = tl.exp(row_minus_max)
denominator = tl.sum(numerator, axis=0)
softmax_output = numerator / denominator
output_row_start_ptr = output_ptr + row_idx * output_row_stride
mask = col_offsets < n_cols
#row_block = tl.load(input_ptrs, mask=mask, other=-float('inf'), cache_modifier=".cg") #load block
row_block = tl.load(input_ptrs, mask=mask, other=-float('inf')) #load block
#subtract, exponentiate and divide by sum
softmax_output = tl.exp(row_block - m) / row_sum
#store
output_ptrs = output_row_start_ptr + col_offsets
output_ptrs = tl.multiple_of(output_ptrs, (16, ))
#output_ptrs = tl.multiple_of(output_ptrs, (16, ))
tl.store(output_ptrs, softmax_output, mask=mask)


device = torch.cuda.current_device()
properties = driver.active.utils.get_device_properties(device)
NUM_SM = properties["multiprocessor_count"]


def softmax(x):
n_rows, n_cols = x.shape
BLOCK_SIZE = triton.next_power_of_2(n_cols)

MAX_FUSED_SIZE = 65536 // x.element_size()
BLOCK_SIZE = min(MAX_FUSED_SIZE, triton.next_power_of_2(n_cols))
y = torch.empty_like(x)

#Persistent kernel. Simply, set num of programs equal to number of streaming multi-processors
num_programs = min(NUM_SM, n_rows)
num_programs = n_rows

grid = lambda meta: (num_programs, )
softmax_kernel[grid](
softmax_kernel_online[grid](
y,
x,
x.stride(0),
Expand All @@ -111,17 +122,8 @@ def run_softmax(M, N):


#pytest
@pytest.mark.parametrize('M, N', [
(1823, 781),
(1, 1),
(128, 1),
(1, 128),
(8192, 8192),
(4096, 8192),
(359, 1),
(1, 359),
(1, 131072),
])
@pytest.mark.parametrize('M, N', [(1823, 781), (1, 1), (128, 1), (1, 128), (8192, 8192), (4096, 8192), (359, 1),
(1, 359), (1, 131072), (1, 89999)])
def test_softmax(M, N):
torch.manual_seed(0)
x = torch.randn(M, N, device='cuda')
Expand Down

0 comments on commit 10d9836

Please sign in to comment.