Skip to content

Commit

Permalink
Add complete test with other params
Browse files Browse the repository at this point in the history
  • Loading branch information
Tcc0403 committed Dec 2, 2024
1 parent e770182 commit f38e1e2
Show file tree
Hide file tree
Showing 2 changed files with 155 additions and 14 deletions.
6 changes: 2 additions & 4 deletions src/liger_kernel/ops/cross_entropy.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@ def liger_cross_entropy_kernel(
Y_ptr,
Y_stride,
weight_ptr,
weight_stride,
loss_ptr,
z_loss_ptr,
loss_stride,
Expand Down Expand Up @@ -69,7 +68,7 @@ def liger_cross_entropy_kernel(
reduction (str): The string for the reduction to apply
softcap (float): The upper threshold for scaling logits to the range (-softcap, +softcap).
BLOCK_SIZE (int): The block size for Triton operations.
HAS_WEIGHT (bool): The boolean value to dteremine whether assigning weight to each of the classes.
HAS_WEIGHT (bool): The boolean value to determine whether assigning weight to each of the classes.
HAS_SOFTCAPPING (bool): The boolean value to determine whether applying soft-capping or not.
"""

Expand Down Expand Up @@ -310,7 +309,6 @@ def cross_entropy_forward(
Y_ptr=target,
Y_stride=target.stride(-1), # always 1
weight_ptr=weight if weight is not None else _input, # dummy if None
weight_stride=weight.stride(-1) if weight is not None else 0,
loss_ptr=loss_1d,
z_loss_ptr=z_loss_1d,
loss_stride=loss_1d.stride(-1), # always 1
Expand Down Expand Up @@ -390,7 +388,7 @@ def forward(
ctx : The context object.
_input (tensor): The input tensor of shape (BT, V) where B is batch size, T is sequence length, V is vocab size.
target (tensor): The target tensor of shape (BT) where each value is in [0, V-1].
weight(Tensor, optional): a manual rescaling weight given to each class. If given, has to be a Tensor of size C and floating point dtype
weight(Tensor, optional): a manual rescaling weight given to each class. If given, has to be a Tensor of size V and floating point dtype
ignore_index (int): The index to ignore in the target.
lse_square_scale (float): The scaler of (logsumexp(_input)) ^ 2 adding to the loss for the stability of training.
label_smoothing (float): The amount of smoothing when computing the loss, where 0.0 means no smoothing.
Expand Down
163 changes: 153 additions & 10 deletions test/transformers/test_cross_entropy.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
class CrossEntropyWithZLoss(torch.nn.Module):
def __init__(
self,
weight=None,
lse_square_scale=0.0,
reduction="mean",
ignore_index=-100,
Expand All @@ -28,6 +29,7 @@ def __init__(
dtype=torch.float32,
):
super().__init__()
self.weight = weight
self.lse_square_scale = lse_square_scale
self.reduction = reduction
self.ignore_index = ignore_index
Expand All @@ -38,10 +40,24 @@ def __init__(
def forward(self, logits, targets):
# Loss calculations are all in float32
logits = logits.to(torch.float32)
HAS_WEIGHT = True if self.weight is not None else False
if HAS_WEIGHT:
self.weight = self.weight.to(torch.float32)
if self.ignore_index >= 0 and self.ignore_index < logits.shape[-1]:
weight_mask = torch.ones_like(self.weight)
weight_mask[self.ignore_index] = 0
selected_weight = torch.gather(
self.weight * weight_mask, dim=-1, index=targets
)
del weight_mask
else:
selected_weight = torch.gather(self.weight, dim=-1, index=targets)
sum_of_non_ignore_weight = selected_weight.sum().item()
# Standard cross entropy loss
ce_loss = F.cross_entropy(
logits,
targets,
weight=self.weight,
reduction=self.reduction,
label_smoothing=self.label_smoothing,
ignore_index=self.ignore_index,
Expand All @@ -54,9 +70,14 @@ def forward(self, logits, targets):
z_loss = torch.where(
targets != self.ignore_index, self.lse_square_scale * (lse**2), 0.0
)
z_loss = z_loss.to(logits.dtype)
if HAS_WEIGHT:
z_loss = z_loss * selected_weight

if self.reduction == "mean":
z_loss = z_loss.sum() / (targets != self.ignore_index).sum()
if HAS_WEIGHT:
z_loss = z_loss.sum() / sum_of_non_ignore_weight
else:
z_loss = z_loss.sum() / (targets != self.ignore_index).sum()
elif self.reduction == "sum":
z_loss = z_loss.sum()
else:
Expand Down Expand Up @@ -185,13 +206,15 @@ def _test_correctness_with_softcap_once(

_tensor = torch.randn(B * T, V, device=device, dtype=dtype) * scalar
# upcasting to match liger's casting strategy
_input = _tensor.to(torch.float32).detach().clone().requires_grad_(True)
_input = _tensor.detach().clone().requires_grad_(True)
_input2 = _tensor.detach().clone().requires_grad_(True)

target = torch.randint(0, V, (B * T,), device=device, dtype=torch.long)

# downcasting to original dtype
output = torch_ce(softcap * torch.tanh(_input / softcap), target).to(dtype)
output = torch_ce(
softcap * torch.tanh(_input.to(torch.float32) / softcap), target
).to(dtype)
output2 = target_ce(_input2, target)

assert torch.allclose(output, output2, atol=atol, rtol=rtol)
Expand Down Expand Up @@ -322,6 +345,59 @@ def _test_correctness_with_weight_once(
assert torch.allclose(_input.grad, _input2.grad, atol=atol, rtol=rtol)


def _test_correctness_with_weight_with_other_params_once(
target_ce,
B,
T,
V,
reduction,
weight,
lse_square_scale,
ignore_index,
label_smoothing,
softcap,
scalar,
dtype,
atol,
rtol,
):
torch.manual_seed(0)
torch_ce = CrossEntropyWithZLoss(
weight=weight,
lse_square_scale=lse_square_scale,
ignore_index=ignore_index,
reduction=reduction,
label_smoothing=label_smoothing,
dtype=dtype,
)

_tensor = torch.randn(B * T, V, device=device, dtype=dtype) * scalar
# upcasting to match liger's casting strategy
_input = _tensor.detach().clone().requires_grad_(True)
_input2 = _tensor.detach().clone().requires_grad_(True)

target = torch.randint(0, V, (B * T,), device=device, dtype=torch.long)

# Assign some random number of elements as ignore_index
num_elements_to_assign = torch.randint(
1, B * T // 2, (1,)
).item() # Random number of elements to set to ignore_index
indices_to_assign = torch.randperm(B * T)[
:num_elements_to_assign
] # Randomly select indices
target[indices_to_assign] = ignore_index

output = torch_ce(
softcap * torch.tanh(_input.to(torch.float32) / softcap), target
).to(dtype)
output2 = target_ce(_input2, target)
assert torch.allclose(output, output2, atol=atol, rtol=rtol)

output.backward()
output2.backward()
assert torch.allclose(_input.grad, _input2.grad, atol=atol, rtol=rtol)


def _test_correctness_not_last_layer_once(
target_ce, B, T, V, reduction, scalar, dtype, atol, rtol
):
Expand Down Expand Up @@ -717,7 +793,6 @@ def test_correctness_with_z_loss_with_other_params_once(
(3, 423, 32000),
],
)
@pytest.mark.parametrize("weight", [0.5, 0.1])
@pytest.mark.parametrize("reduction", ["sum", "mean"])
@pytest.mark.parametrize(
"scalar, dtype, atol, rtol",
Expand All @@ -734,16 +809,86 @@ def test_correctness_with_z_loss_with_other_params_once(
(1.0, torch.float32, 1e-8, 1e-6),
],
)
def test_correctness_with_weight_once(
B, T, V, weight, reduction, scalar, dtype, atol, rtol
):
def test_correctness_with_weight_once(B, T, V, reduction, scalar, dtype, atol, rtol):
weight = torch.rand(V, device=device, dtype=dtype)
test_ce = LigerCrossEntropyLoss(weight=weight, reduction=reduction)
_test_correctness_with_weight_once(
test_ce, B, T, V, reduction, weight, scalar, dtype, atol, rtol
)


@pytest.mark.parametrize(
"B, T, V",
[
(2, 4096, 3200), # llama2, mistral
# # weird shapes
(3, 423, 3200),
],
)
@pytest.mark.parametrize("reduction", ["sum", "mean"])
@pytest.mark.parametrize(
"ignore_index, lse_square_scale, label_smoothing, softcap",
[
(-100, 1e-4, 0.1, 30.0),
(42, 1e-5, 0.2, 40.0),
],
)
@pytest.mark.parametrize(
"scalar, dtype, atol, rtol",
[
pytest.param(
1.0,
torch.bfloat16,
1e-8,
5e-2,
marks=pytest.mark.skipif(
not supports_bfloat16(), reason="bfloat16 not supported on this GPU"
),
),
(1.0, torch.float32, 1e-8, 1e-6),
],
)
def test_correctness_with_weight_with_other_params_once(
B,
T,
V,
reduction,
lse_square_scale,
ignore_index,
label_smoothing,
softcap,
scalar,
dtype,
atol,
rtol,
):
weight = torch.rand(V, device=device, dtype=torch.float32) # match softcap casting
test_ce = LigerCrossEntropyLoss(
weight=weight,
lse_square_scale=lse_square_scale,
reduction=reduction,
ignore_index=ignore_index,
label_smoothing=label_smoothing,
softcap=softcap,
)
_test_correctness_with_weight_with_other_params_once(
test_ce,
B,
T,
V,
reduction,
weight,
lse_square_scale,
ignore_index,
label_smoothing,
softcap,
scalar,
dtype,
atol,
rtol,
)


@pytest.mark.parametrize(
"B, T, V",
[
Expand Down Expand Up @@ -804,7 +949,6 @@ def test_float32_internal():
Y_ptr=Y,
Y_stride=Y.stride(-1),
weight_ptr=X_bf16, # dummy ptr, not used
weight_stride=X_bf16.stride(-2),
z_loss_ptr=loss_bf16, # dummy ptr, not used
loss_ptr=loss_bf16,
loss_stride=loss_bf16.stride(-1),
Expand Down Expand Up @@ -832,7 +976,6 @@ def test_float32_internal():
Y_ptr=Y,
Y_stride=Y.stride(-1),
weight_ptr=X_fp32, # dummy ptr, not used
weight_stride=X_fp32.stride(-2),
loss_ptr=loss_fp32,
z_loss_ptr=loss_fp32, # dummy ptr, not used
loss_stride=loss_fp32.stride(-1),
Expand Down

0 comments on commit f38e1e2

Please sign in to comment.