Skip to content

Commit

Permalink
Merge branch 'linkedin:main' into skip-xpu
Browse files Browse the repository at this point in the history
  • Loading branch information
faaany authored Jan 22, 2025
2 parents 5c80957 + e8d3cc7 commit 6db6518
Show file tree
Hide file tree
Showing 37 changed files with 1,522 additions and 385 deletions.
28 changes: 28 additions & 0 deletions .github/workflows/docs.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
name: Publish documentation
on:
push:
branches:
- gh-pages
permissions:
contents: write
jobs:
deploy:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
- name: Configure Git Credentials
run: |
git config user.name github-actions[bot]
git config user.email 41898282+github-actions[bot]@users.noreply.github.com
- uses: actions/setup-python@v5
with:
python-version: 3.x
- run: echo "cache_id=$(date --utc '+%V')" >> $GITHUB_ENV
- uses: actions/cache@v4
with:
key: mkdocs-material-${{ env.cache_id }}
path: .cache
restore-keys: |
mkdocs-material-
- run: pip install mkdocs-material
- run: mkdocs gh-deploy --force
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -18,3 +18,4 @@ uv.lock

# Benchmark images
benchmark/visualizations
.vscode/
19 changes: 16 additions & 3 deletions Makefile
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
.PHONY: test checkstyle test-convergence all
.PHONY: test checkstyle test-convergence all serve build clean


all: checkstyle test test-convergence
Expand All @@ -7,8 +7,7 @@ all: checkstyle test test-convergence
test:
python -m pytest --disable-warnings test/ --ignore=test/convergence

# Command to run flake8 (code style check), isort (import ordering), and black (code formatting)
# Subsequent commands still run if the previous fails, but return failure at the end
# Command to run ruff for linting and formatting code
checkstyle:
ruff check . --fix; ruff_check_status=$$?; \
ruff format .; ruff_format_status=$$?; \
Expand Down Expand Up @@ -39,3 +38,17 @@ run-benchmarks:
python $$script; \
fi; \
done

# MkDocs Configuration
MKDOCS = mkdocs
CONFIG_FILE = mkdocs.yml

# MkDocs targets
serve:
$(MKDOCS) serve -f $(CONFIG_FILE)

build:
$(MKDOCS) build -f $(CONFIG_FILE)

clean:
rm -rf site/
10 changes: 5 additions & 5 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -59,8 +59,8 @@

