Skip to content

Commit

Permalink
Extract forward/backward core computation bits outside of torch autog…
Browse files Browse the repository at this point in the history
…rad context for easy reuse (#178)

## Summary
<!--- This is a required section; please describe the main purpose of
this proposed code change. --->
Extract forward/backward core computation bits outside of torch autograd
context for easy reuse. This is beneficial for lightning thunder
integration and the reuse of kernel in other context.

Doubled checked the speed and memory usage, within variance range, no
degradation

<!---
## Details
This is an optional section; is there anything specific that reviewers
should be aware of?
--->

## Testing Done
<!--- This is a required section; please describe how this change was
tested. --->

<!-- 
Replace BLANK with your device type. For example, A100-80G-PCIe

Complete the following tasks before sending your PR, and replace `[ ]`
with
`[x]` to indicate you have done them. 
-->

- Hardware Type: <BLANK>
- [x] run `make test` to ensure correctness
- [x] run `make checkstyle` to ensure code style
- [x] run `make test-convergence` to ensure convergence

---------

Co-authored-by: Shao Tang <[email protected]>
  • Loading branch information
qingquansong and lancerts authored Aug 31, 2024
1 parent ff24de8 commit 51060b0
Show file tree
Hide file tree
Showing 7 changed files with 597 additions and 529 deletions.
121 changes: 65 additions & 56 deletions src/liger_kernel/ops/cross_entropy.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@ def liger_cross_entropy_kernel(


@triton.jit
def element_mul(
def element_mul_kernel(
X_ptr,
X_stride,
grad_output_ptr,
Expand Down Expand Up @@ -147,6 +147,68 @@ def element_mul(
tl.store(X_ptr + X_offsets, X_block * grad_output, mask=X_offsets < n_cols)


def cross_entropy_forward(_input, target, ignore_index):
BT, V = _input.shape
n_rows = BT

BLOCK_SIZE = min(MAX_FUSED_SIZE, triton.next_power_of_2(V))

# unreduced loss
loss_1d = torch.zeros(n_rows, dtype=_input.dtype, device=_input.device)

n_non_ignore = (target != ignore_index).sum().item()

# ensure _input and target are contiguous in the last dimension
if _input.stride(-1) != 1:
_input = _input.contiguous()
if target.stride(-1) != 1:
target = target.contiguous()

# Here we use a trick to store X_ptr gradient in X_ptr so we can save memory
liger_cross_entropy_kernel[(n_rows,)](
X_ptr=_input,
X_stride=_input.stride(-2),
Y_ptr=target,
Y_stride=target.stride(-1), # always 1
loss_ptr=loss_1d,
loss_stride=loss_1d.stride(-1), # always 1
n_cols=V,
n_non_ignore=n_non_ignore,
ignore_index=ignore_index,
BLOCK_SIZE=BLOCK_SIZE,
# TODO: 32 seems to give the best performance
# Performance is quite sensitive to num_warps
num_warps=32,
)

loss = torch.sum(loss_1d) / n_non_ignore
return loss, _input


def cross_entropy_backward(_input, grad_output):
# If cross entropy is the last layer, grad_output is 1.0. Skip the mul to save time
if torch.equal(grad_output, torch.tensor(1.0, device=grad_output.device)):
pass

# We use a Triton kernel instead of a PyTorch operation because modifying inputs in-place
# for gradient storage and backward multiple times causes anomalies with PyTorch but not with Triton.
else:
BT, V = _input.shape
n_rows = BT
BLOCK_SIZE = min(MAX_FUSED_SIZE, triton.next_power_of_2(V))

element_mul_kernel[(n_rows,)](
_input,
_input.stride(-2),
grad_output,
V,
BLOCK_SIZE=BLOCK_SIZE,
num_warps=32,
)

return _input


class LigerCrossEntropyFunction(torch.autograd.Function):
"""
This class implements a custom autograd function for the Liger Cross Entropy loss.
Expand All @@ -167,41 +229,7 @@ def forward(ctx, _input, target, ignore_index):
Returns:
tensor: The computed loss.
"""
BT, V = _input.shape
n_rows = BT

BLOCK_SIZE = min(MAX_FUSED_SIZE, triton.next_power_of_2(V))

# unreduced loss
loss_1d = torch.zeros(n_rows, dtype=_input.dtype, device=_input.device)

n_non_ignore = (target != ignore_index).sum().item()

# ensure _input and target are contiguous in the last dimension
if _input.stride(-1) != 1:
_input = _input.contiguous()
if target.stride(-1) != 1:
target = target.contiguous()

# Here we use a trick to store X_ptr gradient in X_ptr so we can save memory
liger_cross_entropy_kernel[(n_rows,)](
X_ptr=_input,
X_stride=_input.stride(-2),
Y_ptr=target,
Y_stride=target.stride(-1), # always 1
loss_ptr=loss_1d,
loss_stride=loss_1d.stride(-1), # always 1
n_cols=V,
n_non_ignore=n_non_ignore,
ignore_index=ignore_index,
BLOCK_SIZE=BLOCK_SIZE,
# TODO: 32 seems to give the best performance
# Performance is quite sensitive to num_warps
num_warps=32,
)

loss = torch.sum(loss_1d) / n_non_ignore

loss, _input = cross_entropy_forward(_input, target, ignore_index)
# TODO: investigation
# If we don't detach the _input tensor, the memory will double
# Not sure why but seems that there will be a time both grad and value exist but in different location
Expand All @@ -221,26 +249,7 @@ def backward(ctx, grad_output):
tuple: A tuple with the gradients with respect to the inputs. The elements are tensors or None.
"""
(_input,) = ctx.saved_tensors
# If cross entropy is the last layer, grad_output is 1.0. Skip the mul to save time
if torch.equal(grad_output, torch.tensor(1.0, device=grad_output.device)):
pass

# We use a Triton kernel instead of a PyTorch operation because modifying inputs in-place
# for gradient storage and backward multiple times causes anomalies with PyTorch but not with Triton.
else:
BT, V = _input.shape
n_rows = BT
BLOCK_SIZE = min(MAX_FUSED_SIZE, triton.next_power_of_2(V))

element_mul[(n_rows,)](
_input,
_input.stride(-2),
grad_output,
V,
BLOCK_SIZE=BLOCK_SIZE,
num_warps=32,
)

_input = cross_entropy_backward(_input, grad_output)
return (
_input,
None,
Expand Down
Loading

0 comments on commit 51060b0

Please sign in to comment.