Skip to content

Commit

Permalink
Set z_loss_1d=None when return_z_loss=False in cross_entropy_loss to …
Browse files Browse the repository at this point in the history
…avoid tl.store fail when triton_interpret=1(for tl.device_print etc.) (#508)

For [issue-507](#507)

## Summary
In cross_entropy_loss kernel, `tl.store(loss_ptr, loss)` doesn't work
when `return_z_loss=False` and `triton_interpret=1`, because loss_1d is
assigned to tensor z_loss_1d, So I set `z_loss_1d = None` in this
situation and it works well.

## Testing Done
I test it on my code and [this most simplified
example](#507 (comment)),
both work well.

Hardware Type: T4 GPU
- [ ] run `make test` to ensure correctness
- [x] run `make checkstyle` to ensure code style
- [x] run `make test-convergence` to ensure convergence
  • Loading branch information
wa008 authored Jan 4, 2025
1 parent bf48d8d commit 5c5a7b4
Showing 1 changed file with 3 additions and 2 deletions.
5 changes: 3 additions & 2 deletions src/liger_kernel/ops/cross_entropy.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,8 @@ def liger_cross_entropy_kernel(
return

loss_ptr += program_id * loss_stride
z_loss_ptr += program_id * loss_stride
if RETURN_Z_LOSS == _TRUE:
z_loss_ptr += program_id * loss_stride

if HAS_WEIGHT:
weight_y = tl.load(weight_ptr + y).cast(tl.float32)
Expand Down Expand Up @@ -296,7 +297,7 @@ def cross_entropy_forward(
if return_z_loss == _TRUE.value:
z_loss_1d = torch.zeros(n_rows, dtype=_input.dtype, device=_input.device)
else:
z_loss_1d = loss_1d # dummy ptr when return_z_loss == False
z_loss_1d = None # set None when return_z_loss == False

target_mask = target != ignore_index
n_non_ignore = target_mask.sum().item()
Expand Down

0 comments on commit 5c5a7b4

Please sign in to comment.