<details>
<summary>Latest News 🔥</summary>
- [2024/12/11] We release [v0.5.0](https://github.com/linkedin/Liger-Kernel/releases/tag/v0.5.0): 80% more memory efficient post training losses (DPO, ORPO, CPO, etc)!

- [2024/12/11] We release [v0.5.0](https://github.com/linkedin/Liger-Kernel/releases/tag/v0.5.0): 80% more memory efficient post training losses (DPO, ORPO, CPO, etc)!
- [2024/12/5] We release LinkedIn Engineering Blog - [Liger-Kernel: Empowering an open source ecosystem of Triton Kernels for Efficient LLM Training](https://www.linkedin.com/blog/engineering/open-source/liger-kernel-open-source-ecosystem-for-efficient-llm-training)
- [2024/11/6] We release [v0.4.0](https://github.com/linkedin/Liger-Kernel/releases/tag/v0.4.0): Full AMD support, Tech Report, Modal CI, Llama-3.2-Vision!
- [2024/10/21] We have released the tech report of Liger Kernel on Arxiv: https://arxiv.org/pdf/2410.10989
Expand Down Expand Up @@ -253,7 +253,7 @@ loss.backward()
| Gemma1 | `liger_kernel.transformers.apply_liger_kernel_to_gemma` | RoPE, RMSNorm, GeGLU, CrossEntropyLoss, FusedLinearCrossEntropy |
| Gemma2 | `liger_kernel.transformers.apply_liger_kernel_to_gemma2` | RoPE, RMSNorm, GeGLU, CrossEntropyLoss, FusedLinearCrossEntropy |
| Qwen2, Qwen2.5, & QwQ | `liger_kernel.transformers.apply_liger_kernel_to_qwen2` | RoPE, RMSNorm, SwiGLU, CrossEntropyLoss, FusedLinearCrossEntropy |
| Qwen2-VL | `liger_kernel.transformers.apply_liger_kernel_to_qwen2_vl` | RMSNorm, LayerNorm, SwiGLU, CrossEntropyLoss, FusedLinearCrossEntropy |
| Qwen2-VL, & QVQ | `liger_kernel.transformers.apply_liger_kernel_to_qwen2_vl` | RMSNorm, LayerNorm, SwiGLU, CrossEntropyLoss, FusedLinearCrossEntropy |
| Phi3 & Phi3.5 | `liger_kernel.transformers.apply_liger_kernel_to_phi3` | RoPE, RMSNorm, SwiGLU, CrossEntropyLoss, FusedLinearCrossEntropy |


Expand Down Expand Up @@ -307,11 +307,11 @@ loss.backward()
- [License Information](https://github.com/linkedin/Liger-Kernel/blob/main/docs/License.md)

## Sponsorship and Collaboration

- [AMD](https://www.amd.com/en.html): Providing AMD GPUs for our AMD CI.
- [Intel](https://www.intel.com/): Providing Intel GPUs for our Intel CI.
- [Modal](https://modal.com/): Free 3000 credits from GPU MODE IRL for our NVIDIA CI.
- [EmbeddedLLM](https://embeddedllm.com/): Making Liger Kernel run fast and stable on AMD.
- [EmbeddedLLM](https://embeddedllm.com/): Making Liger Kernel run fast and stable on AMD.
- [HuggingFace](https://huggingface.co/): Integrating Liger Kernel into Hugging Face Transformers and TRL.
- [Lightning AI](https://lightning.ai/): Integrating Liger Kernel into Lightning Thunder.
- [Axolotl](https://axolotl.ai/): Integrating Liger Kernel into Axolotl.
Expand Down
45 changes: 10 additions & 35 deletions benchmark/scripts/benchmark_cpo_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,44 +11,13 @@
from utils import parse_benchmark_script_args
from utils import run_benchmarks

from liger_kernel.chunked_loss.cpo_loss import LigerFusedLinearCPOFunction
from liger_kernel.utils import infer_device

device = infer_device()

sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "../..")))


class TorchLMHeadCPO(torch.nn.Module):
"""Ground truth implementation of the linear fused with torch based cross entropy loss.
:param H: hidden size
:param V: vocab size
:param ignore_index: index to ignore
:param reduction: reduction method
"""

def __init__(self, H: int, V: int, dtype: torch.dtype, ignore_index: int = -100):
from test.chunked_loss.test_cpo_loss import HFCPOLoss

super().__init__()
self.lin = torch.nn.Linear(in_features=H, out_features=V, bias=False, dtype=dtype)
self.cpo_loss = HFCPOLoss().get_batch_loss_metrics

def forward(self, x, y):
return self.cpo_loss(x, self.lin.weight, y)


class LigerLMHeadCPO(torch.nn.Module):
def __init__(self, H: int, V: int, dtype: torch.dtype, ignore_index: int = -100):
super().__init__()
self.lin = torch.nn.Linear(in_features=H, out_features=V, bias=False, dtype=dtype)
self.cpo_loss = LigerFusedLinearCPOFunction.apply

def forward(self, x, y):
return self.cpo_loss(x, self.lin.weight, y)


#############################################################################
# Test the memory consumption of the linear fused cross entropy loss
#############################################################################
Expand All @@ -57,15 +26,18 @@ def forward(self, x, y):
def bench_memory_fused_linear_cpo_loss(
input: SingleBenchmarkRunInput,
) -> SingleBenchmarkRunOutput:
from test.chunked_loss.test_cpo_loss import LigerLMHeadCPO
from test.chunked_loss.test_cpo_loss import TorchLMHeadCPO

B = input.x
T = input.extra_benchmark_config["T"]
H = input.extra_benchmark_config["H"]
V = input.extra_benchmark_config["V"]
dtype = input.extra_benchmark_config["dtype"]
provider = input.kernel_provider

torch_lm_head_cpo = TorchLMHeadCPO(H=H, V=V, dtype=dtype).to(device)
liger_lm_head_cpo = LigerLMHeadCPO(H=H, V=V, dtype=dtype).to(device)
torch_lm_head_cpo = lambda x, target: TorchLMHeadCPO(H=H, V=V, dtype=dtype).to(device)(x, target)[0]
liger_lm_head_cpo = lambda x, target: LigerLMHeadCPO(H=H, V=V, dtype=dtype).to(device)(x, target)[0]

_input = torch.randn(B, T, H, requires_grad=True, dtype=dtype, device=device)
target = torch.randint(V, (B, T), dtype=torch.long, device=device)
Expand Down Expand Up @@ -96,6 +68,9 @@ def full():
def bench_speed_fused_linear_cpo_loss(
input: SingleBenchmarkRunInput,
) -> SingleBenchmarkRunOutput:
from test.chunked_loss.test_cpo_loss import LigerLMHeadCPO
from test.chunked_loss.test_cpo_loss import TorchLMHeadCPO

B = input.x
T = input.extra_benchmark_config["T"]
H = input.extra_benchmark_config["H"]
Expand All @@ -104,8 +79,8 @@ def bench_speed_fused_linear_cpo_loss(
provider = input.kernel_provider
mode = input.kernel_operation_mode

torch_lm_head_cpo = TorchLMHeadCPO(H=H, V=V, dtype=dtype).to(device)
liger_lm_head_cpo = LigerLMHeadCPO(H=H, V=V, dtype=dtype).to(device)
torch_lm_head_cpo = lambda x, target: TorchLMHeadCPO(H=H, V=V, dtype=dtype).to(device)(x, target)[0]
liger_lm_head_cpo = lambda x, target: LigerLMHeadCPO(H=H, V=V, dtype=dtype).to(device)(x, target)[0]

_input = torch.randn(B, T, H, requires_grad=True, dtype=dtype, device=device)
target = torch.randint(V, (B, T), dtype=torch.long, device=device)
Expand Down
88 changes: 28 additions & 60 deletions benchmark/scripts/benchmark_dpo_loss.py
Original file line number Diff line number Diff line change
@@ -1,71 +1,27 @@
import os
import sys

import torch
import triton

from test.chunked_loss.test_dpo_loss import HF_DPO_Loss
from utils import QUANTILES
from utils import SingleBenchmarkRunInput
from utils import SingleBenchmarkRunOutput
from utils import _test_memory
from utils import parse_benchmark_script_args
from utils import run_benchmarks

from liger_kernel.chunked_loss.dpo_loss import LigerFusedLinearDPOFunction
from liger_kernel.utils import infer_device

device = infer_device()


class TorchDPOLoss(torch.nn.Module):
def __init__(
self,
H: int,
V: int,
dtype: torch.dtype,
beta: float = 0.1,
ignore_index: int = -100,
bias: bool = False,
):
super().__init__()
self.lin = torch.nn.Linear(in_features=H, out_features=V, bias=bias, dtype=dtype)
self.dpo_loss = HF_DPO_Loss(beta=beta, ignore_index=ignore_index)

def forward(self, x, target):
return self.dpo_loss.get_batch_loss_metrics(
x,
self.lin.weight,
target,
self.lin.bias if hasattr(self.lin, "bias") else None,
)


class LigerDPOLoss(torch.nn.Module):
def __init__(
self,
H: int,
V: int,
dtype: torch.dtype,
beta: float = 0.1,
ignore_index: int = -100,
bias: bool = False,
):
super().__init__()
self.lin = torch.nn.Linear(in_features=H, out_features=V, bias=bias, dtype=dtype)
self.beta = beta
self.ignore_index = ignore_index

def forward(self, x, target):
return LigerFusedLinearDPOFunction.apply(
x,
self.lin.weight,
target,
self.lin.bias if hasattr(self.lin, "bias") else None,
self.ignore_index,
self.beta,
True,
)
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "../..")))


def bench_memory_dpo_loss(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOutput:
from test.chunked_loss.test_dpo_loss import LigerLMHeadDPO
from test.chunked_loss.test_dpo_loss import TorchLMHeadDPO

B = input.x
T = input.extra_benchmark_config["T"]
H = input.extra_benchmark_config["H"]
Expand All @@ -76,11 +32,16 @@ def bench_memory_dpo_loss(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunO
ignore_index = input.extra_benchmark_config["ignore_index"]
provider = input.kernel_provider

torch_dpo_loss = TorchDPOLoss(H=H, V=V, dtype=dtype, beta=beta, ignore_index=ignore_index, bias=bias).to(device)
liger_dpo_loss = LigerDPOLoss(H=H, V=V, dtype=dtype, beta=beta, ignore_index=ignore_index, bias=bias).to(device)
torch_dpo_loss = lambda x, ref_x, target: TorchLMHeadDPO(
H=H, V=V, dtype=dtype, beta=beta, ignore_index=ignore_index, bias=bias
).to(device)(x, ref_x, target)[0]
liger_dpo_loss = lambda x, ref_x, target: LigerLMHeadDPO(
H=H, V=V, dtype=dtype, beta=beta, ignore_index=ignore_index, bias=bias
).to(device)(x, ref_x, target)[0]

# Input shape: [B, T, H]
_input = torch.randn(B, T, H, device=device, dtype=dtype)
ref_input = torch.randn(B, T, H, device=device, dtype=dtype, requires_grad=False)
# Target shape: [B, T]
target = torch.randint(V, (B, T), dtype=torch.long, device=device)

Expand All @@ -91,9 +52,9 @@ def bench_memory_dpo_loss(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunO

def fwd():
if provider == "liger":
return liger_dpo_loss(_input, target)
return liger_dpo_loss(_input, ref_input, target)
elif provider == "huggingface":
return torch_dpo_loss(_input, target)
return torch_dpo_loss(_input, ref_input, target)

def full():
y = fwd()
Expand All @@ -108,6 +69,9 @@ def full():


def bench_speed_dpo_loss(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOutput:
from test.chunked_loss.test_dpo_loss import LigerLMHeadDPO
from test.chunked_loss.test_dpo_loss import TorchLMHeadDPO

B = input.x
T = input.extra_benchmark_config["T"]
H = input.extra_benchmark_config["H"]
Expand All @@ -119,12 +83,16 @@ def bench_speed_dpo_loss(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOu
provider = input.kernel_provider
mode = input.kernel_operation_mode

torch_dpo_loss = TorchDPOLoss(H=H, V=V, dtype=dtype, beta=beta, ignore_index=ignore_index, bias=bias).to(device)
liger_dpo_loss = LigerDPOLoss(H=H, V=V, dtype=dtype, beta=beta, ignore_index=ignore_index, bias=bias).to(device)
torch_dpo_loss = lambda x, ref_x, target: TorchLMHeadDPO(
H=H, V=V, dtype=dtype, beta=beta, ignore_index=ignore_index, bias=bias
).to(device)(x, ref_x, target)[0]
liger_dpo_loss = lambda x, ref_x, target: LigerLMHeadDPO(
H=H, V=V, dtype=dtype, beta=beta, ignore_index=ignore_index, bias=bias
).to(device)(x, ref_x, target)[0]

# Input shape: [B, T, H]
_input = torch.randn(B, T, H, device=device, dtype=dtype)

ref_input = torch.randn(B, T, H, device=device, dtype=dtype, requires_grad=False)
# Target shape: [B, T]
target = torch.randint(V, (B, T), device=device, dtype=torch.long)

Expand All @@ -135,9 +103,9 @@ def bench_speed_dpo_loss(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOu

def fwd():
if provider == "liger":
return liger_dpo_loss(_input, target)
return liger_dpo_loss(_input, ref_input, target)
elif provider == "huggingface":
return torch_dpo_loss(_input, target)
return torch_dpo_loss(_input, ref_input, target)

if mode == "forward":
ms_50, ms_20, ms_80 = triton.testing.do_bench(
Expand Down
Loading

0 comments on commit 6db6518

Please sign in to comment.