Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Online softmax implementation #639

Merged
merged 1 commit into from
Sep 16, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
71 changes: 35 additions & 36 deletions python/perf-kernels/softmax.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@

import triton
import triton.language as tl
from triton.runtime import driver


def is_cuda():
Expand Down Expand Up @@ -52,43 +51,52 @@ def get_autotune_config():

@triton.autotune(configs=get_autotune_config(), key=['n_rows', 'n_cols'], use_cuda_graph=True)
@triton.jit
def softmax_kernel(output_ptr, input_ptr, input_row_stride, output_row_stride, n_rows, n_cols,
BLOCK_SIZE: tl.constexpr):
def softmax_kernel_online(output_ptr, input_ptr, input_row_stride, output_row_stride, n_rows, n_cols,
BLOCK_SIZE: tl.constexpr):

row_start = tl.program_id(0)
row_step = tl.num_programs(0)
col_offsets = tl.arange(0, BLOCK_SIZE)
mask = col_offsets < n_cols
for row_idx in tl.range(row_start, n_rows, row_step):
row_start_ptr = input_ptr + row_idx * input_row_stride
row_idx = row_start

#loop 1, find max and sum
m = -float('inf') #Initial value of max
row_sum = 0.0
row_start_ptr = input_ptr + row_idx * input_row_stride
for b in tl.range(0, n_cols, BLOCK_SIZE):
col_offsets = b + tl.arange(0, BLOCK_SIZE)
input_ptrs = row_start_ptr + col_offsets
mask = col_offsets < n_cols
row_block = tl.load(input_ptrs, mask=mask, other=-float('inf'), cache_modifier=".cg") #load block
m_p = tl.max(row_block, axis=0) #find block max
m_p = tl.maximum(m, m_p) #Find new max across all blocks so far
row_sum = row_sum * tl.exp(m - m_p) #Adjust previous sum
row_sum += tl.sum(tl.exp(row_block - m_p)) #Add to exponentiated sum of this block
m = m_p #save max

output_row_start_ptr = output_ptr + row_idx * output_row_stride
#Loop 2
for b in tl.range(0, n_cols, BLOCK_SIZE):
col_offsets = b + tl.arange(0, BLOCK_SIZE)
input_ptrs = row_start_ptr + col_offsets
input_ptrs = tl.multiple_of(input_ptrs, (16, ))
row = tl.load(input_ptrs, mask=mask, other=-float('inf'), cache_modifier=".cg")
row_minus_max = row - tl.max(row, axis=0)
numerator = tl.exp(row_minus_max)
denominator = tl.sum(numerator, axis=0)
softmax_output = numerator / denominator
output_row_start_ptr = output_ptr + row_idx * output_row_stride
mask = col_offsets < n_cols
row_block = tl.load(input_ptrs, mask=mask, other=-float('inf'), cache_modifier=".cg") #load block
#subtract, exponentiate and divide by sum
softmax_output = tl.exp(row_block - m) / row_sum
#store
output_ptrs = output_row_start_ptr + col_offsets
output_ptrs = tl.multiple_of(output_ptrs, (16, ))
tl.store(output_ptrs, softmax_output, mask=mask)


device = torch.cuda.current_device()
properties = driver.active.utils.get_device_properties(device)
NUM_SM = properties["multiprocessor_count"]


def softmax(x):
n_rows, n_cols = x.shape
BLOCK_SIZE = triton.next_power_of_2(n_cols)

MAX_FUSED_SIZE = 65536 // x.element_size()
BLOCK_SIZE = min(MAX_FUSED_SIZE, triton.next_power_of_2(n_cols))
y = torch.empty_like(x)

#Persistent kernel. Simply, set num of programs equal to number of streaming multi-processors
num_programs = min(NUM_SM, n_rows)
num_programs = n_rows

grid = lambda meta: (num_programs, )
softmax_kernel[grid](
softmax_kernel_online[grid](
y,
x,
x.stride(0),
Expand All @@ -111,17 +119,8 @@ def run_softmax(M, N):


#pytest
@pytest.mark.parametrize('M, N', [
(1823, 781),
(1, 1),
(128, 1),
(1, 128),
(8192, 8192),
(4096, 8192),
(359, 1),
(1, 359),
(1, 131072),
])
@pytest.mark.parametrize('M, N', [(1823, 781), (1, 1), (128, 1), (1, 128), (8192, 8192), (4096, 8192), (359, 1),
(1, 359), (1, 131072), (1, 89999)])
def test_softmax(M, N):
torch.manual_seed(0)
x = torch.randn(M, N, device='cuda')
Expand Down
Loading