diff --git a/src/liger_kernel/ops/cross_entropy.py b/src/liger_kernel/ops/cross_entropy.py index 3feca368a..32703788c 100644 --- a/src/liger_kernel/ops/cross_entropy.py +++ b/src/liger_kernel/ops/cross_entropy.py @@ -112,7 +112,7 @@ def liger_cross_entropy_kernel( @triton.jit -def element_mul( +def element_mul_kernel( X_ptr, X_stride, grad_output_ptr, @@ -147,6 +147,68 @@ def element_mul( tl.store(X_ptr + X_offsets, X_block * grad_output, mask=X_offsets < n_cols) +def cross_entropy_forward(_input, target, ignore_index): + BT, V = _input.shape + n_rows = BT + + BLOCK_SIZE = min(MAX_FUSED_SIZE, triton.next_power_of_2(V)) + + # unreduced loss + loss_1d = torch.zeros(n_rows, dtype=_input.dtype, device=_input.device) + + n_non_ignore = (target != ignore_index).sum().item() + + # ensure _input and target are contiguous in the last dimension + if _input.stride(-1) != 1: + _input = _input.contiguous() + if target.stride(-1) != 1: + target = target.contiguous() + + # Here we use a trick to store X_ptr gradient in X_ptr so we can save memory + liger_cross_entropy_kernel[(n_rows,)]( + X_ptr=_input, + X_stride=_input.stride(-2), + Y_ptr=target, + Y_stride=target.stride(-1), # always 1 + loss_ptr=loss_1d, + loss_stride=loss_1d.stride(-1), # always 1 + n_cols=V, + n_non_ignore=n_non_ignore, + ignore_index=ignore_index, + BLOCK_SIZE=BLOCK_SIZE, + # TODO: 32 seems to give the best performance + # Performance is quite sensitive to num_warps + num_warps=32, + ) + + loss = torch.sum(loss_1d) / n_non_ignore + return loss, _input + + +def cross_entropy_backward(_input, grad_output): + # If cross entropy is the last layer, grad_output is 1.0. Skip the mul to save time + if torch.equal(grad_output, torch.tensor(1.0, device=grad_output.device)): + pass + + # We use a Triton kernel instead of a PyTorch operation because modifying inputs in-place + # for gradient storage and backward multiple times causes anomalies with PyTorch but not with Triton. + else: + BT, V = _input.shape + n_rows = BT + BLOCK_SIZE = min(MAX_FUSED_SIZE, triton.next_power_of_2(V)) + + element_mul_kernel[(n_rows,)]( + _input, + _input.stride(-2), + grad_output, + V, + BLOCK_SIZE=BLOCK_SIZE, + num_warps=32, + ) + + return _input + + class LigerCrossEntropyFunction(torch.autograd.Function): """ This class implements a custom autograd function for the Liger Cross Entropy loss. @@ -167,41 +229,7 @@ def forward(ctx, _input, target, ignore_index): Returns: tensor: The computed loss. """ - BT, V = _input.shape - n_rows = BT - - BLOCK_SIZE = min(MAX_FUSED_SIZE, triton.next_power_of_2(V)) - - # unreduced loss - loss_1d = torch.zeros(n_rows, dtype=_input.dtype, device=_input.device) - - n_non_ignore = (target != ignore_index).sum().item() - - # ensure _input and target are contiguous in the last dimension - if _input.stride(-1) != 1: - _input = _input.contiguous() - if target.stride(-1) != 1: - target = target.contiguous() - - # Here we use a trick to store X_ptr gradient in X_ptr so we can save memory - liger_cross_entropy_kernel[(n_rows,)]( - X_ptr=_input, - X_stride=_input.stride(-2), - Y_ptr=target, - Y_stride=target.stride(-1), # always 1 - loss_ptr=loss_1d, - loss_stride=loss_1d.stride(-1), # always 1 - n_cols=V, - n_non_ignore=n_non_ignore, - ignore_index=ignore_index, - BLOCK_SIZE=BLOCK_SIZE, - # TODO: 32 seems to give the best performance - # Performance is quite sensitive to num_warps - num_warps=32, - ) - - loss = torch.sum(loss_1d) / n_non_ignore - + loss, _input = cross_entropy_forward(_input, target, ignore_index) # TODO: investigation # If we don't detach the _input tensor, the memory will double # Not sure why but seems that there will be a time both grad and value exist but in different location @@ -221,26 +249,7 @@ def backward(ctx, grad_output): tuple: A tuple with the gradients with respect to the inputs. The elements are tensors or None. """ (_input,) = ctx.saved_tensors - # If cross entropy is the last layer, grad_output is 1.0. Skip the mul to save time - if torch.equal(grad_output, torch.tensor(1.0, device=grad_output.device)): - pass - - # We use a Triton kernel instead of a PyTorch operation because modifying inputs in-place - # for gradient storage and backward multiple times causes anomalies with PyTorch but not with Triton. - else: - BT, V = _input.shape - n_rows = BT - BLOCK_SIZE = min(MAX_FUSED_SIZE, triton.next_power_of_2(V)) - - element_mul[(n_rows,)]( - _input, - _input.stride(-2), - grad_output, - V, - BLOCK_SIZE=BLOCK_SIZE, - num_warps=32, - ) - + _input = cross_entropy_backward(_input, grad_output) return ( _input, None, diff --git a/src/liger_kernel/ops/fused_linear_cross_entropy.py b/src/liger_kernel/ops/fused_linear_cross_entropy.py index d32d180a8..7b62dbbb1 100644 --- a/src/liger_kernel/ops/fused_linear_cross_entropy.py +++ b/src/liger_kernel/ops/fused_linear_cross_entropy.py @@ -1,7 +1,10 @@ import torch import triton -from liger_kernel.ops.cross_entropy import element_mul, liger_cross_entropy_kernel +from liger_kernel.ops.cross_entropy import ( + element_mul_kernel, + liger_cross_entropy_kernel, +) # The hard limit of TRITON_MAX_TENSOR_NUMEL is 1048576 https://github.com/triton-lang/triton/blob/ba42a5c68fd0505f8c42f4202d53be0f8d9a5fe0/python/triton/language/core.py#L19 # However, setting limit as 65536 as in LayerNorm tutorial is faster because of less register spilling @@ -9,6 +12,163 @@ MAX_FUSED_SIZE = 65536 // 2 +def fused_linear_cross_entropy_forward( + _input, weight, target, bias=None, ignore_index=-100 +): + dtype = ( + torch.get_autocast_gpu_dtype() if torch.is_autocast_enabled() else _input.dtype + ) + device = _input.device + + # inputs have shape: BT x H + # materialized activations will have shape: BT x V + # the increase in memory = BT x V + # reduction can be achieved by partitioning the number of tokens BT into smaller chunks. + # for ex: if we were to achieve the same memory consumption as BT x H, then the chunk size should be: + # inc_factor = (V+H-1)//H, chunk_size = (BT + inc_factor - 1)//inc_factor + # for ex: BT = 4096*4, V = 32000, H = 4096 ==> inc_factor = 8, chunk_size = 2048 + BT, H = _input.shape + V = weight.shape[0] + BLOCK_SIZE = min(MAX_FUSED_SIZE, triton.next_power_of_2(V)) + + inc_factor = triton.cdiv(V, H) # (V + H - 1) // H + chunk_size = triton.next_power_of_2( + triton.cdiv(BT, inc_factor) + ) # (BT + inc_factor - 1) // inc_factor + num_chunks = triton.cdiv(BT, chunk_size) # (BT + chunk_size - 1) // chunk_size + + grad_weight = torch.zeros_like(weight, device=device) + grad_input = torch.zeros_like(_input, device=device) + grad_bias = torch.zeros_like(bias, device=device) if bias is not None else None + # we use fp32 for loss accumulator + loss_1d = torch.zeros(BT, dtype=torch.float32, device=device) + + total_n_non_ignore = (target != ignore_index).sum().item() + + for chunk_id in range(num_chunks): + start_idx = chunk_id * chunk_size + end_idx = min((chunk_id + 1) * chunk_size, BT) + _input_chunk = _input[start_idx:end_idx] # chunk_size x H + + # when doing matmul, use the original precision + logits_chunk = _input_chunk @ weight.t() # chunk_size x V + if bias is not None: + logits_chunk = logits_chunk + bias + target_chunk = target[start_idx:end_idx] # chunk_size, + + n_rows = logits_chunk.shape[0] + + # unreduced loss + loss_1d_slice = loss_1d[start_idx:end_idx] # chunk_size, + n_non_ignore = (target_chunk != ignore_index).sum().item() + + # when doing CE, use the upcasted precision + logits_chunk = logits_chunk.float() + + # ensure _input and target are contiguous + logits_chunk = logits_chunk.contiguous() + target_chunk = target_chunk.contiguous() + + # Here we calculate the gradient of logits_chunk in place so we can save memory. + liger_cross_entropy_kernel[(n_rows,)]( + X_ptr=logits_chunk, + X_stride=logits_chunk.stride(-2), + Y_ptr=target_chunk, + Y_stride=target_chunk.stride(-1), # always 1 + loss_ptr=loss_1d_slice, + loss_stride=loss_1d_slice.stride(-1), # always 1 + n_cols=V, + n_non_ignore=n_non_ignore, + ignore_index=ignore_index, + BLOCK_SIZE=BLOCK_SIZE, + num_warps=32, + ) + + # gradient of logits_chunk is computed in-place by the above triton kernel. + # Following HuggingFace model source code, we do the forward and backward + # w.r.t. logits in fp32 for numerical stability especially as the num classes (vocab size) os huge. + # (reference: https://github.com/huggingface/transformers/blob/v4.42.4/src/transformers/models/llama/modeling_llama.py#L1194) + # Propagating to lm_head's backward, we'll switch back to the original dtype. + logits_chunk = logits_chunk.to(dtype) + + # gradient of logits_chunk is computed in-place by the above triton kernel and is of shape: chunk_size x V + # thus grad_input[start_idx: end_idx] should be of shape: chunk_size x H + # additionally, since we are chunking the inputs, observe that the loss and gradients are calculated only + # on `n_non_ignore` tokens. However, the gradient of the input should be calculated for all tokens. + # Thus, we need an additional scaling factor of (n_non_ignore/total_n_non_ignore) to scale the gradients. + grad_logits_chunk = logits_chunk * ( + n_non_ignore / total_n_non_ignore + ) # chunk_size x V + grad_input[start_idx:end_idx] = grad_logits_chunk @ weight + torch.addmm( + input=grad_weight, + mat1=logits_chunk.t(), + mat2=_input_chunk, + out=grad_weight, + alpha=n_non_ignore / total_n_non_ignore, + beta=1.0, + ) + + if bias is not None: + torch.add( + input=grad_bias, + other=logits_chunk.sum(dim=0), + out=grad_bias, + alpha=n_non_ignore / total_n_non_ignore, + ) + + loss = torch.sum(loss_1d) / total_n_non_ignore + return loss, grad_input, grad_weight, grad_bias + + +def fused_linear_cross_entropy_backward( + grad_output, grad_input, grad_weight, grad_bias +): + # If cross entropy is the last layer, grad_output is 1.0. Skip the mul to save time + if torch.ne(grad_output, torch.tensor(1.0, device=grad_output.device)): + # We use a Triton kernel instead of a PyTorch operation because modifying inputs in-place + # for gradient storage and backward multiple times causes anomalies with PyTorch but not with Triton. + BT, H = grad_input.shape + n_rows = BT + BLOCK_SIZE = min(MAX_FUSED_SIZE, triton.next_power_of_2(H)) + + element_mul_kernel[(n_rows,)]( + grad_input, + grad_input.stride(-2), + grad_output, + H, + BLOCK_SIZE=BLOCK_SIZE, + num_warps=32, + ) + + # handle grad_weight + V, H = grad_weight.shape + n_rows = V + + element_mul_kernel[(n_rows,)]( + grad_weight, + grad_weight.stride(-2), + grad_output, + H, + BLOCK_SIZE=BLOCK_SIZE, + num_warps=32, + ) + + if grad_bias is not None: + V = grad_bias.shape[0] + n_rows = V + + element_mul_kernel[(n_rows,)]( + grad_bias, + grad_bias.stride(-1), + grad_output, + 1, + BLOCK_SIZE=BLOCK_SIZE, + num_warps=32, + ) + return grad_input, grad_weight, grad_bias + + class LigerFusedLinearCrossEntropyFunction(torch.autograd.Function): @staticmethod def forward(ctx, _input, weight, target, bias=None, ignore_index=-100): @@ -27,112 +187,9 @@ def forward(ctx, _input, weight, target, bias=None, ignore_index=-100): bias: (V) where V is the number of classes ignore_index: the index to ignore in the target """ - dtype = ( - torch.get_autocast_gpu_dtype() - if torch.is_autocast_enabled() - else _input.dtype + loss, grad_input, grad_weight, grad_bias = fused_linear_cross_entropy_forward( + _input, weight, target, bias, ignore_index ) - device = _input.device - - # inputs have shape: BT x H - # materialized activations will have shape: BT x V - # the increase in memory = BT x V - # reduction can be achieved by partitioning the number of tokens BT into smaller chunks. - # for ex: if we were to achieve the same memory consumption as BT x H, then the chunk size should be: - # inc_factor = (V+H-1)//H, chunk_size = (BT + inc_factor - 1)//inc_factor - # for ex: BT = 4096*4, V = 32000, H = 4096 ==> inc_factor = 8, chunk_size = 2048 - BT, H = _input.shape - V = weight.shape[0] - BLOCK_SIZE = min(MAX_FUSED_SIZE, triton.next_power_of_2(V)) - - inc_factor = triton.cdiv(V, H) # (V + H - 1) // H - chunk_size = triton.next_power_of_2( - triton.cdiv(BT, inc_factor) - ) # (BT + inc_factor - 1) // inc_factor - num_chunks = triton.cdiv(BT, chunk_size) # (BT + chunk_size - 1) // chunk_size - - grad_weight = torch.zeros_like(weight, device=device) - grad_input = torch.zeros_like(_input, device=device) - grad_bias = torch.zeros_like(bias, device=device) if bias is not None else None - # we use fp32 for loss accumulator - loss_1d = torch.zeros(BT, dtype=torch.float32, device=device) - - total_n_non_ignore = (target != ignore_index).sum().item() - - for chunk_id in range(num_chunks): - start_idx = chunk_id * chunk_size - end_idx = min((chunk_id + 1) * chunk_size, BT) - _input_chunk = _input[start_idx:end_idx] # chunk_size x H - - # when doing matmul, use the original precision - logits_chunk = _input_chunk @ weight.t() # chunk_size x V - if bias is not None: - logits_chunk = logits_chunk + bias - target_chunk = target[start_idx:end_idx] # chunk_size, - - n_rows = logits_chunk.shape[0] - - # unreduced loss - loss_1d_slice = loss_1d[start_idx:end_idx] # chunk_size, - n_non_ignore = (target_chunk != ignore_index).sum().item() - - # when doing CE, use the upcasted precision - logits_chunk = logits_chunk.float() - - # ensure _input and target are contiguous - logits_chunk = logits_chunk.contiguous() - target_chunk = target_chunk.contiguous() - - # Here we calculate the gradient of logits_chunk in place so we can save memory. - liger_cross_entropy_kernel[(n_rows,)]( - X_ptr=logits_chunk, - X_stride=logits_chunk.stride(-2), - Y_ptr=target_chunk, - Y_stride=target_chunk.stride(-1), # always 1 - loss_ptr=loss_1d_slice, - loss_stride=loss_1d_slice.stride(-1), # always 1 - n_cols=V, - n_non_ignore=n_non_ignore, - ignore_index=ignore_index, - BLOCK_SIZE=BLOCK_SIZE, - num_warps=32, - ) - - # gradient of logits_chunk is computed in-place by the above triton kernel. - # Following HuggingFace model source code, we do the forward and backward - # w.r.t. logits in fp32 for numerical stability especially as the num classes (vocab size) os huge. - # (reference: https://github.com/huggingface/transformers/blob/v4.42.4/src/transformers/models/llama/modeling_llama.py#L1194) - # Propagating to lm_head's backward, we'll switch back to the original dtype. - logits_chunk = logits_chunk.to(dtype) - - # gradient of logits_chunk is computed in-place by the above triton kernel and is of shape: chunk_size x V - # thus grad_input[start_idx: end_idx] should be of shape: chunk_size x H - # additionally, since we are chunking the inputs, observe that the loss and gradients are calculated only - # on `n_non_ignore` tokens. However, the gradient of the input should be calculated for all tokens. - # Thus, we need an additional scaling factor of (n_non_ignore/total_n_non_ignore) to scale the gradients. - grad_logits_chunk = logits_chunk * ( - n_non_ignore / total_n_non_ignore - ) # chunk_size x V - grad_input[start_idx:end_idx] = grad_logits_chunk @ weight - torch.addmm( - input=grad_weight, - mat1=logits_chunk.t(), - mat2=_input_chunk, - out=grad_weight, - alpha=n_non_ignore / total_n_non_ignore, - beta=1.0, - ) - - if bias is not None: - torch.add( - input=grad_bias, - other=logits_chunk.sum(dim=0), - out=grad_bias, - alpha=n_non_ignore / total_n_non_ignore, - ) - - loss = torch.sum(loss_1d) / total_n_non_ignore - # downcast to dtype and store for backward ctx.save_for_backward( grad_input.detach(), @@ -144,47 +201,7 @@ def forward(ctx, _input, weight, target, bias=None, ignore_index=-100): @staticmethod def backward(ctx, grad_output): (grad_input, grad_weight, grad_bias) = ctx.saved_tensors - # If cross entropy is the last layer, grad_output is 1.0. Skip the mul to save time - if torch.ne(grad_output, torch.tensor(1.0, device=grad_output.device)): - # We use a Triton kernel instead of a PyTorch operation because modifying inputs in-place - # for gradient storage and backward multiple times causes anomalies with PyTorch but not with Triton. - BT, H = grad_input.shape - n_rows = BT - BLOCK_SIZE = min(MAX_FUSED_SIZE, triton.next_power_of_2(H)) - - element_mul[(n_rows,)]( - grad_input, - grad_input.stride(-2), - grad_output, - H, - BLOCK_SIZE=BLOCK_SIZE, - num_warps=32, - ) - - # handle grad_weight - V, H = grad_weight.shape - n_rows = V - - element_mul[(n_rows,)]( - grad_weight, - grad_weight.stride(-2), - grad_output, - H, - BLOCK_SIZE=BLOCK_SIZE, - num_warps=32, - ) - - if grad_bias is not None: - V = grad_bias.shape[0] - n_rows = V - - element_mul[(n_rows,)]( - grad_bias, - grad_bias.stride(-1), - grad_output, - 1, - BLOCK_SIZE=BLOCK_SIZE, - num_warps=32, - ) - + grad_input, grad_weight, grad_bias = fused_linear_cross_entropy_backward( + grad_output, grad_input, grad_weight, grad_bias + ) return (grad_input, grad_weight, None, grad_bias, None) diff --git a/src/liger_kernel/ops/geglu.py b/src/liger_kernel/ops/geglu.py index 09c761dac..de1a850c4 100644 --- a/src/liger_kernel/ops/geglu.py +++ b/src/liger_kernel/ops/geglu.py @@ -92,54 +92,61 @@ def _geglu_tanh_backward_kernel( tl.store(b + col_offsets, db_row, mask=mask) +def geglu_forward(a, b): + ori_shape = a.shape + + n_cols = ori_shape[-1] + a = a.view(-1, n_cols) + b = b.view(-1, n_cols) + c = torch.zeros_like(a) + n_rows = a.shape[0] + + BLOCK_SIZE, num_warps = calculate_settings(n_cols) + + _geglu_tanh_forward_kernel[(n_rows,)]( + a, + b, + c, + c.stride(-2), + n_cols=n_cols, + BLOCK_SIZE=BLOCK_SIZE, + num_warps=num_warps, + ) + return a, b, c.view(*ori_shape) + + +def geglu_backward(a, b, dc): + ori_shape = dc.shape + n_cols = ori_shape[-1] + dc = dc.view(-1, n_cols) + n_rows = dc.shape[0] + + BLOCK_SIZE, num_warps = calculate_settings(n_cols) + + _geglu_tanh_backward_kernel[(n_rows,)]( + dc, + a, + b, + dc.stride(-2), + n_cols=n_cols, + BLOCK_SIZE=BLOCK_SIZE, + num_warps=num_warps, + ) + + return a.view(*ori_shape), b.view(*ori_shape) + + class LigerGELUMulFunction(torch.autograd.Function): @staticmethod @ensure_contiguous def forward(ctx, a, b): - ori_shape = a.shape - - n_cols = ori_shape[-1] - a = a.view(-1, n_cols) - b = b.view(-1, n_cols) - c = torch.zeros_like(a) - n_rows = a.shape[0] - - BLOCK_SIZE, num_warps = calculate_settings(n_cols) - - _geglu_tanh_forward_kernel[(n_rows,)]( - a, - b, - c, - c.stride(-2), - n_cols=n_cols, - BLOCK_SIZE=BLOCK_SIZE, - num_warps=num_warps, - ) - + a, b, c = geglu_forward(a, b) ctx.save_for_backward(a, b) - - return c.view(*ori_shape) + return c @staticmethod @ensure_contiguous def backward(ctx, dc): - - ori_shape = dc.shape - n_cols = ori_shape[-1] - dc = dc.view(-1, n_cols) a, b = ctx.saved_tensors - n_rows = dc.shape[0] - - BLOCK_SIZE, num_warps = calculate_settings(n_cols) - - _geglu_tanh_backward_kernel[(n_rows,)]( - dc, - a, - b, - dc.stride(-2), - n_cols=n_cols, - BLOCK_SIZE=BLOCK_SIZE, - num_warps=num_warps, - ) - - return a.view(*ori_shape), b.view(*ori_shape) + a, b = geglu_backward(a, b, dc) + return a, b diff --git a/src/liger_kernel/ops/layer_norm.py b/src/liger_kernel/ops/layer_norm.py index ff2e664c4..c729f6944 100644 --- a/src/liger_kernel/ops/layer_norm.py +++ b/src/liger_kernel/ops/layer_norm.py @@ -23,7 +23,7 @@ @triton.jit -def _layer_norm_forward( +def _layer_norm_forward_kernel( Y_ptr, # pointer to output, shape (n_rows, n_cols) Y_row_stride, # stride of each row in output X_ptr, # pointer to input, shape (n_rows, n_cols) @@ -67,7 +67,7 @@ def _layer_norm_forward( @triton.jit -def _layer_norm_backward( +def _layer_norm_backward_kernel( X_ptr, # pointer to input, shape (n_rows, n_cols) W_ptr, # pointer to weights, shape (n_cols,) Mean_ptr, # pointer to mean, shape (n_rows,) @@ -130,94 +130,99 @@ def _layer_norm_backward( tl.store(DB_ptr + row_block_id * stride_db + cols, db_row.to(dtype), mask=mask) +def layer_norm_forward(X, W, B, eps): + shape = X.shape + dim = shape[-1] + X = X.view(-1, dim) + n_rows, n_cols = X.shape + BLOCK_SIZE, num_warps = calculate_settings(n_cols) + Y = torch.empty((n_rows, n_cols), dtype=X.dtype, device=X.device) + Mean = torch.empty(n_rows, dtype=X.dtype, device=X.device) + RSTD = torch.empty(n_rows, dtype=X.dtype, device=X.device) + + assert ( + X.shape[1] == W.shape[0] + ), f"Incompatible hidden size dimension between input tensor with shape[1] = {X.shape[1]} and weight tensor with shape[0] = {W.shape[0]}" + + _layer_norm_forward_kernel[(n_rows,)]( + Y, + Y.stride(0), + X, + X.stride(0), + W, + W.stride(0), + B, + B.stride(0), + Mean, + Mean.stride(0), + RSTD, + RSTD.stride(0), + n_cols, + eps, + BLOCK_SIZE=BLOCK_SIZE, + num_warps=num_warps, + ) + return Y.view(*shape), X, Mean, RSTD, BLOCK_SIZE, num_warps + + +def layer_norm_backward(dY, X, W, B, Mean, RSTD): + shape = dY.shape + dim = shape[-1] + dY = dY.view(-1, dim) + n_rows, n_cols = dY.shape + + DX = torch.empty((n_rows, n_cols), dtype=X.dtype, device=X.device) + sm_count = torch.cuda.get_device_properties(X.device).multi_processor_count + _DW = torch.empty((sm_count, n_cols), dtype=W.dtype, device=W.device) + _DB = torch.empty((sm_count, n_cols), dtype=W.dtype, device=W.device) + + BLOCK_SIZE, num_warps = calculate_settings(n_cols) + if n_cols > BLOCK_SIZE: + raise RuntimeError("This layer norm doesn't support feature dim >= 64KB.") + + rows_per_program = math.ceil(n_rows / sm_count) + grid = (sm_count,) + triton_dtype = tl.float32 if X.dtype == torch.float32 else tl.bfloat16 + _layer_norm_backward_kernel[grid]( + X, + W, + Mean, + RSTD, + DX, + _DW, + _DB, + dY, + X.stride(0), + DX.stride(0), + _DW.stride(0), + _DB.stride(0), + dY.stride(0), + n_rows, + n_cols, + rows_per_program, + BLOCK_SIZE=BLOCK_SIZE, + num_warps=num_warps, + dtype=triton_dtype, + ) + + DW = _DW.sum(dim=0).to(W.dtype) + DB = _DB.sum(dim=0).to(W.dtype) + + DX = DX.view(*shape) + return DX, DW, DB + + class LigerLayerNormFunction(torch.autograd.Function): @staticmethod @ensure_contiguous def forward(ctx, X, W, B, eps): - shape = X.shape - dim = shape[-1] - X = X.view(-1, dim) - n_rows, n_cols = X.shape - BLOCK_SIZE, num_warps = calculate_settings(n_cols) - - Y = torch.empty((n_rows, n_cols), dtype=X.dtype, device=X.device) - Mean = torch.empty(n_rows, dtype=X.dtype, device=X.device) - RSTD = torch.empty(n_rows, dtype=X.dtype, device=X.device) - - assert ( - X.shape[1] == W.shape[0] - ), f"Incompatible hidden size dimension between input tensor with shape[1] = {X.shape[1]} and weight tensor with shape[0] = {W.shape[0]}" - - _layer_norm_forward[(n_rows,)]( - Y, - Y.stride(0), - X, - X.stride(0), - W, - W.stride(0), - B, - B.stride(0), - Mean, - Mean.stride(0), - RSTD, - RSTD.stride(0), - n_cols, - eps, - BLOCK_SIZE=BLOCK_SIZE, - num_warps=num_warps, - ) - ctx.eps = eps - ctx.BLOCK_SIZE = BLOCK_SIZE - ctx.num_warps = num_warps - + Y, X, Mean, RSTD, BLOCK_SIZE, num_warps = layer_norm_forward(X, W, B, eps) ctx.save_for_backward(X, W, B, Mean, RSTD) - return Y.view(*shape) + return Y @staticmethod @ensure_contiguous def backward(ctx, dY): - shape = dY.shape - dim = shape[-1] - dY = dY.view(-1, dim) X, W, B, Mean, RSTD = ctx.saved_tensors - n_rows, n_cols = dY.shape - - DX = torch.empty((n_rows, n_cols), dtype=X.dtype, device=X.device) - sm_count = torch.cuda.get_device_properties(X.device).multi_processor_count - _DW = torch.empty((sm_count, n_cols), dtype=W.dtype, device=W.device) - _DB = torch.empty((sm_count, n_cols), dtype=W.dtype, device=W.device) - - BLOCK_SIZE, num_warps = calculate_settings(n_cols) - if n_cols > BLOCK_SIZE: - raise RuntimeError("This layer norm doesn't support feature dim >= 64KB.") - - rows_per_program = math.ceil(n_rows / sm_count) - grid = (sm_count,) - triton_dtype = tl.float32 if X.dtype == torch.float32 else tl.bfloat16 - _layer_norm_backward[grid]( - X, - W, - Mean, - RSTD, - DX, - _DW, - _DB, - dY, - X.stride(0), - DX.stride(0), - _DW.stride(0), - _DB.stride(0), - dY.stride(0), - n_rows, - n_cols, - rows_per_program, - BLOCK_SIZE=BLOCK_SIZE, - num_warps=num_warps, - dtype=triton_dtype, - ) - - DW = _DW.sum(dim=0).to(W.dtype) - DB = _DB.sum(dim=0).to(W.dtype) - - DX = DX.view(*shape) + DX, DW, DB = layer_norm_backward(dY, X, W, B, Mean, RSTD) return DX, DW, DB, None diff --git a/src/liger_kernel/ops/rms_norm.py b/src/liger_kernel/ops/rms_norm.py index 38e4ae573..9a04611eb 100644 --- a/src/liger_kernel/ops/rms_norm.py +++ b/src/liger_kernel/ops/rms_norm.py @@ -27,7 +27,7 @@ @triton.jit -def _rms_norm_forward( +def _rms_norm_forward_kernel( Y_ptr, Y_row_stride, X_ptr, @@ -92,7 +92,7 @@ def _rms_norm_forward( @triton.jit -def _rms_norm_backward( +def _rms_norm_backward_kernel( dY_ptr, dY_row_stride, X_ptr, @@ -181,6 +181,91 @@ def _rms_norm_backward( } +def rms_norm_forward(X, W, eps, offset, casting_mode): + if not isinstance(casting_mode, int): + assert ( + casting_mode in _str_to_casting_mode + ), f"Invalid casting mode: {casting_mode}" + casting_mode = _str_to_casting_mode[casting_mode] + else: + assert ( + casting_mode in _str_to_casting_mode.values() + ), f"Invalid casting mode: {casting_mode}" + + shape = X.shape + dim = shape[-1] + X = X.view(-1, dim) + n_rows, n_cols = X.shape + BLOCK_SIZE, num_warps = calculate_settings(n_cols) + + Y = torch.empty((n_rows, n_cols), dtype=X.dtype, device=X.device) + # r is to cache (1/rms) for each row + # r is always computed/stored in fp32 if we are using Llama or Gemma casting mode + r_dtype = ( + torch.float32 + if casting_mode in (_CASTING_MODE_LLAMA.value, _CASTING_MODE_GEMMA.value) + else X.dtype + ) + r = torch.empty(n_rows, dtype=r_dtype, device=X.device) + + # Check constraints. + assert ( + X.shape[1] == W.shape[0] + ), "Incompatible hidden size dimension between tensor1.shape[1] and tensor2.shape[0]" + + _rms_norm_forward_kernel[(n_rows,)]( + Y, + Y.stride(0), + X, + X.stride(0), + W, + W.stride(0), + r, + r.stride(0), + n_cols, + eps, + offset, + casting_mode, + BLOCK_SIZE=BLOCK_SIZE, + num_warps=num_warps, + ) + return Y.view(*shape), X, r, BLOCK_SIZE, num_warps, casting_mode + + +def rms_norm_backward(dY, X, W, r, eps, offset, casting_mode, BLOCK_SIZE, num_warps): + shape = dY.shape + dim = shape[-1] + dY = dY.view(-1, dim) + n_rows, n_cols = dY.shape + dW = torch.empty_like( + X, + dtype=(torch.float32 if casting_mode == _CASTING_MODE_GEMMA.value else W.dtype), + ) + + # Here we use dY to store the value of dX to save memory + _rms_norm_backward_kernel[(n_rows,)]( + dY, + dY.stride(0), + X, + X.stride(0), + W, + W.stride(0), + r, + r.stride(0), + dW, + dW.stride(0), + n_cols, + eps, + offset, + casting_mode, + BLOCK_SIZE=BLOCK_SIZE, + num_warps=num_warps, + ) + dX = dY.view(*shape) + dW = torch.sum(dW, dim=0).to(W.dtype) + return dX, dW + + class LigerRMSNormFunction(torch.autograd.Function): """ Performs RMSNorm (Root Mean Square Normalization), which normalizes the input tensor `X` using the @@ -206,61 +291,16 @@ def forward(ctx, X, W, eps, offset=0.0, casting_mode="llama"): X: (B, T, H) or (BxT, H) W: (H,) """ - if not isinstance(casting_mode, int): - assert ( - casting_mode in _str_to_casting_mode - ), f"Invalid casting mode: {casting_mode}" - casting_mode = _str_to_casting_mode[casting_mode] - else: - assert ( - casting_mode in _str_to_casting_mode.values() - ), f"Invalid casting mode: {casting_mode}" - - shape = X.shape - dim = shape[-1] - X = X.view(-1, dim) - n_rows, n_cols = X.shape - BLOCK_SIZE, num_warps = calculate_settings(n_cols) - - Y = torch.empty((n_rows, n_cols), dtype=X.dtype, device=X.device) - # r is to cache (1/rms) for each row - # r is always computed/stored in fp32 if we are using Llama or Gemma casting mode - r_dtype = ( - torch.float32 - if casting_mode in (_CASTING_MODE_LLAMA.value, _CASTING_MODE_GEMMA.value) - else X.dtype - ) - r = torch.empty(n_rows, dtype=r_dtype, device=X.device) - - # Check constraints. - assert ( - X.shape[1] == W.shape[0] - ), "Incompatible hidden size dimension between tensor1.shape[1] and tensor2.shape[0]" - - _rms_norm_forward[(n_rows,)]( - Y, - Y.stride(0), - X, - X.stride(0), - W, - W.stride(0), - r, - r.stride(0), - n_cols, - eps, - offset, - casting_mode, - BLOCK_SIZE=BLOCK_SIZE, - num_warps=num_warps, + Y, X, r, BLOCK_SIZE, num_warps, casting_mode = rms_norm_forward( + X, W, eps, offset, casting_mode ) ctx.eps = eps ctx.offset = offset ctx.casting_mode = casting_mode ctx.BLOCK_SIZE = BLOCK_SIZE ctx.num_warps = num_warps - ctx.save_for_backward(X, W, r) - return Y.view(*shape) + return Y @staticmethod @ensure_contiguous @@ -268,40 +308,16 @@ def backward(ctx, dY): """ Y: (B, T, H) or (BxT, H) """ - - shape = dY.shape - dim = shape[-1] - dY = dY.view(-1, dim) X, W, r = ctx.saved_tensors - n_rows, n_cols = dY.shape - dW = torch.empty_like( - X, - dtype=( - torch.float32 - if ctx.casting_mode == _CASTING_MODE_GEMMA.value - else W.dtype - ), - ) - - # Here we use dY to store the value of dX to save memory - _rms_norm_backward[(n_rows,)]( + dX, dW = rms_norm_backward( dY, - dY.stride(0), X, - X.stride(0), W, - W.stride(0), r, - r.stride(0), - dW, - dW.stride(0), - n_cols, ctx.eps, ctx.offset, ctx.casting_mode, - BLOCK_SIZE=ctx.BLOCK_SIZE, - num_warps=ctx.num_warps, + ctx.BLOCK_SIZE, + ctx.num_warps, ) - dX = dY.view(*shape) - dW = torch.sum(dW, dim=0).to(W.dtype) return dX, dW, None, None, None diff --git a/src/liger_kernel/ops/rope.py b/src/liger_kernel/ops/rope.py index be718a625..f317d88c7 100644 --- a/src/liger_kernel/ops/rope.py +++ b/src/liger_kernel/ops/rope.py @@ -117,6 +117,92 @@ def _triton_rope( tl.store(k_ptr + second_half_k_offsets, new_k_tile_2, mask=second_k_mask) +def repo_foward(q, k, cos, sin): + + # transpose it back to the physical shape because Triton looks at the physical storage + # note: q and k are incontiguous before the transformation and will become contiguous after transpose + q = q.transpose(1, 2) + k = k.transpose(1, 2) + + batch_size, seq_len, n_q_head, head_dim = q.shape + n_kv_head = k.shape[2] + pad_hd = triton.next_power_of_2(head_dim) + pad_n_q_head = triton.next_power_of_2(n_q_head) + pad_n_kv_head = triton.next_power_of_2(n_kv_head) + BLOCK_SIZE = max(pad_n_q_head, pad_n_kv_head) + + n_row = batch_size * seq_len + + # ensure tensors passed into the kernel are contiguous. It will be no-op if they are already contiguous + q = q.contiguous() + k = k.contiguous() + cos = cos.contiguous() + sin = sin.contiguous() + + _triton_rope[(n_row,)]( + q, + q.stride(1), + k, + k.stride(1), + cos, + cos.stride(-2), + sin, + sin.stride(-2), + seq_len, + batch_size, + n_q_head, + n_kv_head, + head_dim, + pad_n_q_head, + pad_n_kv_head, + pad_hd, + BLOCK_SIZE=BLOCK_SIZE, + BACKWARD_PASS=False, + ) + return q.transpose(1, 2), k.transpose(1, 2), cos, sin + + +def rope_backward(dq, dk, cos, sin): + dq = dq.transpose(1, 2) + dk = dk.transpose(1, 2) + + batch_size, seq_len, n_q_head, head_dim = dq.shape + n_kv_head = dk.shape[2] + pad_hd = triton.next_power_of_2(head_dim) + pad_n_q_head = triton.next_power_of_2(n_q_head) + pad_n_kv_head = triton.next_power_of_2(n_kv_head) + BLOCK_SIZE = max(pad_n_q_head, pad_n_kv_head) + + n_row = batch_size * seq_len + + # ensure dq and dk are contiguous + dq = dq.contiguous() + dk = dk.contiguous() + + # backward is similar to forward except swapping few ops + _triton_rope[(n_row,)]( + dq, + dq.stride(1), + dk, + dk.stride(1), + cos, + cos.stride(-2), + sin, + sin.stride(-2), + seq_len, + batch_size, + n_q_head, + n_kv_head, + head_dim, + pad_n_q_head, + pad_n_kv_head, + pad_hd, + BLOCK_SIZE=BLOCK_SIZE, + BACKWARD_PASS=True, + ) + return dq.transpose(1, 2), dk.transpose(1, 2) + + class LigerRopeFunction(torch.autograd.Function): """ Triton implementation of the Rotary Positional Embedding (RoPE) operation. Please note that @@ -138,50 +224,9 @@ def forward(ctx, q, k, cos, sin, position_ids=None, unsqueeze_dim=1): cos size: (1, seq_len, head_dim) sin size: (1, seq_len, head_dim) """ - - # transpose it back to the physical shape because Triton looks at the physical storage - # note: q and k are incontiguous before the transformation and will become contiguous after transpose - q = q.transpose(1, 2) - k = k.transpose(1, 2) - - batch_size, seq_len, n_q_head, head_dim = q.shape - n_kv_head = k.shape[2] - pad_hd = triton.next_power_of_2(head_dim) - pad_n_q_head = triton.next_power_of_2(n_q_head) - pad_n_kv_head = triton.next_power_of_2(n_kv_head) - BLOCK_SIZE = max(pad_n_q_head, pad_n_kv_head) - - n_row = batch_size * seq_len - - # ensure tensors passed into the kernel are contiguous. It will be no-op if they are already contiguous - q = q.contiguous() - k = k.contiguous() - cos = cos.contiguous() - sin = sin.contiguous() - - _triton_rope[(n_row,)]( - q, - q.stride(1), - k, - k.stride(1), - cos, - cos.stride(-2), - sin, - sin.stride(-2), - seq_len, - batch_size, - n_q_head, - n_kv_head, - head_dim, - pad_n_q_head, - pad_n_kv_head, - pad_hd, - BLOCK_SIZE=BLOCK_SIZE, - BACKWARD_PASS=False, - ) - + q, k, cos, sin = repo_foward(q, k, cos, sin) ctx.save_for_backward(cos, sin) - return q.transpose(1, 2), k.transpose(1, 2) + return q, k def backward(ctx, dq, dk): """ @@ -192,43 +237,5 @@ def backward(ctx, dq, dk): """ cos, sin = ctx.saved_tensors - - dq = dq.transpose(1, 2) - dk = dk.transpose(1, 2) - - batch_size, seq_len, n_q_head, head_dim = dq.shape - n_kv_head = dk.shape[2] - pad_hd = triton.next_power_of_2(head_dim) - pad_n_q_head = triton.next_power_of_2(n_q_head) - pad_n_kv_head = triton.next_power_of_2(n_kv_head) - BLOCK_SIZE = max(pad_n_q_head, pad_n_kv_head) - - n_row = batch_size * seq_len - - # ensure dq and dk are contiguous - dq = dq.contiguous() - dk = dk.contiguous() - - # backward is similar to forward except swapping few ops - _triton_rope[(n_row,)]( - dq, - dq.stride(1), - dk, - dk.stride(1), - cos, - cos.stride(-2), - sin, - sin.stride(-2), - seq_len, - batch_size, - n_q_head, - n_kv_head, - head_dim, - pad_n_q_head, - pad_n_kv_head, - pad_hd, - BLOCK_SIZE=BLOCK_SIZE, - BACKWARD_PASS=True, - ) - - return dq.transpose(1, 2), dk.transpose(1, 2), None, None, None, None + dq, dk = rope_backward(dq, dk, cos, sin) + return dq, dk, None, None, None, None diff --git a/src/liger_kernel/ops/swiglu.py b/src/liger_kernel/ops/swiglu.py index d83625be4..1fa031910 100644 --- a/src/liger_kernel/ops/swiglu.py +++ b/src/liger_kernel/ops/swiglu.py @@ -60,54 +60,61 @@ def _swiglu_backward_kernel( tl.store(b_ptr + col_offsets, db_row, mask=mask) +def swiglu_forward(a, b): + ori_shape = a.shape + + n_cols = ori_shape[-1] + a = a.view(-1, n_cols) + b = b.view(-1, n_cols) + c = torch.zeros_like(a) + n_rows = a.shape[0] + + BLOCK_SIZE, num_warps = calculate_settings(n_cols) + + _swiglu_forward_kernel[(n_rows,)]( + a, + b, + c, + c.stride(-2), + n_cols=n_cols, + BLOCK_SIZE=BLOCK_SIZE, + num_warps=num_warps, + ) + return a, b, c.view(*ori_shape) + + +def swiglu_backward(a, b, dc): + + ori_shape = dc.shape + n_cols = ori_shape[-1] + dc = dc.view(-1, n_cols) + n_rows = dc.shape[0] + + BLOCK_SIZE, num_warps = calculate_settings(n_cols) + + _swiglu_backward_kernel[(n_rows,)]( + dc, + a, + b, + dc.stride(-2), + n_cols=n_cols, + BLOCK_SIZE=BLOCK_SIZE, + num_warps=num_warps, + ) + return a.view(*ori_shape), b.view(*ori_shape) + + class LigerSiLUMulFunction(torch.autograd.Function): @staticmethod @ensure_contiguous def forward(ctx, a, b): - ori_shape = a.shape - - n_cols = ori_shape[-1] - a = a.view(-1, n_cols) - b = b.view(-1, n_cols) - c = torch.zeros_like(a) - n_rows = a.shape[0] - - BLOCK_SIZE, num_warps = calculate_settings(n_cols) - - _swiglu_forward_kernel[(n_rows,)]( - a, - b, - c, - c.stride(-2), - n_cols=n_cols, - BLOCK_SIZE=BLOCK_SIZE, - num_warps=num_warps, - ) - + a, b, c = swiglu_forward(a, b) ctx.save_for_backward(a, b) - - return c.view(*ori_shape) + return c @staticmethod @ensure_contiguous def backward(ctx, dc): - - ori_shape = dc.shape - n_cols = ori_shape[-1] - dc = dc.view(-1, n_cols) a, b = ctx.saved_tensors - n_rows = dc.shape[0] - - BLOCK_SIZE, num_warps = calculate_settings(n_cols) - - _swiglu_backward_kernel[(n_rows,)]( - dc, - a, - b, - dc.stride(-2), - n_cols=n_cols, - BLOCK_SIZE=BLOCK_SIZE, - num_warps=num_warps, - ) - - return a.view(*ori_shape), b.view(*ori_shape) + a, b = swiglu_backward(a, b, dc) + return a, b