diff --git a/.gitignore b/.gitignore index 123eae3ad..6eb421731 100644 --- a/.gitignore +++ b/.gitignore @@ -5,6 +5,7 @@ site/ .venv/ venv/ .ipynb_checkpoints/ +.vscode/ # Misc .DS_Store diff --git a/benchmark/README.md b/benchmark/README.md new file mode 100644 index 000000000..64f16e10d --- /dev/null +++ b/benchmark/README.md @@ -0,0 +1,30 @@ +## Benchmarking Liger Kernels + +Follow these steps to benchmark and visualize kernel performance: + +1. Create a benchmark script + - Add your script under `benchmark/scripts/` + - Name it according to the kernel (e.g., `benchmark_.py`) + +2. Run the benchmark + - Results will be saved to `benchmark/data/all_benchmark_data.csv` + + Example: Benchmarking KTO Loss + ```bash + cd benchmark + python scripts/benchmark_kto_loss.py + ``` + +3. Visualize results + - Use the visualization script with appropriate parameters + + Example: Visualizing KTO Loss benchmark results + ```bash + python benchmarks_visualizer.py \ + --kernel-name kto_loss \ + --metric-name memory \ + --kernel-operation-mode full + ``` + +4. View results + - Generated plots will be saved in `benchmark/visualizations/` \ No newline at end of file diff --git a/benchmark/data/all_benchmark_data.csv b/benchmark/data/all_benchmark_data.csv index 4e966cab2..7b182f657 100644 --- a/benchmark/data/all_benchmark_data.csv +++ b/benchmark/data/all_benchmark_data.csv @@ -715,3 +715,33 @@ fused_linear_simpo_loss,huggingface,full,memory,MB,B,B,2,8645.314453125,8645.314 fused_linear_simpo_loss,huggingface,full,memory,MB,B,B,4,12184.330078125,12184.330078125,12184.330078125,"{""T"": 1024, ""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA A100-SXM4-80GB,2024-11-15 14:30:01,0.4.1 fused_linear_simpo_loss,huggingface,full,memory,MB,B,B,8,19262.361328125,19262.361328125,19262.361328125,"{""T"": 1024, ""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA A100-SXM4-80GB,2024-11-15 14:30:01,0.4.1 fused_linear_simpo_loss,huggingface,full,memory,MB,B,B,16,33418.42578125,33418.42578125,33418.42578125,"{""T"": 1024, ""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA A100-SXM4-80GB,2024-11-15 14:30:01,0.4.1 +kto_loss,liger,forward,speed,ms,B,Batch Size (B),2,7.841599941253662,7.801983833312988,7.849664211273193,"{""T"": 512, ""H"": 1024, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": true, ""beta"": 0.1, ""ignore_index"": 42}",NVIDIA A100-SXM4-80GB,2024-12-23 23:33:06,0.5.2 +kto_loss,liger,forward,speed,ms,B,Batch Size (B),4,15.568096160888672,15.555737495422363,16.054176330566406,"{""T"": 512, ""H"": 1024, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": true, ""beta"": 0.1, ""ignore_index"": 42}",NVIDIA A100-SXM4-80GB,2024-12-23 23:33:06,0.5.2 +kto_loss,liger,forward,speed,ms,B,Batch Size (B),8,31.145376205444336,30.750951766967773,31.5398006439209,"{""T"": 512, ""H"": 1024, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": true, ""beta"": 0.1, ""ignore_index"": 42}",NVIDIA A100-SXM4-80GB,2024-12-23 23:33:06,0.5.2 +kto_loss,liger,forward,speed,ms,B,Batch Size (B),16,61.49708938598633,61.49708938598633,61.49708938598633,"{""T"": 512, ""H"": 1024, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": true, ""beta"": 0.1, ""ignore_index"": 42}",NVIDIA A100-SXM4-80GB,2024-12-23 23:33:06,0.5.2 +kto_loss,liger,forward,speed,ms,B,Batch Size (B),32,122.01449584960938,122.01449584960938,122.01449584960938,"{""T"": 512, ""H"": 1024, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": true, ""beta"": 0.1, ""ignore_index"": 42}",NVIDIA A100-SXM4-80GB,2024-12-23 23:33:06,0.5.2 +kto_loss,huggingface,forward,speed,ms,B,Batch Size (B),2,7.892335891723633,7.8687615394592285,8.03729248046875,"{""T"": 512, ""H"": 1024, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": true, ""beta"": 0.1, ""ignore_index"": 42}",NVIDIA A100-SXM4-80GB,2024-12-23 23:33:28,0.5.2 +kto_loss,huggingface,forward,speed,ms,B,Batch Size (B),4,14.16302490234375,13.813311576843262,15.860223770141602,"{""T"": 512, ""H"": 1024, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": true, ""beta"": 0.1, ""ignore_index"": 42}",NVIDIA A100-SXM4-80GB,2024-12-23 23:33:28,0.5.2 +kto_loss,huggingface,forward,speed,ms,B,Batch Size (B),8,25.56470489501953,25.564167022705078,25.641658782958984,"{""T"": 512, ""H"": 1024, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": true, ""beta"": 0.1, ""ignore_index"": 42}",NVIDIA A100-SXM4-80GB,2024-12-23 23:33:28,0.5.2 +kto_loss,huggingface,forward,speed,ms,B,Batch Size (B),16,53.0928955078125,53.0928955078125,53.0928955078125,"{""T"": 512, ""H"": 1024, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": true, ""beta"": 0.1, ""ignore_index"": 42}",NVIDIA A100-SXM4-80GB,2024-12-23 23:33:28,0.5.2 +kto_loss,huggingface,forward,speed,ms,B,Batch Size (B),32,108.76080322265625,108.76080322265625,108.76080322265625,"{""T"": 512, ""H"": 1024, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": true, ""beta"": 0.1, ""ignore_index"": 42}",NVIDIA A100-SXM4-80GB,2024-12-23 23:33:28,0.5.2 +kto_loss,liger,full,speed,ms,B,Batch Size (B),2,8.662687301635742,8.488287925720215,9.611334800720215,"{""T"": 512, ""H"": 1024, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": true, ""beta"": 0.1, ""ignore_index"": 42}",NVIDIA A100-SXM4-80GB,2024-12-23 23:33:50,0.5.2 +kto_loss,liger,full,speed,ms,B,Batch Size (B),4,18.40096092224121,17.99224281311035,18.57883644104004,"{""T"": 512, ""H"": 1024, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": true, ""beta"": 0.1, ""ignore_index"": 42}",NVIDIA A100-SXM4-80GB,2024-12-23 23:33:50,0.5.2 +kto_loss,liger,full,speed,ms,B,Batch Size (B),8,32.09159851074219,31.708070755004883,32.475128173828125,"{""T"": 512, ""H"": 1024, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": true, ""beta"": 0.1, ""ignore_index"": 42}",NVIDIA A100-SXM4-80GB,2024-12-23 23:33:50,0.5.2 +kto_loss,liger,full,speed,ms,B,Batch Size (B),16,69.30239868164062,69.30239868164062,69.30239868164062,"{""T"": 512, ""H"": 1024, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": true, ""beta"": 0.1, ""ignore_index"": 42}",NVIDIA A100-SXM4-80GB,2024-12-23 23:33:50,0.5.2 +kto_loss,liger,full,speed,ms,B,Batch Size (B),32,124.2437744140625,124.2437744140625,124.2437744140625,"{""T"": 512, ""H"": 1024, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": true, ""beta"": 0.1, ""ignore_index"": 42}",NVIDIA A100-SXM4-80GB,2024-12-23 23:33:50,0.5.2 +kto_loss,huggingface,full,speed,ms,B,Batch Size (B),2,11.449472427368164,11.407564163208008,11.773555755615234,"{""T"": 512, ""H"": 1024, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": true, ""beta"": 0.1, ""ignore_index"": 42}",NVIDIA A100-SXM4-80GB,2024-12-23 23:34:11,0.5.2 +kto_loss,huggingface,full,speed,ms,B,Batch Size (B),4,20.871471405029297,20.862951278686523,20.879276275634766,"{""T"": 512, ""H"": 1024, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": true, ""beta"": 0.1, ""ignore_index"": 42}",NVIDIA A100-SXM4-80GB,2024-12-23 23:34:11,0.5.2 +kto_loss,huggingface,full,speed,ms,B,Batch Size (B),8,41.16409683227539,40.760780334472656,41.567413330078125,"{""T"": 512, ""H"": 1024, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": true, ""beta"": 0.1, ""ignore_index"": 42}",NVIDIA A100-SXM4-80GB,2024-12-23 23:34:11,0.5.2 +kto_loss,huggingface,full,speed,ms,B,Batch Size (B),16,77.720703125,77.720703125,77.720703125,"{""T"": 512, ""H"": 1024, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": true, ""beta"": 0.1, ""ignore_index"": 42}",NVIDIA A100-SXM4-80GB,2024-12-23 23:34:11,0.5.2 +kto_loss,huggingface,full,speed,ms,B,Batch Size (B),32,156.25794982910156,156.25794982910156,156.25794982910156,"{""T"": 512, ""H"": 1024, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": true, ""beta"": 0.1, ""ignore_index"": 42}",NVIDIA A100-SXM4-80GB,2024-12-23 23:34:11,0.5.2 +kto_loss,liger,full,memory,MB,B,Batch Size (B),2,2027.48583984375,2027.48583984375,2027.48583984375,"{""T"": 512, ""H"": 1024, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": true, ""beta"": 0.1, ""ignore_index"": 42}",NVIDIA A100-SXM4-80GB,2024-12-23 23:34:36,0.5.2 +kto_loss,liger,full,memory,MB,B,Batch Size (B),4,2789.736328125,2789.736328125,2789.736328125,"{""T"": 512, ""H"": 1024, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": true, ""beta"": 0.1, ""ignore_index"": 42}",NVIDIA A100-SXM4-80GB,2024-12-23 23:34:36,0.5.2 +kto_loss,liger,full,memory,MB,B,Batch Size (B),8,2801.751953125,2801.751953125,2801.751953125,"{""T"": 512, ""H"": 1024, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": true, ""beta"": 0.1, ""ignore_index"": 42}",NVIDIA A100-SXM4-80GB,2024-12-23 23:34:36,0.5.2 +kto_loss,liger,full,memory,MB,B,Batch Size (B),16,2825.783203125,2825.783203125,2825.783203125,"{""T"": 512, ""H"": 1024, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": true, ""beta"": 0.1, ""ignore_index"": 42}",NVIDIA A100-SXM4-80GB,2024-12-23 23:34:36,0.5.2 +kto_loss,liger,full,memory,MB,B,Batch Size (B),32,2873.845703125,2873.845703125,2873.845703125,"{""T"": 512, ""H"": 1024, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": true, ""beta"": 0.1, ""ignore_index"": 42}",NVIDIA A100-SXM4-80GB,2024-12-23 23:34:36,0.5.2 +kto_loss,huggingface,full,memory,MB,B,Batch Size (B),2,3786.7373046875,3786.7373046875,3786.7373046875,"{""T"": 512, ""H"": 1024, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": true, ""beta"": 0.1, ""ignore_index"": 42}",NVIDIA A100-SXM4-80GB,2024-12-23 23:34:59,0.5.2 +kto_loss,huggingface,full,memory,MB,B,Batch Size (B),4,5544.25390625,5544.25390625,5544.25390625,"{""T"": 512, ""H"": 1024, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": true, ""beta"": 0.1, ""ignore_index"": 42}",NVIDIA A100-SXM4-80GB,2024-12-23 23:34:59,0.5.2 +kto_loss,huggingface,full,memory,MB,B,Batch Size (B),8,9057.287109375,9057.287109375,9057.287109375,"{""T"": 512, ""H"": 1024, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": true, ""beta"": 0.1, ""ignore_index"": 42}",NVIDIA A100-SXM4-80GB,2024-12-23 23:34:59,0.5.2 +kto_loss,huggingface,full,memory,MB,B,Batch Size (B),16,16087.353515625,16087.353515625,16087.353515625,"{""T"": 512, ""H"": 1024, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": true, ""beta"": 0.1, ""ignore_index"": 42}",NVIDIA A100-SXM4-80GB,2024-12-23 23:34:59,0.5.2 +kto_loss,huggingface,full,memory,MB,B,Batch Size (B),32,30147.486328125,30147.486328125,30147.486328125,"{""T"": 512, ""H"": 1024, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": true, ""beta"": 0.1, ""ignore_index"": 42}",NVIDIA A100-SXM4-80GB,2024-12-23 23:34:59,0.5.2 diff --git a/benchmark/scripts/benchmark_kto_loss.py b/benchmark/scripts/benchmark_kto_loss.py new file mode 100644 index 000000000..fa88d0a51 --- /dev/null +++ b/benchmark/scripts/benchmark_kto_loss.py @@ -0,0 +1,314 @@ +import os +import sys + +import torch +import triton + +from utils import QUANTILES +from utils import SingleBenchmarkRunInput +from utils import SingleBenchmarkRunOutput +from utils import _test_memory +from utils import parse_benchmark_script_args +from utils import run_benchmarks + +from liger_kernel.chunked_loss import LigerFusedLinearKTOLoss +from liger_kernel.utils import infer_device + +device = infer_device() +sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "../.."))) + + +class TorchLMHeadKTO(torch.nn.Module): + def __init__( + self, + H: int, + V: int, + dtype: torch.dtype, + use_bias: bool = False, + use_ref_bias: bool = False, + ignore_index: int = -100, + beta: float = 0.1, + ): + from test.chunked_loss.test_kto_loss import HFKTOLoss + + super().__init__() + self.lin = torch.nn.Linear(in_features=H, out_features=V, bias=use_bias, dtype=dtype) + self.ref_lin = torch.nn.Linear(in_features=H, out_features=V, bias=use_ref_bias, dtype=dtype) + self.KTO_loss = HFKTOLoss( + ignore_index=ignore_index, + beta=beta, + use_ref_model=True, + ).get_batch_loss_metrics + + def forward(self, x, ref_x, y, preference_labels, kl=None): + return self.KTO_loss( + weight=self.lin.weight, + _input=x, + target=y, + bias=self.lin.bias, + ref_input=ref_x, + ref_weight=self.ref_lin.weight, + ref_bias=self.ref_lin.bias, + preference_labels=preference_labels, + kl=kl, + ) + + +class LigerLMHeadKTO(torch.nn.Module): + def __init__( + self, + H: int, + V: int, + dtype: torch.dtype, + use_bias: bool = False, + use_ref_bias: bool = False, + ignore_index: int = -100, + beta: float = 0.1, + ): + super().__init__() + self.lin = torch.nn.Linear(in_features=H, out_features=V, bias=use_bias, dtype=dtype) + self.ref_lin = torch.nn.Linear(in_features=H, out_features=V, bias=use_ref_bias, dtype=dtype) + self.KTO_loss = LigerFusedLinearKTOLoss( + ignore_index=ignore_index, + beta=beta, + use_ref_model=True, + ) + + def forward(self, x, ref_x, y, preference_labels, kl=None): + return self.KTO_loss( + _input=x, + lin_weight=self.lin.weight, + target=y, + preference_labels=preference_labels, + bias=self.lin.bias, + ref_input=ref_x, + ref_weight=self.ref_lin.weight, + ref_bias=self.ref_lin.bias, + kl=kl, + ) + + +def bench_memory_kto_loss(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOutput: + B = input.x + T = input.extra_benchmark_config["T"] + H = input.extra_benchmark_config["H"] + V = input.extra_benchmark_config["V"] + dtype = input.extra_benchmark_config["dtype"] + bias = input.extra_benchmark_config["bias"] + beta = input.extra_benchmark_config["beta"] + ignore_index = input.extra_benchmark_config["ignore_index"] + provider = input.kernel_provider + + torch_kto_loss = TorchLMHeadKTO( + H=H, + V=V, + dtype=dtype, + bias=bias, + ref_bias=bias, + ignore_index=ignore_index, + beta=beta, + ).to(device) + + liger_kto_loss = LigerLMHeadKTO( + H=H, + V=V, + dtype=dtype, + bias=bias, + ref_bias=bias, + ignore_index=ignore_index, + beta=beta, + ).to(device) + + # Input shape: [B, T, H] + _input = torch.randn(B, T, H, device=device, dtype=dtype) + + # Target shape: [B, T] + target = torch.randint(V, (B, T), dtype=torch.long, device=device) + + # Preference labels shape: [B] + # Create binary preference labels (0 or 1) for each sequence in the batch + # Used to indicate preferred sequences (1) vs non-preferred sequences (0) + preference_labels = torch.randint(2, (B,), dtype=torch.bool, device=device) + + # Precomputed KL divergence between policy and reference distributions + kl = torch.randn(1, device=device, dtype=dtype) + + # Add ignore_index tokens to simulate padding + num_elements_to_assign = torch.randint(1, B * T // 2, (1,)).item() + indices_to_assign = torch.randperm(B * T)[:num_elements_to_assign] + target.view(-1)[indices_to_assign] = ignore_index + + # Add ref_x with the same shape as _input + ref_input = torch.randn(B, T, H, device=device, dtype=dtype) + + def fwd(): + if provider == "liger": + return liger_kto_loss( + x=_input, + ref_x=ref_input, + y=target, + preference_labels=preference_labels, + kl=kl, + ) + elif provider == "huggingface": + return torch_kto_loss( + x=_input, + ref_x=ref_input, + y=target, + preference_labels=preference_labels, + kl=kl, + ) + + def full(): + y = fwd() + y.backward() + + mem_50, mem_20, mem_80 = _test_memory(full, _iter=10, quantiles=QUANTILES) + return SingleBenchmarkRunOutput( + y_20=mem_20, + y_50=mem_50, + y_80=mem_80, + ) + + +def bench_speed_kto_loss(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOutput: + B = input.x + T = input.extra_benchmark_config["T"] + H = input.extra_benchmark_config["H"] + V = input.extra_benchmark_config["V"] + dtype = input.extra_benchmark_config["dtype"] + bias = input.extra_benchmark_config["bias"] + beta = input.extra_benchmark_config["beta"] + ignore_index = input.extra_benchmark_config["ignore_index"] + provider = input.kernel_provider + mode = input.kernel_operation_mode + + torch_kto_loss = TorchLMHeadKTO( + H=H, + V=V, + dtype=dtype, + beta=beta, + ignore_index=ignore_index, + bias=bias, + ).to(device) + liger_kto_loss = LigerLMHeadKTO( + H=H, + V=V, + dtype=dtype, + beta=beta, + ignore_index=ignore_index, + bias=bias, + ).to(device) + + # Input shape: [B, T, H] + _input = torch.randn(B, T, H, device=device, dtype=dtype) + + # Target shape: [B, T] + target = torch.randint(V, (B, T), device=device, dtype=torch.long) + + # Preference labels shape: [B] + # Create binary preference labels (0 or 1) for each sequence in the batch + # Used to indicate preferred sequences (1) vs non-preferred sequences (0) + preference_labels = torch.randint(2, (B,), dtype=torch.bool, device=device) + + # Precomputed KL divergence between policy and reference distributions + kl = torch.randn(1, device=device, dtype=dtype) + + # Add ignore_index tokens + num_elements_to_assign = torch.randint(1, B * T // 2, (1,)).item() + indices_to_assign = torch.randperm(B * T)[:num_elements_to_assign] + target.view(-1)[indices_to_assign] = ignore_index + + # Add ref_x with the same shape as _input + ref_input = torch.randn(B, T, H, device=device, dtype=dtype) + + def fwd(): + if provider == "liger": + return liger_kto_loss( + x=_input, + ref_x=ref_input, + y=target, + preference_labels=preference_labels, + kl=kl, + ) + elif provider == "huggingface": + return torch_kto_loss( + x=_input, + ref_x=ref_input, + y=target, + preference_labels=preference_labels, + kl=kl, + ) + + if mode == "forward": + ms_50, ms_20, ms_80 = triton.testing.do_bench( + fwd, + rep=100, + quantiles=QUANTILES, + ) + elif mode == "backward": + y = fwd() + ms_50, ms_20, ms_80 = triton.testing.do_bench( + lambda: y.backward(retain_graph=True), + grad_to_none=[_input], + rep=100, + quantiles=QUANTILES, + ) + elif mode == "full": + + def full(): + y = fwd() + y.backward() + + ms_50, ms_20, ms_80 = triton.testing.do_bench( + full, + rep=100, + quantiles=QUANTILES, + ) + + return SingleBenchmarkRunOutput( + y_20=ms_20, + y_50=ms_50, + y_80=ms_80, + ) + + +if __name__ == "__main__": + args = parse_benchmark_script_args() + + common_configs = { + "kernel_name": "kto_loss", + "x_name": "B", + "x_label": "Batch Size (B)", + "x_values": [2**i for i in range(1, 6)], + "kernel_providers": ["liger", "huggingface"], + "extra_benchmark_configs": [ + { + "T": 512, + "H": 1024, + "V": 128256, + "mode": "forward", + "dtype": torch.bfloat16, + "bias": True, + "beta": 0.1, + "ignore_index": 42, + } + ], + "overwrite": args.overwrite, + } + + run_benchmarks( + bench_test_fn=bench_speed_kto_loss, + kernel_operation_modes=["forward", "full"], + metric_name="speed", + metric_unit="ms", + **common_configs, + ) + + run_benchmarks( + bench_test_fn=bench_memory_kto_loss, + kernel_operation_modes=["full"], + metric_name="memory", + metric_unit="MB", + **common_configs, + ) diff --git a/src/liger_kernel/chunked_loss/README.md b/src/liger_kernel/chunked_loss/README.md index 15ab24543..1dd7037f2 100644 --- a/src/liger_kernel/chunked_loss/README.md +++ b/src/liger_kernel/chunked_loss/README.md @@ -1,6 +1,6 @@ # Liger FlexChunkLoss: Alignment and Distillation loss -Liger FlexChunkLoss offers a versatile interface, delivering up to 80% memory savings and a 10% throughput boost for post-training loss functions, including alignment (DPO, ORPO, CPO) and very soon, distillation. Its flexible design supports custom losses, ensuring efficiency gains across diverse use cases. +Liger FlexChunkLoss offers a versatile interface, delivering up to 80% memory savings and a 10% throughput boost for post-training loss functions, including alignment (DPO, ORPO, CPO, KTO) and very soon, distillation. Its flexible design supports custom losses, ensuring efficiency gains across diverse use cases. ### User interface diff --git a/src/liger_kernel/chunked_loss/__init__.py b/src/liger_kernel/chunked_loss/__init__.py index 238bdded9..87f3887b5 100644 --- a/src/liger_kernel/chunked_loss/__init__.py +++ b/src/liger_kernel/chunked_loss/__init__.py @@ -1,4 +1,5 @@ from liger_kernel.chunked_loss.cpo_loss import LigerFusedLinearCPOLoss # noqa: F401 from liger_kernel.chunked_loss.dpo_loss import LigerFusedLinearDPOLoss # noqa: F401 +from liger_kernel.chunked_loss.kto_loss import LigerFusedLinearKTOLoss # noqa: F401 from liger_kernel.chunked_loss.orpo_loss import LigerFusedLinearORPOLoss # noqa: F401 from liger_kernel.chunked_loss.simpo_loss import LigerFusedLinearSimPOLoss # noqa: F401 diff --git a/src/liger_kernel/chunked_loss/functional.py b/src/liger_kernel/chunked_loss/functional.py index 5a51d3f72..a10398400 100644 --- a/src/liger_kernel/chunked_loss/functional.py +++ b/src/liger_kernel/chunked_loss/functional.py @@ -1,5 +1,6 @@ from liger_kernel.chunked_loss.cpo_loss import LigerFusedLinearCPOFunction from liger_kernel.chunked_loss.dpo_loss import LigerFusedLinearDPOFunction +from liger_kernel.chunked_loss.kto_loss import LigerFusedLinearKTOFunction from liger_kernel.chunked_loss.orpo_loss import LigerFusedLinearORPOFunction from liger_kernel.chunked_loss.simpo_loss import LigerFusedLinearSimPOFunction @@ -7,3 +8,4 @@ liger_fused_linear_dpo = LigerFusedLinearDPOFunction.apply liger_fused_linear_cpo = LigerFusedLinearCPOFunction.apply liger_fused_linear_simpo = LigerFusedLinearSimPOFunction.apply +liger_fused_linear_kto = LigerFusedLinearKTOFunction.apply diff --git a/src/liger_kernel/chunked_loss/fused_linear_unpaired_preference.py b/src/liger_kernel/chunked_loss/fused_linear_unpaired_preference.py new file mode 100644 index 000000000..7dc3466c9 --- /dev/null +++ b/src/liger_kernel/chunked_loss/fused_linear_unpaired_preference.py @@ -0,0 +1,246 @@ +from abc import abstractmethod +from functools import partial + +import torch + +from torch.nn import functional as F + + +class LigerFusedLinearUnpairedPreferenceBase(torch.autograd.Function): + @abstractmethod + def preference_loss_fn(*args, **kwargs): + """ + To be extended by subclasses. + """ + raise NotImplementedError("Preference loss function must be implemented.") + + @staticmethod + def forward( + ctx, + _input, + weight, + target, + preference_labels, + bias=None, + loss_fn=None, + chunk_size=1, + ignore_index=-100, + compiled=True, + use_ref_model=False, + ref_input=None, + ref_weight=None, + ref_bias=None, + **loss_kwargs, + ): + """ + Base class for fused linear layer with unpaired preference loss like KTO + Expects _input to be stacked with chosen and rejected inputs on the batch dimension. + + The mental model is: + + forward() + ├── Loop over chunks + └── compute_loss() + ├── chunk_forward() # Compute logits and log probs + └── prefer_loss() # Calculate preference loss + + Args: + _input (torch.Tensor): Input tensor. Shape: (batch_size, seq_len, hidden_size). + weight (torch.Tensor): Weight tensor. Shape: (vocab_size, hidden_size). + target (torch.Tensor): Target tensor. Shape: (batch_size, seq_len). + bias (torch.Tensor, optional): Bias tensor. Shape: (vocab_size,). + loss_fn (callable): Loss function to compute the loss on a chunk of input/target. + chunk_size (int): Size of a chunk (# of batches of stacked chosen and rejected inputs). + ignore_index (int): Index to ignore for loss computation. + beta (float): Weight for the preference loss. + compiled (bool): Whether to use torch compile for chunk accumulation. + use_ref_model (bool): Whether to use a reference model for the alignment loss. + preference_labels (torch.Tensor): Boolean tensor indicating chosen (True) vs rejected (False) examples. + Shape: (batch_size,). + ref_weight (torch.Tensor): Reference weight tensor. Shape: (vocab_size, hidden_size). + ref_bias (torch.Tensor, optional): Reference bias tensor. Shape: (vocab_size,). + loss_kwargs (dict): Other possible arguments that a loss function might need + """ + # TODO: Tune CHUNK_SIZE to fully utilize the GPU + CHUNK_SIZE = chunk_size + + # Gradients to be accumulated + grad_inputs = [] + grad_weight = torch.zeros_like(weight) + grad_bias = torch.zeros_like(bias) if bias is not None else None + + # Loss to be accumulated + loss_acc = torch.zeros((), device=_input.device) + + compute_loss = partial( + LigerFusedLinearUnpairedPreferenceBase._compute_loss, + preference_loss_fn=loss_fn, + full_target=target, + ignore_index=ignore_index, + use_ref_model=use_ref_model, + ref_weight=ref_weight, + ref_bias=ref_bias, + **loss_kwargs, + ) + + def fused_fwd_bwd(input_chunk, target_chunk, preference_labels_chunk, ref_input_chunk): + """ + Fused forward and backward pass for a chunk of input and target. + """ + argnums = (0, 1, 4) if bias is not None else (0, 1) + return torch.func.grad_and_value(compute_loss, argnums=argnums, has_aux=False)( + input_chunk, + weight, + target_chunk, + preference_labels_chunk, + bias, + ref_input_chunk=ref_input_chunk, + ) + + def accumulate_chunk( + input_chunk, + target_chunk, + preference_labels_chunk=None, + ref_input_chunk=None, + ): + (chunk_grad_input, chunk_grad_weight, *chunk_grad_bias), (chunk_loss) = fused_fwd_bwd( + input_chunk, target_chunk, preference_labels_chunk, ref_input_chunk + ) + if bias is not None: + grad_bias.add_(chunk_grad_bias[0]) # accumulate bias gradient + + # Accumulate gradients + grad_weight.add_(chunk_grad_weight) + grad_inputs.append(chunk_grad_input) + + # Accumulate loss + loss_acc.add_(chunk_loss) + + if compiled: + fused_fwd_bwd = torch.compile(fused_fwd_bwd) + + # When not paired, use labels to separate chosen and rejected + assert preference_labels is not None, "preference_labels must be provided for unpaired preference loss" + + chunks = max(1, _input.shape[0] // CHUNK_SIZE) + _input_chunks = torch.chunk(_input, chunks=chunks, dim=0) + _target_chunks = torch.chunk(target, chunks=chunks, dim=0) + _preference_labels_chunks = torch.chunk(preference_labels, chunks=chunks, dim=0) + + if use_ref_model: + _ref_input_chunks = torch.chunk(ref_input, chunks=chunks, dim=0) + + for ( + input_chunk, + target_chunk, + ref_input_chunk, + preference_labels_chunk, + ) in zip( + _input_chunks, + _target_chunks, + (_ref_input_chunks if use_ref_model else [None] * len(_input_chunks)), + _preference_labels_chunks, + ): + # mark input_chunk, target_chunk, and target dimension 1 (sequence length) as dynamic to prevent torch.compile recompilation + torch._dynamo.mark_dynamic(input_chunk, 1) + torch._dynamo.mark_dynamic(target_chunk, 1) + torch._dynamo.mark_dynamic(target, 1) + torch._dynamo.mark_dynamic(ref_input_chunk, 1) if use_ref_model else None + torch._dynamo.mark_dynamic(preference_labels_chunk, 1) + + # accumulate loss, gradients, and metrics + accumulate_chunk(input_chunk, target_chunk, preference_labels_chunk, ref_input_chunk) + + ctx.save_for_backward( + torch.cat(grad_inputs, dim=0), + grad_weight, + grad_bias, + ) + return loss_acc + + @staticmethod + def backward(ctx, *grad_output): + grad_input, grad_weight, grad_bias = ctx.saved_tensors + if torch.ne(grad_output[0][0], torch.tensor(1.0, device=grad_output[0][0].device)): + grad_input = grad_input * grad_output[0][0] + grad_weight = grad_weight * grad_output[0][0] + grad_bias = grad_bias * grad_output[0][0] if grad_bias is not None else None + + return grad_input, grad_weight, None, None, grad_bias + + @staticmethod + def chunk_forward( + input_chunk, + weight, + target_chunk, + bias=None, + ignore_index=-100, + ): + logits_chunk = input_chunk @ weight.t() + if bias is not None: + logits_chunk = logits_chunk + bias + log_probs_chunk = F.log_softmax(logits_chunk.float(), dim=-1) + + loss_mask_chunk = target_chunk != ignore_index + label_chunk = torch.where(loss_mask_chunk, target_chunk, 0) + + per_token_logps_chunk = log_probs_chunk.gather(-1, label_chunk.unsqueeze(-1)).squeeze(-1) + average_log_prob_chunk = (per_token_logps_chunk * loss_mask_chunk).sum(-1) / loss_mask_chunk.sum(-1) + + return average_log_prob_chunk + + @staticmethod + def _compute_loss( + input_chunk, + weight, + target_chunk, + preference_labels_chunk, + bias=None, + preference_loss_fn=None, + full_target=None, + ignore_index=-100, + use_ref_model=False, + ref_input_chunk=None, + ref_weight=None, + ref_bias=None, + **loss_kwargs, + ): + """ + Compute the total loss for a chunk of input and target, while using an alignment/preference loss function. + Args: + preference_loss_fn (callable): Loss function to compute the loss on a chunk of input/target. + input_chunk (torch.Tensor): Chunk of input tensor. Shape: (2 * chunk_size, sequence_length, hidden_size). + weight (torch.Tensor): Weight tensor. Shape: (vocab_size, hidden_size). + target_chunk (torch.Tensor): Chunk of target tensor. Shape: (2 * chunk_size, sequence_length). + bias (torch.Tensor, optional): Bias tensor. Shape: (vocab_size,). + full_target (torch.Tensor): Full target tensor. Shape: (batch_size, sequence_length). + ignore_index (int): Index to ignore for loss computation. + use_ref_model (bool): Whether to use a reference model for the alignment loss. + ref_weight (torch.Tensor): Reference weight tensor. Shape: (vocab_size, hidden_size). + ref_bias (torch.Tensor, optional): Reference bias tensor. Shape: (vocab_size,). + loss_kwargs (dict): Additional arguments for the loss function. + """ + average_log_prob_chunk = LigerFusedLinearUnpairedPreferenceBase.chunk_forward( + input_chunk, + weight, + target_chunk, + bias=bias, + ignore_index=ignore_index, + ) + + if use_ref_model: + with torch.no_grad(): + ref_average_log_prob_chunk = LigerFusedLinearUnpairedPreferenceBase.chunk_forward( + ref_input_chunk, + ref_weight, + target_chunk, + ref_bias, + ignore_index=ignore_index, + ) + loss_kwargs["ref_average_log_prob_chunk"] = ref_average_log_prob_chunk + + preference_loss_chunk = preference_loss_fn( + average_log_prob_chunk, preference_labels_chunk, full_target, **loss_kwargs + ) + + return preference_loss_chunk diff --git a/src/liger_kernel/chunked_loss/kto_loss.py b/src/liger_kernel/chunked_loss/kto_loss.py new file mode 100644 index 000000000..596baf82f --- /dev/null +++ b/src/liger_kernel/chunked_loss/kto_loss.py @@ -0,0 +1,172 @@ +import torch +import torch.nn.functional as F + +from liger_kernel.chunked_loss.fused_linear_unpaired_preference import LigerFusedLinearUnpairedPreferenceBase + + +class LigerFusedLinearKTOFunction(LigerFusedLinearUnpairedPreferenceBase): + @staticmethod + def preference_loss_fn( + average_log_prob_chunk, + preference_labels_chunk, + full_target, + ref_average_log_prob_chunk=None, + beta=0.1, + kl=None, + ): + """ + Implements the Kahneman-Tversky Optimization (KTO) loss function. + Paper: "KTO: Model Alignment as Prospect Theory-Guided Optimization" + https://arxiv.org/abs/2402.01306 + + KTO loss is inspired by prospect theory (https://en.wikipedia.org/wiki/Prospect_theory) + from behavioral economics, which models how humans make decisions under uncertainty. + The loss function is asymmetric, treating gains and losses differently, similar to + human decision-making patterns. + + Formula: + When y is chosen: + L_KTO = 1 - σ(β * (log[π(x)/π₀(x)] - KL(π||π₀)_y)) + When y is rejected: + L_KTO = 1 - σ(β * (KL(π||π₀)_y - log[π(x)/π₀(x)])) + + Where: + - σ: Sigmoid function + - β: Temperature parameter controlling the strength of the preference signal + - π(x): Policy (current model) + - π₀(x): Reference policy (reference model) + - KL(π||π₀)_y: KL divergence estimated using the rejected response y + + The loss encourages the model to: + 1. Assign higher probability to chosen responses + 2. Assign lower probability to rejected responses + 3. Maintain reasonable distance from the reference model + + Args: + chosen_logps: Log probabilities of chosen tokens (batch_size,) + rejected_logps: Log probabilities of rejected tokens (batch_size,) + full_target: Non chunked full target tensor + ref_chosen_logps: Reference log probs of chosen tokens (batch_size,) + ref_rejected_logps: Reference log probs of rejected tokens (batch_size,) + beta: Weight for the direct preference loss + kl: KL divergence between the policy model and the reference model for the chosen responses. Shape: (batch_size,) + Returns: + Tuple of (loss, chosen_rewards, rejected_rewards): + - loss: The KTO loss value + - chosen_rewards: Reward signals for chosen responses (detached) + - rejected_rewards: Reward signals for rejected responses (detached) + """ + logratios_chunk = average_log_prob_chunk - ref_average_log_prob_chunk + multiplier_chunk = torch.where(preference_labels_chunk, 1, -1) + if kl is not None: + losses = 1 - F.sigmoid(beta * (logratios_chunk - kl) * multiplier_chunk) + else: + losses = 1 - F.sigmoid(beta * logratios_chunk * multiplier_chunk) + + return losses.sum() / (full_target.shape[0]) + + @staticmethod + def forward( + ctx, + _input, + weight, + target, + preference_labels, + bias=None, + ref_input=None, + ref_weight=None, + ref_bias=None, + kl=None, + ignore_index=-100, + beta=0.1, + compiled=True, + use_ref_model=True, + ): + return LigerFusedLinearUnpairedPreferenceBase.forward( + ctx=ctx, + _input=_input, + weight=weight, + target=target, + preference_labels=preference_labels, + bias=bias, + loss_fn=LigerFusedLinearKTOFunction.preference_loss_fn, + ignore_index=ignore_index, + beta=beta, + compiled=compiled, + use_ref_model=use_ref_model, + ref_input=ref_input, + ref_weight=ref_weight, + ref_bias=ref_bias, + kl=kl, + ) + + @staticmethod + def backward(ctx, *grad_output): + grads = LigerFusedLinearUnpairedPreferenceBase.backward(ctx, grad_output)[:5] + return ( + *grads, + None, + None, + None, + None, + None, + None, + None, + None, + None, + None, + ) + + +class LigerFusedLinearKTOLoss(torch.nn.Module): + """ + Fused linear layer with Kahneman-Tversky Optimization (KTO) loss. + """ + + def __init__( + self, + ignore_index: int = -100, + beta: float = 0.1, + compiled: bool = True, + use_ref_model: bool = False, + ): + """ + Args: + ignore_index (int): Index to ignore in the loss calculation + beta (float): Temperature parameter for the KTO loss + compiled (bool): Whether to use compiled operations + use_ref_model (bool): Whether to use a reference model for the DPO loss. + """ + super().__init__() + self.ignore_index = ignore_index + self.beta = beta + self.compiled = compiled + self.use_ref_model = use_ref_model + + def forward( + self, + _input, + lin_weight, + target, + bias=None, + preference_labels=None, + ref_input=None, + ref_weight=None, + ref_bias=None, + kl=None, + ): + return LigerFusedLinearKTOFunction.apply( + _input, + lin_weight, + target, + preference_labels, + bias, + ref_input, + ref_weight, + ref_bias, + kl, + self.ignore_index, + self.beta, + self.compiled, + self.use_ref_model, + ) diff --git a/test/chunked_loss/test_dpo_loss.py b/test/chunked_loss/test_dpo_loss.py index ab18a0f24..ec8dc3e1f 100644 --- a/test/chunked_loss/test_dpo_loss.py +++ b/test/chunked_loss/test_dpo_loss.py @@ -18,9 +18,9 @@ class HFDPOLoss(HFAlignmentLoss): """ - Implementation of the Odds Ratio Preference Optimization (ORPO) loss, + Implementation of the Direct Preference Optimization (DPO) loss, adapted from Hugging Face's implementation. - Reference: https://github.com/huggingface/trl/blob/main/trl/trainer/orpo_trainer.py + Reference: https://github.com/huggingface/trl/blob/main/trl/trainer/dpo_trainer.py """ def __init__( diff --git a/test/chunked_loss/test_kto_loss.py b/test/chunked_loss/test_kto_loss.py new file mode 100644 index 000000000..5edf8e45e --- /dev/null +++ b/test/chunked_loss/test_kto_loss.py @@ -0,0 +1,353 @@ +import pytest +import torch +import torch.nn.functional as F + +from liger_kernel.chunked_loss import LigerFusedLinearKTOLoss +from liger_kernel.chunked_loss.functional import liger_fused_linear_kto +from liger_kernel.chunked_loss.kto_loss import LigerFusedLinearKTOFunction +from liger_kernel.utils import infer_device +from test.utils import HFAlignmentLoss +from test.utils import assert_verbose_allclose +from test.utils import set_seed + +device = infer_device() + +# set random seed globally +set_seed() + + +class HFKTOLoss(HFAlignmentLoss): + """ + Implementation of the Kahneman-Tversky Optimization (KTO) loss, + adapted from Hugging Face's implementation. + Reference: https://github.com/huggingface/trl/blob/main/trl/trainer/kto_trainer.py + """ + + def __init__( + self, + ignore_index: int = -100, + beta: float = 0.1, + use_ref_model: bool = True, + ): + super().__init__( + beta=beta, + ignore_index=ignore_index, + use_ref_model=use_ref_model, + unpaired=True, + compute_nll_loss=False, + ) + + def alignment_loss( + self, + policy_chosen_logps: torch.FloatTensor, + policy_rejected_logps: torch.FloatTensor, + ref_chosen_logps: torch.FloatTensor, + ref_rejected_logps: torch.FloatTensor, + kl: torch.FloatTensor = None, + ): + """Compute KTO loss for a batch of policy log probabilities. + Args: + policy_chosen_logps: Log probabilities of the policy model for the chosen responses. Shape: (batch_size,) + policy_rejected_logps: Log probabilities of the policy model for the rejected responses. Shape: (batch_size,) + ref_chosen_logps: Log probabilities of the reference model for the chosen responses. Shape: (batch_size,) + ref_rejected_logps: Log probabilities of the reference model for the rejected responses. Shape: (batch_size,) + Returns: + The losses tensor contains the KTO loss for each example in the batch. + """ + if kl is None: + kl = torch.zeros(1).to(policy_chosen_logps.device) + + # Chosen losses + if policy_chosen_logps.shape[0] != 0 or ref_chosen_logps.shape[0] != 0: + chosen_logratios = policy_chosen_logps - ref_chosen_logps + # Eqn (7) of the KTO paper (https://huggingface.co/papers/2402.01306) + chosen_losses = 1 - F.sigmoid(self.beta * (chosen_logratios - kl)) + + else: + # lists can't be empty -- if they are, then accelerate.gather will hang + chosen_losses = torch.Tensor([]).to(policy_chosen_logps.device) + + # Rejected losses + if policy_rejected_logps.shape[0] != 0 or ref_rejected_logps.shape[0] != 0: + rejected_logratios = policy_rejected_logps - ref_rejected_logps + rejected_losses = 1 - F.sigmoid(self.beta * (kl - rejected_logratios)) + else: + # lists can't be empty -- if they are, then accelerate.gather will hang + rejected_losses = torch.Tensor([]).to(policy_rejected_logps.device) + + losses = torch.cat( + (chosen_losses, rejected_losses), + 0, + ) + + return losses + + +class TorchLMHeadKTO(torch.nn.Module): + def __init__( + self, + H: int, + V: int, + dtype: torch.dtype, + bias: bool = False, + ref_bias: bool = False, + ignore_index: int = -100, + beta: float = 0.1, + ): + super().__init__() + self.lin = torch.nn.Linear(in_features=H, out_features=V, bias=bias, dtype=dtype) + self.ref_lin = torch.nn.Linear(in_features=H, out_features=V, bias=ref_bias, dtype=dtype) + self.KTO_loss = HFKTOLoss( + ignore_index=ignore_index, + beta=beta, + use_ref_model=True, + ).get_batch_loss_metrics + + def forward(self, x, ref_x, y, preference_labels, kl=None): + return self.KTO_loss( + weight=self.lin.weight, + _input=x, + target=y, + bias=self.lin.bias, + ref_input=ref_x, + ref_weight=self.ref_lin.weight, + ref_bias=self.ref_lin.bias, + preference_labels=preference_labels, + kl=kl, + ) + + +class LigerLMHeadKTO(torch.nn.Module): + def __init__( + self, + H: int, + V: int, + dtype: torch.dtype, + bias: bool = False, + ref_bias: bool = False, + ignore_index: int = -100, + beta: float = 0.1, + ): + super().__init__() + self.lin = torch.nn.Linear(in_features=H, out_features=V, bias=bias, dtype=dtype) + self.ref_lin = torch.nn.Linear(in_features=H, out_features=V, bias=ref_bias, dtype=dtype) + self.KTO_loss = LigerFusedLinearKTOLoss( + ignore_index=ignore_index, + beta=beta, + use_ref_model=True, + ) + + def forward(self, x, ref_x, y, preference_labels, kl=None): + return self.KTO_loss( + _input=x, + lin_weight=self.lin.weight, + target=y, + preference_labels=preference_labels, + bias=self.lin.bias, + ref_input=ref_x, + ref_weight=self.ref_lin.weight, + ref_bias=self.ref_lin.bias, + kl=kl, + ) + + +@pytest.mark.parametrize( + "B, T, H, V", + [ + (8, 128, 1024, 4096), + (3, 47, 31, 123), # random shape + ], +) +@pytest.mark.parametrize( + "scalar, dtype, atol, rtol", + [ + (1.0, torch.bfloat16, 5e-2, 5e-1), + (1.0, torch.float32, 1e-5, 5e-4), + ], +) +@pytest.mark.parametrize("bias", [True, False]) +@pytest.mark.parametrize("ref_bias", [True, False]) +@pytest.mark.parametrize("ignore_index, beta", [(-100, 0.1), (42, 0.2)]) +def test_correctness(B, T, H, V, scalar, dtype, atol, rtol, bias, ref_bias, ignore_index, beta): + # Preference labels shape: [B] + # Create binary preference labels (0 or 1) for each sequence in the batch + # Used to indicate preferred sequences (1) vs non-preferred sequences (0) + preference_labels = torch.randint(2, (B,), dtype=torch.bool, device=device, requires_grad=False) + + # Precomputed KL divergence between policy and reference distributions + kl = torch.randn(1, device=device, dtype=dtype) + + torch_lm_head_KTO = TorchLMHeadKTO( + H=H, + V=V, + dtype=dtype, + bias=bias, + ref_bias=ref_bias, + ignore_index=ignore_index, + beta=beta, + ) + liger_lm_head_KTO = LigerLMHeadKTO( + H=H, + V=V, + dtype=dtype, + bias=bias, + ref_bias=ref_bias, + ignore_index=ignore_index, + beta=beta, + ) + + torch_lm_head_KTO.lin.weight.data = liger_lm_head_KTO.lin.weight.data = torch.randn( + V, H, device=device, dtype=dtype + ) + torch_lm_head_KTO.ref_lin.weight.data = liger_lm_head_KTO.ref_lin.weight.data = torch.randn( + V, H, device=device, dtype=dtype + ) + + if bias: + torch_lm_head_KTO.lin.bias.data = liger_lm_head_KTO.lin.bias.data = torch.randn(V, device=device, dtype=dtype) + if ref_bias: + torch_lm_head_KTO.ref_lin.bias.data = liger_lm_head_KTO.ref_lin.bias.data = torch.randn( + V, device=device, dtype=dtype + ) + + _input = torch.randn(B, T, H, device=device, dtype=dtype) * scalar + input1 = _input.detach().clone().requires_grad_(True) + input2 = _input.detach().clone().requires_grad_(True) + + ref_input = torch.randn(B, T, H, device=device, dtype=dtype, requires_grad=False) * scalar + + target = torch.randint( + 0, + V, + ( + B, + T, + ), + device=device, + dtype=torch.long, + ) + # Assign some random number of elements as ignore_index + num_elements_to_assign = torch.randint(1, B * T // 2, (1,)).item() + indices_to_assign = torch.randperm(B * T)[:num_elements_to_assign] + target.view(-1)[indices_to_assign] = ignore_index + + loss1 = torch_lm_head_KTO(x=input1, ref_x=ref_input, y=target, preference_labels=preference_labels, kl=kl) + loss2 = liger_lm_head_KTO(x=input2, ref_x=ref_input, y=target, preference_labels=preference_labels, kl=kl) + + assert_verbose_allclose(loss1, loss2, atol=atol, rtol=rtol) + + loss1.backward() + loss2.backward() + + assert_verbose_allclose(input1, input2, atol=atol, rtol=rtol) + assert_verbose_allclose(torch_lm_head_KTO.lin.weight, liger_lm_head_KTO.lin.weight, atol=atol, rtol=rtol) + + if bias: + assert_verbose_allclose(torch_lm_head_KTO.lin.bias, liger_lm_head_KTO.lin.bias, atol=atol, rtol=rtol) + + assert_verbose_allclose(input1.grad, input2.grad, atol=atol, rtol=rtol) + assert_verbose_allclose( + torch_lm_head_KTO.lin.weight.grad, + liger_lm_head_KTO.lin.weight.grad, + atol=atol, + rtol=rtol, + ) + if bias: + assert_verbose_allclose( + torch_lm_head_KTO.lin.bias.grad, + liger_lm_head_KTO.lin.bias.grad, + atol=atol, + rtol=rtol, + ) + + +@pytest.mark.parametrize( + "B, T, H, V", + [ + (2, 2, 8, 8), + (3, 47, 31, 123), # random shape + ], +) +@pytest.mark.parametrize( + "scalar, dtype, atol, rtol", + [ + (1.0, torch.bfloat16, 5e-2, 5e-1), + (1.0, torch.float32, 1e-5, 5e-4), + ], +) +@pytest.mark.parametrize("bias", [True, False]) +@pytest.mark.parametrize("ref_bias", [True, False]) +def test_correctness_functional(B, T, H, V, scalar, dtype, atol, rtol, bias, ref_bias): + # Preference labels shape: [B] + # Create binary preference labels (0 or 1) for each sequence in the batch + # Used to indicate preferred sequences (1) vs non-preferred sequences (0) + preference_labels = torch.randint(2, (B,), dtype=torch.bool, device=device) + + # Precomputed KL divergence between policy and reference distributions + kl = torch.randn(1, device=device, dtype=dtype) + + _input = torch.randn(B, T, H, device=device, dtype=dtype) * scalar + input1 = _input.detach().clone().requires_grad_(True) + input2 = _input.detach().clone().requires_grad_(True) + + ref_input = torch.randn(B, T, H, device=device, dtype=dtype, requires_grad=False) * scalar + + target = torch.randint( + 0, + V, + ( + B, + T, + ), + device=device, + dtype=torch.long, + ) + + _weight = torch.randn(V, H, device=device, dtype=dtype) + weight1 = _weight.detach().clone().requires_grad_(True) + weight2 = _weight.detach().clone().requires_grad_(True) + + _ref_weight = torch.randn(V, H, device=device, dtype=dtype) + ref_weight1 = _ref_weight.detach().clone().requires_grad_(True) + ref_weight2 = _ref_weight.detach().clone().requires_grad_(True) + + _bias = torch.randn(V, device=device, dtype=dtype) if bias else None + bias1 = _bias.detach().clone().requires_grad_(True) if bias else None + bias2 = _bias.detach().clone().requires_grad_(True) if bias else None + + _ref_bias = torch.randn(V, device=device, dtype=dtype) if ref_bias else None + ref_bias1 = _ref_bias.detach().clone().requires_grad_(True) if ref_bias else None + ref_bias2 = _ref_bias.detach().clone().requires_grad_(True) if ref_bias else None + + loss1 = LigerFusedLinearKTOFunction.apply( + input1, + weight1, + target, + preference_labels, + bias1, + ref_input, + ref_weight1, + ref_bias1, + kl, + ) + loss2 = liger_fused_linear_kto( + input2, + weight2, + target, + preference_labels, + bias2, + ref_input, + ref_weight2, + ref_bias2, + kl, + ) + + assert_verbose_allclose(loss1, loss2, atol=atol, rtol=rtol) + + loss1.backward() + loss2.backward() + + assert_verbose_allclose(input1.grad, input2.grad, atol=atol, rtol=rtol) + assert_verbose_allclose(weight1.grad, weight2.grad, atol=atol, rtol=rtol) + if bias: + assert_verbose_allclose(bias1.grad, bias2.grad, atol=atol, rtol=rtol) diff --git a/test/utils.py b/test/utils.py index 31294cc09..3fcb07b71 100644 --- a/test/utils.py +++ b/test/utils.py @@ -386,12 +386,15 @@ def __init__( beta: float = 0.1, ignore_index: int = -100, use_ref_model: bool = False, + unpaired: bool = False, compute_nll_loss: bool = True, + **kwargs, ): self.alpha = alpha self.beta = beta self.ignore_index = ignore_index self.use_ref_model = use_ref_model + self.unpaired = unpaired self.compute_nll_loss = compute_nll_loss @abstractmethod @@ -431,22 +434,33 @@ def get_batch_logps( def get_ref_logps( self, - _input: torch.FloatTensor, + ref_input: torch.FloatTensor, ref_weight: torch.FloatTensor, target: torch.LongTensor, ref_bias: torch.FloatTensor, average_log_prob: bool = True, + preference_labels: torch.Tensor = None, ): """Compute the log probabilities of the given labels under the given reference model.""" - ref_logits = _input @ ref_weight.t() - if ref_bias is not None: - ref_logits = ref_logits + ref_bias - ref_all_logps = self.get_batch_logps(ref_logits, target, average_log_prob=average_log_prob) - return ( - ref_all_logps[: _input.shape[0] // 2], - ref_all_logps[_input.shape[0] // 2 :], - ) + with torch.no_grad(): + ref_logits = ref_input @ ref_weight.t() + if ref_bias is not None: + ref_logits = ref_logits + ref_bias + ref_all_logps = self.get_batch_logps(ref_logits, target, average_log_prob=average_log_prob) + + if self.unpaired and preference_labels is not None: + # Split based on preference labels + return ( + ref_all_logps[preference_labels], + ref_all_logps[~preference_labels], + ) + else: + # Original paired behavior - split in half + return ( + ref_all_logps[: ref_input.shape[0] // 2], + ref_all_logps[ref_input.shape[0] // 2 :], + ) def concatenated_forward( self, @@ -455,6 +469,7 @@ def concatenated_forward( target: torch.LongTensor, bias: torch.FloatTensor | None = None, average_log_prob: bool = True, + preference_labels: torch.Tensor = None, nll_target: torch.LongTensor | None = None, ) -> Tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]: """Run the given model on the given batch of inputs, concatenating the chosen and rejected inputs together. @@ -489,11 +504,19 @@ def cross_entropy_loss(logits, labels): average_log_prob=average_log_prob, ) - chosen_logps = all_logps[:len_chosen] - rejected_logps = all_logps[len_chosen:] - - chosen_logits = all_logits[:len_chosen] - rejected_logits = all_logits[len_chosen:] + if self.unpaired and preference_labels is not None: + # Split based on labels tensor + chosen_logps = all_logps[preference_labels] + rejected_logps = all_logps[~preference_labels] + chosen_logits = all_logits[preference_labels] + rejected_logits = all_logits[~preference_labels] + else: + # Original paired behavior - split in half + len_chosen = _input.shape[0] // 2 + chosen_logps = all_logps[:len_chosen] + rejected_logps = all_logps[len_chosen:] + chosen_logits = all_logits[:len_chosen] + rejected_logits = all_logits[len_chosen:] return ( chosen_logps, @@ -513,11 +536,12 @@ def get_batch_loss_metrics( ref_weight: torch.FloatTensor = None, ref_bias: torch.FloatTensor = None, average_log_prob: bool = True, + preference_labels: torch.Tensor = None, nll_target: torch.LongTensor = None, + **loss_kwargs, ): """Compute the loss metrics for the given batch of inputs for train or test.""" - - forward_output = self.concatenated_forward(_input, weight, target, bias, average_log_prob, nll_target) + forward_output = self.concatenated_forward(_input, weight, target, bias, average_log_prob, preference_labels, nll_target) ( policy_chosen_logps, policy_rejected_logps, @@ -526,10 +550,14 @@ def get_batch_loss_metrics( policy_nll_loss, ) = forward_output[:5] - loss_kwargs = {} if self.use_ref_model: ref_chosen_logps, ref_rejected_logps = self.get_ref_logps( - ref_input, ref_weight, target, ref_bias, average_log_prob + ref_input, + ref_weight, + target, + ref_bias, + average_log_prob, + preference_labels, ) loss_kwargs["ref_chosen_logps"] = ref_chosen_logps loss_kwargs["ref_rejected_logps"] = ref_rejected_logps @@ -538,16 +566,20 @@ def get_batch_loss_metrics( losses, *aggregated_aux_outputs = alignment_loss_outputs else: losses, aggregated_aux_outputs = alignment_loss_outputs, [] - # full loss + loss = policy_nll_loss * self.alpha + losses.mean() - return_vars = ( - policy_chosen_logps, - policy_rejected_logps, - policy_chosen_logits.detach().mean(), - policy_rejected_logits.detach().mean(), - policy_nll_loss, - ) - return loss, (*return_vars, *aggregated_aux_outputs) + + if not self.unpaired: + return_vars = ( + policy_chosen_logps, + policy_rejected_logps, + policy_chosen_logits.detach().mean(), + policy_rejected_logits.detach().mean(), + policy_nll_loss, + ) + return loss, (*return_vars, *aggregated_aux_outputs) + else: + return loss class HFDistillationLoss: