Skip to content

Commit

Permalink
Feat: add kl div to readme (#229)
Browse files Browse the repository at this point in the history
## Summary
Adds newly implemented kl divergence loss to readme. Closes #188
finally.

## Testing Done
No code changes

---------

Co-authored-by: Shao Tang <[email protected]>
Co-authored-by: Byron Hsu <[email protected]>
  • Loading branch information
3 people authored Sep 8, 2024
1 parent 1cdb7f0 commit 9250546
Showing 1 changed file with 3 additions and 2 deletions.
5 changes: 3 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,7 @@ pip install -e .
```
## Getting Started

There are a couple ways to apply Liger kernels, depending on the level of customization required.
There are a couple of ways to apply Liger kernels, depending on the level of customization required.

### 1. Use AutoLigerKernelForCausalLM

Expand Down Expand Up @@ -242,6 +242,7 @@ loss.backward()
| GeGLU | `liger_kernel.transformers.LigerGEGLUMLP` |
| CrossEntropy | `liger_kernel.transformers.LigerCrossEntropyLoss` |
| FusedLinearCrossEntropy | `liger_kernel.transformers.LigerFusedLinearCrossEntropyLoss`|
| KLDivergence | `liger_kernel.transformers.LigerKLDIVLoss` |

- **RMSNorm**: [RMSNorm](https://arxiv.org/pdf/1910.07467), which normalizes activations using their root mean square, is implemented by fusing the normalization and scaling steps into a single Triton kernel, and achieves ~3X speedup with ~3X peak memory reduction.
- **LayerNorm**: [LayerNorm](https://arxiv.org/pdf/1607.06450), which centers and normalizes activations across the feature dimension, is implemented by fusing the centering, normalization and scaling steps into a single Triton kernel, and achieves ~2X speedup.
Expand All @@ -255,7 +256,7 @@ $$\text{GeGLU}(x)=\text{GELU}(xW+b)\otimes(xV+c)$$
- **CrossEntropy**: [Cross entropy loss](https://pytorch.org/docs/stable/generated/torch.nn.CrossEntropyLoss.html) is implemented by computing both the loss and gradient in the forward pass with inplace replacement of input to reduce the peak memory by avoiding simultaneous materialization of both input logits and gradient. It achieves >2X speedup and >4X memory reduction for common vocab sizes (e.g., 32K, 128K, etc.).
<!-- TODO: verify vocab sizes are accurate -->
- **FusedLinearCrossEntropy**: Peak memory usage of cross entropy loss is further improved by fusing the model head with the CE loss and chunking the input for block-wise loss and gradient calculation, a technique inspired by [Efficient Cross Entropy](https://github.com/mgmalek/efficient_cross_entropy). It achieves >4X memory reduction for 128k vocab size. **This is highly effective for large batch size, large sequence length, and large vocabulary sizes.** Please refer to the [Medusa example](https://github.com/linkedin/Liger-Kernel/tree/main/examples/medusa) for individual kernel usage.

- **KLDivergence**: [KL Divergence](https://pytorch.org/docs/stable/generated/torch.nn.KLDivLoss.html) is implemented by fusing the forward into a single triton kernel, with reduction done outside the kernel. It achieves ~1.5X speed and ~15% memory reduction for 128K vocab size.

### Experimental Kernels

Expand Down

0 comments on commit 9250546

Please sign in to comment.