From 53d5934c796d7ecdcbdf9790dd9fec89a2205149 Mon Sep 17 00:00:00 2001 From: Shivam Sahni Date: Fri, 13 Sep 2024 11:53:07 -0700 Subject: [PATCH] Reduction support for CrossEntropy and Division by 0 Fix (#153) --- Makefile | 2 +- src/liger_kernel/ops/cross_entropy.py | 43 ++++++++++--- .../ops/fused_linear_cross_entropy.py | 41 ++++++++++--- .../transformers/cross_entropy.py | 7 ++- .../fused_linear_cross_entropy.py | 8 ++- test/transformers/test_cross_entropy.py | 61 ++++++++++--------- .../test_fused_linear_cross_entropy.py | 34 ++++++++--- 7 files changed, 139 insertions(+), 57 deletions(-) diff --git a/Makefile b/Makefile index 200390e75..f0120bd21 100644 --- a/Makefile +++ b/Makefile @@ -1,7 +1,7 @@ .PHONY: test checkstyle test-convergence all -all: test test-convergence checkstyle +all: checkstyle test test-convergence # Command to run pytest for correctness tests test: diff --git a/src/liger_kernel/ops/cross_entropy.py b/src/liger_kernel/ops/cross_entropy.py index 901809d4d..66e03ae4a 100644 --- a/src/liger_kernel/ops/cross_entropy.py +++ b/src/liger_kernel/ops/cross_entropy.py @@ -15,6 +15,7 @@ def liger_cross_entropy_kernel( n_non_ignore, ignore_index, label_smoothing: tl.constexpr, + reduction: tl.constexpr, # set it as constexpr since reduction is always known at compile time BLOCK_SIZE: tl.constexpr, ): """ @@ -32,6 +33,7 @@ def liger_cross_entropy_kernel( n_non_ignore (int): The number of non-ignored elements in the batch. ignore_index (int): The index to ignore in the target. label_smoothing (float): The amount of smoothing when computing the loss, where 0.0 means no smoothing. + reduction (str): The string for the reduction to apply BLOCK_SIZE (int): The block size for Triton operations. """ @@ -83,20 +85,33 @@ def liger_cross_entropy_kernel( d = d * tl.exp(m - m_new) + tl.sum(tl.exp(X_block - m_new)) m = m_new - # 4. [Online softmax] second pass: calculate the gradients + # 4. [Online Softmax] Second pass: compute gradients + # For 'mean' reduction, gradients are normalized by number of non-ignored elements (N) # dx_y = (softmax(x_y) - 1) / N # dx_i = softmax(x_i) / N, i != y - # N is the number of non ignored elements in the batch # For label smoothing: # dx_i = (softmax(x_y) - label_smoothing / V) / N, V = n_cols, i != y # dx_y = (softmax(x_y) - label_smoothing / V - (1 - label_smoothing)) / N # = dx_i - (1 - label_smoothing) / N + # + # For 'sum' reduction, no normalization is applied: + # dx_y = softmax(x_y) - 1 + # dx_i = softmax(x_i), for i ≠ y + # For label smoothing: + # dx_i = (softmax(x_y) - label_smoothing / V), V = n_cols, i != y + # dx_y = (softmax(x_y) - label_smoothing / V - (1 - label_smoothing)) + # = dx_i - (1 - label_smoothing) + for i in range(0, n_cols, BLOCK_SIZE): X_offsets = i + tl.arange(0, BLOCK_SIZE) X_block = tl.load( X_ptr + X_offsets, mask=X_offsets < n_cols, other=float("-inf") ) - X_block = (tl.exp(X_block - m) / d - eps) / (n_non_ignore) + if reduction == "mean": + X_block = (tl.exp(X_block - m) / d - eps) / (n_non_ignore) + else: + X_block = tl.exp(X_block - m) / d - eps + tl.store(X_ptr + X_offsets, X_block, mask=X_offsets < n_cols) # We need tl.debug_barrier() to ensure the new result of X_ptr is written as mentioned in @@ -123,9 +138,16 @@ def liger_cross_entropy_kernel( smooth_loss = scaled_x_sum + label_smoothing * (m + tl.log(d)) loss = loss * (1 - label_smoothing) + smooth_loss + # Normalize the loss by the number of non-ignored elements if reduction is "mean" + if reduction == "mean": + loss = loss / n_non_ignore + # 6. Specially handle the i==y case where `dx_y = (softmax(x_y) - (1 - label_smoothing) / N` X_y = tl.load(X_ptr + y) - X_y += -(1 - label_smoothing) / (n_non_ignore) + if reduction == "mean": + X_y += -(1 - label_smoothing) / (n_non_ignore) + else: + X_y += -(1 - label_smoothing) tl.store(loss_ptr, loss) tl.store(X_ptr + y, X_y) @@ -173,7 +195,7 @@ def element_mul_kernel( tl.store(X_ptr + X_offsets, X_block * grad_output, mask=X_offsets < n_cols) -def cross_entropy_forward(_input, target, ignore_index, label_smoothing): +def cross_entropy_forward(_input, target, ignore_index, label_smoothing, reduction): BT, V = _input.shape n_rows = BT @@ -202,13 +224,14 @@ def cross_entropy_forward(_input, target, ignore_index, label_smoothing): n_non_ignore=n_non_ignore, ignore_index=ignore_index, label_smoothing=label_smoothing, + reduction=reduction, BLOCK_SIZE=BLOCK_SIZE, # TODO: 32 seems to give the best performance # Performance is quite sensitive to num_warps num_warps=32, ) - loss = torch.sum(loss_1d) / n_non_ignore + loss = torch.sum(loss_1d) return loss, _input @@ -243,7 +266,9 @@ class LigerCrossEntropyFunction(torch.autograd.Function): """ @staticmethod - def forward(ctx, _input, target, ignore_index=-100, label_smoothing=0.0): + def forward( + ctx, _input, target, ignore_index=-100, label_smoothing=0.0, reduction="mean" + ): """ The forward pass of the Liger Cross Entropy loss. @@ -253,12 +278,13 @@ def forward(ctx, _input, target, ignore_index=-100, label_smoothing=0.0): target (tensor): The target tensor of shape (BT) where each value is in [0, V-1]. ignore_index (int): The index to ignore in the target. label_smoothing (float): The amount of smoothing when computing the loss, where 0.0 means no smoothing. + reduction (str): The reduction to apply to the output: "none" | "mean | "sum". Returns: tensor: The computed loss. """ loss, _input = cross_entropy_forward( - _input, target, ignore_index, label_smoothing + _input, target, ignore_index, label_smoothing, reduction ) # TODO: investigation # If we don't detach the _input tensor, the memory will double @@ -285,4 +311,5 @@ def backward(ctx, grad_output): None, None, None, + None, ) diff --git a/src/liger_kernel/ops/fused_linear_cross_entropy.py b/src/liger_kernel/ops/fused_linear_cross_entropy.py index 6fb519e83..27706f110 100644 --- a/src/liger_kernel/ops/fused_linear_cross_entropy.py +++ b/src/liger_kernel/ops/fused_linear_cross_entropy.py @@ -13,7 +13,13 @@ def fused_linear_cross_entropy_forward( - _input, weight, target, bias=None, ignore_index=-100, label_smoothing=0.0 + _input, + weight, + target, + bias=None, + ignore_index=-100, + label_smoothing=0.0, + reduction="mean", ): dtype = ( torch.get_autocast_gpu_dtype() if torch.is_autocast_enabled() else _input.dtype @@ -84,6 +90,7 @@ def fused_linear_cross_entropy_forward( n_non_ignore=n_non_ignore, ignore_index=ignore_index, label_smoothing=label_smoothing, + reduction=reduction, BLOCK_SIZE=BLOCK_SIZE, num_warps=32, ) @@ -100,9 +107,15 @@ def fused_linear_cross_entropy_forward( # additionally, since we are chunking the inputs, observe that the loss and gradients are calculated only # on `n_non_ignore` tokens. However, the gradient of the input should be calculated for all tokens. # Thus, we need an additional scaling factor of (n_non_ignore/total_n_non_ignore) to scale the gradients. - grad_logits_chunk = logits_chunk * ( - n_non_ignore / total_n_non_ignore - ) # chunk_size x V + + if reduction == "mean": + alpha = n_non_ignore / total_n_non_ignore if total_n_non_ignore > 0 else 0.0 + else: + alpha = 1.0 + + loss_1d[start_idx:end_idx] = loss_1d_slice * alpha + grad_logits_chunk = logits_chunk * alpha # chunk_size x V + grad_input[start_idx:end_idx] = grad_logits_chunk @ weight if grad_weight is not None: @@ -111,7 +124,7 @@ def fused_linear_cross_entropy_forward( mat1=logits_chunk.t(), mat2=_input_chunk, out=grad_weight, - alpha=n_non_ignore / total_n_non_ignore, + alpha=alpha, beta=1.0, ) @@ -120,10 +133,10 @@ def fused_linear_cross_entropy_forward( input=grad_bias, other=logits_chunk.sum(dim=0), out=grad_bias, - alpha=n_non_ignore / total_n_non_ignore, + alpha=alpha, ) - loss = torch.sum(loss_1d) / total_n_non_ignore + loss = torch.sum(loss_1d) return loss, grad_input, grad_weight, grad_bias @@ -179,7 +192,14 @@ def fused_linear_cross_entropy_backward( class LigerFusedLinearCrossEntropyFunction(torch.autograd.Function): @staticmethod def forward( - ctx, _input, weight, target, bias=None, ignore_index=-100, label_smoothing=0.0 + ctx, + _input, + weight, + target, + bias=None, + ignore_index=-100, + label_smoothing=0.0, + reduction="mean", ): """ Fusing the last linear layer with cross-entropy loss @@ -196,9 +216,10 @@ def forward( bias: (V) where V is the number of classes ignore_index: the index to ignore in the target label_smoothing (float): The amount of smoothing when computing the loss, where 0.0 means no smoothing. + reduction: reduction to apply """ loss, grad_input, grad_weight, grad_bias = fused_linear_cross_entropy_forward( - _input, weight, target, bias, ignore_index, label_smoothing + _input, weight, target, bias, ignore_index, label_smoothing, reduction ) # downcast to dtype and store for backward ctx.save_for_backward( @@ -214,4 +235,4 @@ def backward(ctx, grad_output): grad_input, grad_weight, grad_bias = fused_linear_cross_entropy_backward( grad_output, grad_input, grad_weight, grad_bias ) - return (grad_input, grad_weight, None, grad_bias, None, None) + return (grad_input, grad_weight, None, grad_bias, None, None, None) diff --git a/src/liger_kernel/transformers/cross_entropy.py b/src/liger_kernel/transformers/cross_entropy.py index 0adb1cc87..b2457481b 100644 --- a/src/liger_kernel/transformers/cross_entropy.py +++ b/src/liger_kernel/transformers/cross_entropy.py @@ -9,8 +9,13 @@ def __init__(self, *args, **kwargs): assert (self.label_smoothing >= 0) and ( self.label_smoothing <= 1 ), f"label_smoothing must be between 0.0 and 1.0. Got: {self.label_smoothing}" + assert self.reduction in { + "mean", + "sum", + "none", + }, f"reduction must be one of 'mean', 'sum', or 'none'. Got: {self.reduction}" def forward(self, _input, target): return LigerCrossEntropyFunction.apply( - _input, target, self.ignore_index, self.label_smoothing + _input, target, self.ignore_index, self.label_smoothing, self.reduction ) diff --git a/src/liger_kernel/transformers/fused_linear_cross_entropy.py b/src/liger_kernel/transformers/fused_linear_cross_entropy.py index 99da00d43..0e3331565 100644 --- a/src/liger_kernel/transformers/fused_linear_cross_entropy.py +++ b/src/liger_kernel/transformers/fused_linear_cross_entropy.py @@ -11,5 +11,11 @@ def __init__(self, *args, **kwargs): def forward(self, lin_weight, _input, target, bias=None): return LigerFusedLinearCrossEntropyFunction.apply( - _input, lin_weight, target, bias, self.ignore_index, self.label_smoothing + _input, + lin_weight, + target, + bias, + self.ignore_index, + self.label_smoothing, + self.reduction, ) diff --git a/test/transformers/test_cross_entropy.py b/test/transformers/test_cross_entropy.py index 40da8479c..1a970573e 100644 --- a/test/transformers/test_cross_entropy.py +++ b/test/transformers/test_cross_entropy.py @@ -1,4 +1,4 @@ -from test.utils import supports_bfloat16 +from test.utils import set_seed, supports_bfloat16 import pytest import torch @@ -8,12 +8,12 @@ from liger_kernel.transformers.cross_entropy import LigerCrossEntropyLoss from liger_kernel.transformers.functional import liger_cross_entropy -SLEEP_SECONDS = 0.1 +set_seed(42) -def _test_correctness_once(target_ce, B, T, V, scalar, dtype, atol, rtol): - torch.manual_seed(0) - torch_ce = CrossEntropyLoss() +def _test_correctness_once(target_ce, B, T, V, reduction, scalar, dtype, atol, rtol): + + torch_ce = CrossEntropyLoss(reduction=reduction) _tensor = torch.randn(B * T, V, device="cuda", dtype=dtype) * scalar _input = _tensor.detach().clone().requires_grad_(True) @@ -31,10 +31,10 @@ def _test_correctness_once(target_ce, B, T, V, scalar, dtype, atol, rtol): def _test_correctness_with_ignore_index_once( - target_ce, B, T, V, ignore_index, scalar, dtype, atol, rtol + target_ce, B, T, V, ignore_index, reduction, scalar, dtype, atol, rtol ): - torch.manual_seed(0) - torch_ce = CrossEntropyLoss(ignore_index=ignore_index) + + torch_ce = CrossEntropyLoss(ignore_index=ignore_index, reduction=reduction) _tensor = torch.randn(B * T, V, device="cuda", dtype=dtype) * scalar _input = _tensor.detach().clone().requires_grad_(True) @@ -64,7 +64,7 @@ def _test_correctness_with_ignore_index_once( def _test_correctness_with_label_smoothing_once( target_ce, B, T, V, label_smoothing, scalar, dtype, atol, rtol ): - torch.manual_seed(0) + torch_ce = CrossEntropyLoss(label_smoothing=label_smoothing) _tensor = torch.randn(B * T, V, device="cuda", dtype=dtype) * scalar @@ -86,7 +86,7 @@ def _test_correctness_with_label_smoothing_once( def _test_correctness_with_label_smoothing_with_ignore_index_once( target_ce, B, T, V, ignore_index, label_smoothing, scalar, dtype, atol, rtol ): - torch.manual_seed(0) + torch_ce = CrossEntropyLoss( ignore_index=ignore_index, label_smoothing=label_smoothing ) @@ -117,10 +117,10 @@ def _test_correctness_with_label_smoothing_with_ignore_index_once( def _test_correctness_not_last_layer_once( - target_ce, B, T, V, scalar, dtype, atol, rtol + target_ce, B, T, V, reduction, scalar, dtype, atol, rtol ): - torch.manual_seed(0) - torch_ce = CrossEntropyLoss() + + torch_ce = CrossEntropyLoss(reduction=reduction) _tensor = torch.randn(B * T, V, device="cuda", dtype=dtype) * scalar _input = _tensor.detach().clone().requires_grad_(True) @@ -141,7 +141,6 @@ def _test_correctness_not_last_layer_once( def _test_correctness_functional(B, T, V, scalar, dtype, atol, rtol): - torch.manual_seed(0) _input = torch.randn(B * T, V, device="cuda", dtype=dtype) * scalar @@ -178,6 +177,7 @@ def _test_correctness_functional(B, T, V, scalar, dtype, atol, rtol): (3, 423, 32000), ], ) +@pytest.mark.parametrize("reduction", ["sum", "mean"]) @pytest.mark.parametrize( "scalar, dtype, atol, rtol", [ @@ -217,9 +217,9 @@ def _test_correctness_functional(B, T, V, scalar, dtype, atol, rtol): torch.cuda.get_device_properties(0).total_memory < 16 * 1000 * 1000 * 1000, reason="Needs 16GB+ GPU memory.", ) -def test_correctness(B, T, V, scalar, dtype, atol, rtol): - liger_ce = LigerCrossEntropyLoss() - _test_correctness_once(liger_ce, B, T, V, scalar, dtype, atol, rtol) +def test_correctness(B, T, V, scalar, dtype, reduction, atol, rtol): + liger_ce = LigerCrossEntropyLoss(reduction=reduction) + _test_correctness_once(liger_ce, B, T, V, reduction, scalar, dtype, atol, rtol) @pytest.mark.parametrize( @@ -255,6 +255,7 @@ def test_correctness_functional(B, T, V, scalar, dtype, atol, rtol): (3, 423, 32000, -123), ], ) +@pytest.mark.parametrize("reduction", ["sum", "mean"]) @pytest.mark.parametrize( "scalar, dtype, atol, rtol", [ @@ -295,11 +296,11 @@ def test_correctness_functional(B, T, V, scalar, dtype, atol, rtol): reason="Needs 16GB+ GPU memory.", ) def test_correctness_with_ignore_index( - B, T, V, ignore_index, scalar, dtype, atol, rtol + B, T, V, ignore_index, reduction, scalar, dtype, atol, rtol ): - liger_ce = LigerCrossEntropyLoss(ignore_index=ignore_index) + liger_ce = LigerCrossEntropyLoss(ignore_index=ignore_index, reduction=reduction) _test_correctness_with_ignore_index_once( - liger_ce, B, T, V, ignore_index, scalar, dtype, atol, rtol + liger_ce, B, T, V, ignore_index, reduction, scalar, dtype, atol, rtol ) @@ -432,6 +433,7 @@ def test_correctness_with_label_smoothing_with_ignore_index_once( (3, 423, 32000), ], ) +@pytest.mark.parametrize("reduction", ["sum", "mean"]) @pytest.mark.parametrize( "scalar, dtype, atol, rtol", [ @@ -451,9 +453,11 @@ def test_correctness_with_label_smoothing_with_ignore_index_once( torch.cuda.get_device_properties(0).total_memory < 16 * 1000 * 1000 * 1000, reason="Needs 16GB+ GPU memory.", ) -def test_correctness_not_last_layer(B, T, V, scalar, dtype, atol, rtol): - liger_ce = LigerCrossEntropyLoss() - _test_correctness_not_last_layer_once(liger_ce, B, T, V, scalar, dtype, atol, rtol) +def test_correctness_not_last_layer(B, T, V, reduction, scalar, dtype, atol, rtol): + liger_ce = LigerCrossEntropyLoss(reduction=reduction) + _test_correctness_not_last_layer_once( + liger_ce, B, T, V, reduction, scalar, dtype, atol, rtol + ) ############################################################################# @@ -461,9 +465,9 @@ def test_correctness_not_last_layer(B, T, V, scalar, dtype, atol, rtol): ############################################################################# -def _full_pass_once(B, T, V): - torch.manual_seed(0) - liger_ce = LigerCrossEntropyLoss() +def _full_pass_once(B, T, V, reduction): + + liger_ce = LigerCrossEntropyLoss(reduction=reduction) _input = torch.randn( B * T, V, requires_grad=True, device="cuda", dtype=torch.bfloat16 @@ -485,11 +489,12 @@ def _full_pass_once(B, T, V): (8, 16384, 128256), # _input = 32GB, total = ~64GB ], ) +@pytest.mark.parametrize("reduction", ["sum", "mean"]) @pytest.mark.skipif( torch.cuda.get_device_properties(0).total_memory < 64 * 1000 * 1000 * 1000, reason="Needs 64GB+ GPU memory.", ) -def test_large_no_exception(B, T, V): +def test_large_no_exception(B, T, V, reduction): # The large inputs were hitting cuda illegal memory access because of # https://github.com/triton-lang/triton/issues/1058 - _full_pass_once(B, T, V) + _full_pass_once(B, T, V, reduction) diff --git a/test/transformers/test_fused_linear_cross_entropy.py b/test/transformers/test_fused_linear_cross_entropy.py index b8e9f76dd..57e2cf534 100644 --- a/test/transformers/test_fused_linear_cross_entropy.py +++ b/test/transformers/test_fused_linear_cross_entropy.py @@ -32,13 +32,16 @@ def __init__( bias: bool = False, ignore_index: int = -100, label_smoothing: float = 0.0, + reduction: str = "mean", ): super().__init__() self.lin = torch.nn.Linear( in_features=H, out_features=V, bias=bias, dtype=dtype ) self.ce_loss = torch.nn.CrossEntropyLoss( - ignore_index=ignore_index, reduction="mean", label_smoothing=label_smoothing + ignore_index=ignore_index, + reduction=reduction, + label_smoothing=label_smoothing, ) def forward(self, x, y): @@ -55,6 +58,7 @@ def __init__( bias: bool = False, ignore_index: int = -100, label_smoothing: float = 0.0, + reduction: str = "mean", ): super().__init__() self.lin = torch.nn.Linear( @@ -62,7 +66,7 @@ def __init__( ) self.ce_loss = LigerFusedLinearCrossEntropyLoss( ignore_index=ignore_index, - reduction="mean", + reduction=reduction, label_smoothing=label_smoothing, ) @@ -87,21 +91,35 @@ def forward(self, x, y): ], ) @pytest.mark.parametrize( - "scalar, dtype, atol, rtol", + "reduction, scalar, dtype, atol, rtol", [ - (1.0, torch.bfloat16, 5e-3, 5e-2), - (1.0, torch.float32, 1e-5, 5e-4), + ("mean", 1.0, torch.bfloat16, 5e-3, 5e-2), + ("mean", 1.0, torch.float32, 1e-5, 5e-4), + ("sum", 1.0, torch.bfloat16, 5e-0, 5e1), + ("sum", 1.0, torch.float32, 1e-3, 5e-2), ], ) @pytest.mark.parametrize("bias", [True, False]) @pytest.mark.parametrize("label_smoothing", [0, 0.1]) -def test_correctness(B, T, H, V, scalar, dtype, bias, label_smoothing, atol, rtol): +def test_correctness( + B, T, H, V, scalar, dtype, bias, label_smoothing, reduction, atol, rtol +): device = "cuda" torch_lm_head_ce = TorchLMHeadCE( - H=H, V=V, bias=bias, label_smoothing=label_smoothing, dtype=dtype + H=H, + V=V, + bias=bias, + label_smoothing=label_smoothing, + reduction=reduction, + dtype=dtype, ).to(device) liger_lm_head_ce = LigerLMHeadCE( - H=H, V=V, bias=bias, label_smoothing=label_smoothing, dtype=dtype + H=H, + V=V, + bias=bias, + label_smoothing=label_smoothing, + reduction=reduction, + dtype=dtype, ).to(device) # init the linear in all CEs with the same weights