Skip to content

Commit

Permalink
Reduction support for CrossEntropy and Division by 0 Fix (#153)
Browse files Browse the repository at this point in the history
  • Loading branch information
shivam15s authored Sep 13, 2024
1 parent 7a5d484 commit 53d5934
Show file tree
Hide file tree
Showing 7 changed files with 139 additions and 57 deletions.
2 changes: 1 addition & 1 deletion Makefile
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
.PHONY: test checkstyle test-convergence all


all: test test-convergence checkstyle
all: checkstyle test test-convergence

# Command to run pytest for correctness tests
test:
Expand Down
43 changes: 35 additions & 8 deletions src/liger_kernel/ops/cross_entropy.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ def liger_cross_entropy_kernel(
n_non_ignore,
ignore_index,
label_smoothing: tl.constexpr,
reduction: tl.constexpr, # set it as constexpr since reduction is always known at compile time
BLOCK_SIZE: tl.constexpr,
):
"""
Expand All @@ -32,6 +33,7 @@ def liger_cross_entropy_kernel(
n_non_ignore (int): The number of non-ignored elements in the batch.
ignore_index (int): The index to ignore in the target.
label_smoothing (float): The amount of smoothing when computing the loss, where 0.0 means no smoothing.
reduction (str): The string for the reduction to apply
BLOCK_SIZE (int): The block size for Triton operations.
"""

Expand Down Expand Up @@ -83,20 +85,33 @@ def liger_cross_entropy_kernel(
d = d * tl.exp(m - m_new) + tl.sum(tl.exp(X_block - m_new))
m = m_new

# 4. [Online softmax] second pass: calculate the gradients
# 4. [Online Softmax] Second pass: compute gradients
# For 'mean' reduction, gradients are normalized by number of non-ignored elements (N)
# dx_y = (softmax(x_y) - 1) / N
# dx_i = softmax(x_i) / N, i != y
# N is the number of non ignored elements in the batch
# For label smoothing:
# dx_i = (softmax(x_y) - label_smoothing / V) / N, V = n_cols, i != y
# dx_y = (softmax(x_y) - label_smoothing / V - (1 - label_smoothing)) / N
# = dx_i - (1 - label_smoothing) / N
#
# For 'sum' reduction, no normalization is applied:
# dx_y = softmax(x_y) - 1
# dx_i = softmax(x_i), for i ≠ y
# For label smoothing:
# dx_i = (softmax(x_y) - label_smoothing / V), V = n_cols, i != y
# dx_y = (softmax(x_y) - label_smoothing / V - (1 - label_smoothing))
# = dx_i - (1 - label_smoothing)

for i in range(0, n_cols, BLOCK_SIZE):
X_offsets = i + tl.arange(0, BLOCK_SIZE)
X_block = tl.load(
X_ptr + X_offsets, mask=X_offsets < n_cols, other=float("-inf")
)
X_block = (tl.exp(X_block - m) / d - eps) / (n_non_ignore)
if reduction == "mean":
X_block = (tl.exp(X_block - m) / d - eps) / (n_non_ignore)
else:
X_block = tl.exp(X_block - m) / d - eps

tl.store(X_ptr + X_offsets, X_block, mask=X_offsets < n_cols)

# We need tl.debug_barrier() to ensure the new result of X_ptr is written as mentioned in
Expand All @@ -123,9 +138,16 @@ def liger_cross_entropy_kernel(
smooth_loss = scaled_x_sum + label_smoothing * (m + tl.log(d))
loss = loss * (1 - label_smoothing) + smooth_loss

# Normalize the loss by the number of non-ignored elements if reduction is "mean"
if reduction == "mean":
loss = loss / n_non_ignore

# 6. Specially handle the i==y case where `dx_y = (softmax(x_y) - (1 - label_smoothing) / N`
X_y = tl.load(X_ptr + y)
X_y += -(1 - label_smoothing) / (n_non_ignore)
if reduction == "mean":
X_y += -(1 - label_smoothing) / (n_non_ignore)
else:
X_y += -(1 - label_smoothing)

tl.store(loss_ptr, loss)
tl.store(X_ptr + y, X_y)
Expand Down Expand Up @@ -173,7 +195,7 @@ def element_mul_kernel(
tl.store(X_ptr + X_offsets, X_block * grad_output, mask=X_offsets < n_cols)


def cross_entropy_forward(_input, target, ignore_index, label_smoothing):
def cross_entropy_forward(_input, target, ignore_index, label_smoothing, reduction):
BT, V = _input.shape
n_rows = BT

Expand Down Expand Up @@ -202,13 +224,14 @@ def cross_entropy_forward(_input, target, ignore_index, label_smoothing):
n_non_ignore=n_non_ignore,
ignore_index=ignore_index,
label_smoothing=label_smoothing,
reduction=reduction,
BLOCK_SIZE=BLOCK_SIZE,
# TODO: 32 seems to give the best performance
# Performance is quite sensitive to num_warps
num_warps=32,
)

loss = torch.sum(loss_1d) / n_non_ignore
loss = torch.sum(loss_1d)
return loss, _input


Expand Down Expand Up @@ -243,7 +266,9 @@ class LigerCrossEntropyFunction(torch.autograd.Function):
"""

@staticmethod
def forward(ctx, _input, target, ignore_index=-100, label_smoothing=0.0):
def forward(
ctx, _input, target, ignore_index=-100, label_smoothing=0.0, reduction="mean"
):
"""
The forward pass of the Liger Cross Entropy loss.
Expand All @@ -253,12 +278,13 @@ def forward(ctx, _input, target, ignore_index=-100, label_smoothing=0.0):
target (tensor): The target tensor of shape (BT) where each value is in [0, V-1].
ignore_index (int): The index to ignore in the target.
label_smoothing (float): The amount of smoothing when computing the loss, where 0.0 means no smoothing.
reduction (str): The reduction to apply to the output: "none" | "mean | "sum".
Returns:
tensor: The computed loss.
"""
loss, _input = cross_entropy_forward(
_input, target, ignore_index, label_smoothing
_input, target, ignore_index, label_smoothing, reduction
)
# TODO: investigation
# If we don't detach the _input tensor, the memory will double
Expand All @@ -285,4 +311,5 @@ def backward(ctx, grad_output):
None,
None,
None,
None,
)
41 changes: 31 additions & 10 deletions src/liger_kernel/ops/fused_linear_cross_entropy.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,13 @@


def fused_linear_cross_entropy_forward(
_input, weight, target, bias=None, ignore_index=-100, label_smoothing=0.0
_input,
weight,
target,
bias=None,
ignore_index=-100,
label_smoothing=0.0,
reduction="mean",
):
dtype = (
torch.get_autocast_gpu_dtype() if torch.is_autocast_enabled() else _input.dtype
Expand Down Expand Up @@ -84,6 +90,7 @@ def fused_linear_cross_entropy_forward(
n_non_ignore=n_non_ignore,
ignore_index=ignore_index,
label_smoothing=label_smoothing,
reduction=reduction,
BLOCK_SIZE=BLOCK_SIZE,
num_warps=32,
)
Expand All @@ -100,9 +107,15 @@ def fused_linear_cross_entropy_forward(
# additionally, since we are chunking the inputs, observe that the loss and gradients are calculated only
# on `n_non_ignore` tokens. However, the gradient of the input should be calculated for all tokens.
# Thus, we need an additional scaling factor of (n_non_ignore/total_n_non_ignore) to scale the gradients.
grad_logits_chunk = logits_chunk * (
n_non_ignore / total_n_non_ignore
) # chunk_size x V

if reduction == "mean":
alpha = n_non_ignore / total_n_non_ignore if total_n_non_ignore > 0 else 0.0
else:
alpha = 1.0

loss_1d[start_idx:end_idx] = loss_1d_slice * alpha
grad_logits_chunk = logits_chunk * alpha # chunk_size x V

grad_input[start_idx:end_idx] = grad_logits_chunk @ weight

if grad_weight is not None:
Expand All @@ -111,7 +124,7 @@ def fused_linear_cross_entropy_forward(
mat1=logits_chunk.t(),
mat2=_input_chunk,
out=grad_weight,
alpha=n_non_ignore / total_n_non_ignore,
alpha=alpha,
beta=1.0,
)

Expand All @@ -120,10 +133,10 @@ def fused_linear_cross_entropy_forward(
input=grad_bias,
other=logits_chunk.sum(dim=0),
out=grad_bias,
alpha=n_non_ignore / total_n_non_ignore,
alpha=alpha,
)

loss = torch.sum(loss_1d) / total_n_non_ignore
loss = torch.sum(loss_1d)
return loss, grad_input, grad_weight, grad_bias


Expand Down Expand Up @@ -179,7 +192,14 @@ def fused_linear_cross_entropy_backward(
class LigerFusedLinearCrossEntropyFunction(torch.autograd.Function):
@staticmethod
def forward(
ctx, _input, weight, target, bias=None, ignore_index=-100, label_smoothing=0.0
ctx,
_input,
weight,
target,
bias=None,
ignore_index=-100,
label_smoothing=0.0,
reduction="mean",
):
"""
Fusing the last linear layer with cross-entropy loss
Expand All @@ -196,9 +216,10 @@ def forward(
bias: (V) where V is the number of classes
ignore_index: the index to ignore in the target
label_smoothing (float): The amount of smoothing when computing the loss, where 0.0 means no smoothing.
reduction: reduction to apply
"""
loss, grad_input, grad_weight, grad_bias = fused_linear_cross_entropy_forward(
_input, weight, target, bias, ignore_index, label_smoothing
_input, weight, target, bias, ignore_index, label_smoothing, reduction
)
# downcast to dtype and store for backward
ctx.save_for_backward(
Expand All @@ -214,4 +235,4 @@ def backward(ctx, grad_output):
grad_input, grad_weight, grad_bias = fused_linear_cross_entropy_backward(
grad_output, grad_input, grad_weight, grad_bias
)
return (grad_input, grad_weight, None, grad_bias, None, None)
return (grad_input, grad_weight, None, grad_bias, None, None, None)
7 changes: 6 additions & 1 deletion src/liger_kernel/transformers/cross_entropy.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,13 @@ def __init__(self, *args, **kwargs):
assert (self.label_smoothing >= 0) and (
self.label_smoothing <= 1
), f"label_smoothing must be between 0.0 and 1.0. Got: {self.label_smoothing}"
assert self.reduction in {
"mean",
"sum",
"none",
}, f"reduction must be one of 'mean', 'sum', or 'none'. Got: {self.reduction}"

def forward(self, _input, target):
return LigerCrossEntropyFunction.apply(
_input, target, self.ignore_index, self.label_smoothing
_input, target, self.ignore_index, self.label_smoothing, self.reduction
)
8 changes: 7 additions & 1 deletion src/liger_kernel/transformers/fused_linear_cross_entropy.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,5 +11,11 @@ def __init__(self, *args, **kwargs):

def forward(self, lin_weight, _input, target, bias=None):
return LigerFusedLinearCrossEntropyFunction.apply(
_input, lin_weight, target, bias, self.ignore_index, self.label_smoothing
_input,
lin_weight,
target,
bias,
self.ignore_index,
self.label_smoothing,
self.reduction,
)
Loading

0 comments on commit 53d5934

Please sign in to comment.