From 6bc8f335b09a7aca917724f59c0d3dfa28024eca Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Fri, 31 Jan 2025 13:17:44 +0100 Subject: [PATCH 01/15] init --- src/liger_kernel/chunked_loss/__init__.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/liger_kernel/chunked_loss/__init__.py b/src/liger_kernel/chunked_loss/__init__.py index 4f76ab79d..6069ecbae 100644 --- a/src/liger_kernel/chunked_loss/__init__.py +++ b/src/liger_kernel/chunked_loss/__init__.py @@ -4,3 +4,4 @@ from liger_kernel.chunked_loss.kto_loss import LigerFusedLinearKTOLoss # noqa: F401 from liger_kernel.chunked_loss.orpo_loss import LigerFusedLinearORPOLoss # noqa: F401 from liger_kernel.chunked_loss.simpo_loss import LigerFusedLinearSimPOLoss # noqa: F401 +from liger_kernel.chunked_loss.grpo_loss import LigerFusedLinearGRPOLoss # noqa: F401 \ No newline at end of file From da150c044dc49a3278e6c17641bd7ef5703f7ee1 Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Fri, 31 Jan 2025 22:54:13 +0100 Subject: [PATCH 02/15] initial LigerFusedLinearRLHFBase --- src/liger_kernel/chunked_loss/__init__.py | 2 +- .../chunked_loss/fused_linear_rlhf.py | 238 ++++++++++++++++++ src/liger_kernel/chunked_loss/grpo_loss.py | 199 +++++++++++++++ 3 files changed, 438 insertions(+), 1 deletion(-) create mode 100644 src/liger_kernel/chunked_loss/fused_linear_rlhf.py create mode 100644 src/liger_kernel/chunked_loss/grpo_loss.py diff --git a/src/liger_kernel/chunked_loss/__init__.py b/src/liger_kernel/chunked_loss/__init__.py index 6069ecbae..06c49d6b3 100644 --- a/src/liger_kernel/chunked_loss/__init__.py +++ b/src/liger_kernel/chunked_loss/__init__.py @@ -1,7 +1,7 @@ from liger_kernel.chunked_loss.cpo_loss import LigerFusedLinearCPOLoss # noqa: F401 from liger_kernel.chunked_loss.dpo_loss import LigerFusedLinearDPOLoss # noqa: F401 +from liger_kernel.chunked_loss.grpo_loss import LigerFusedLinearGRPOLoss # noqa: F401 from liger_kernel.chunked_loss.jsd_loss import LigerFusedLinearJSDLoss # noqa: F401 from liger_kernel.chunked_loss.kto_loss import LigerFusedLinearKTOLoss # noqa: F401 from liger_kernel.chunked_loss.orpo_loss import LigerFusedLinearORPOLoss # noqa: F401 from liger_kernel.chunked_loss.simpo_loss import LigerFusedLinearSimPOLoss # noqa: F401 -from liger_kernel.chunked_loss.grpo_loss import LigerFusedLinearGRPOLoss # noqa: F401 \ No newline at end of file diff --git a/src/liger_kernel/chunked_loss/fused_linear_rlhf.py b/src/liger_kernel/chunked_loss/fused_linear_rlhf.py new file mode 100644 index 000000000..f7d3f2a0b --- /dev/null +++ b/src/liger_kernel/chunked_loss/fused_linear_rlhf.py @@ -0,0 +1,238 @@ +from abc import abstractmethod +from functools import partial + +import torch +import torch.nn.functional as F + + +class LigerFusedLinearRLHFBase(torch.autograd.Function): + @abstractmethod + def preference_loss_fn(*args, **kwargs): + """ + To be extended by subclasses. + """ + raise NotImplementedError("Preference loss function must be implemented.") + + @staticmethod + def forward( + ctx, + _input, + weight, + rewards, + attention_mask, + bias=None, + loss_fn=None, + chunk_size=1, + beta=0.1, + compiled=True, + use_ref_model=False, + ref_input=None, + ref_weight=None, + ref_bias=None, + ): + """ + Base class for fused linear layer with RLHF loss. + Expects _input to contain the policy model inputs. + + Args: + _input (torch.Tensor): Input tensor. Shape: (batch_size, seq_len, hidden_size). + weight (torch.Tensor): Weight tensor. Shape: (vocab_size, hidden_size). + rewards (torch.Tensor): Rewards tensor. Shape: (batch_size,). + attention_mask (torch.Tensor): Attention mask. Shape: (batch_size, seq_len). + bias (torch.Tensor, optional): Bias tensor. Shape: (vocab_size,). + loss_fn (callable): Loss function to compute the loss on a chunk of input. + chunk_size (int): Size of chunks to process. + beta (float): Weight for KL penalty. + compiled (bool): Whether to use torch compile for chunk accumulation. + use_ref_model (bool): Whether to use a reference model. + ref_input (torch.Tensor): Reference model input tensor. + ref_weight (torch.Tensor): Reference model weight tensor. + ref_bias (torch.Tensor, optional): Reference model bias tensor. + """ + # TODO: Tune CHUNK_SIZE to fully utilize the GPU + CHUNK_SIZE = chunk_size + + # Gradients to be accumulated + grad_weight = torch.zeros_like(weight) + grad_inputs = [] + grad_bias = torch.zeros_like(bias) if bias is not None else None + + # Loss to be accumulated + loss_acc = torch.zeros((), device=_input.device) + + compute_loss = partial( + LigerFusedLinearRLHFBase._compute_loss, + preference_loss_fn=loss_fn, + beta=beta, + use_ref_model=use_ref_model, + ref_weight=ref_weight, + ref_bias=ref_bias, + rewards=rewards, + ) + + def fused_fwd_bwd(input_chunk, attention_mask_chunk, ref_input_chunk): + """ + Fused forward and backward pass for a chunk of input. + """ + if bias is not None: + return torch.func.grad_and_value(compute_loss, argnums=(0, 1, 4), has_aux=True)( + input_chunk, + weight, + attention_mask_chunk, + bias, + ref_input_chunk=ref_input_chunk, + ) + else: + return torch.func.grad_and_value(compute_loss, argnums=(0, 1), has_aux=True)( + input_chunk, + weight, + attention_mask_chunk, + ref_input_chunk=ref_input_chunk, + ) + + def accumulate_chunk(input_chunk, attention_mask_chunk, ref_input_chunk=None): + if bias is not None: + (chunk_grad_input, chunk_grad_weight, chunk_grad_bias), (chunk_loss, _) = fused_fwd_bwd( + input_chunk, attention_mask_chunk, ref_input_chunk + ) + grad_bias.add_(chunk_grad_bias) + else: + (chunk_grad_input, chunk_grad_weight), (chunk_loss, _) = fused_fwd_bwd( + input_chunk, attention_mask_chunk, ref_input_chunk + ) + + # Accumulate gradients + grad_weight.add_(chunk_grad_weight) + grad_inputs.append(chunk_grad_input) + + # Accumulate loss + loss_acc.add_(chunk_loss) + + if compiled: + fused_fwd_bwd = torch.compile(fused_fwd_bwd) + + chunks = max(1, _input.shape[0] // CHUNK_SIZE) + _input_chunks = torch.chunk(_input, chunks=chunks, dim=0) + _attention_mask_chunks = torch.chunk(attention_mask, chunks=chunks, dim=0) + + if use_ref_model: + _ref_input_chunks = torch.chunk(ref_input, chunks=chunks, dim=0) + + for input_chunk, attention_mask_chunk, ref_input_chunk in zip( + _input_chunks, + _attention_mask_chunks, + (_ref_input_chunks if use_ref_model else [None] * len(_input_chunks)), + strict=False, + ): + # Mark dynamic dimensions to prevent recompilation + torch._dynamo.mark_dynamic(input_chunk, 1) + torch._dynamo.mark_dynamic(attention_mask_chunk, 1) + torch._dynamo.mark_dynamic(ref_input_chunk, 1) if use_ref_model else None + + # Accumulate loss and gradients + accumulate_chunk(input_chunk, attention_mask_chunk, ref_input_chunk) + + # Combine gradients + grad_input = torch.cat(grad_inputs, dim=0) + + ctx.save_for_backward(grad_input, grad_weight, grad_bias) + return loss_acc, () + + @staticmethod + def backward(ctx, *grad_output): + grad_input, grad_weight, grad_bias = ctx.saved_tensors + if torch.ne(grad_output[0][0], torch.tensor(1.0, device=grad_output[0][0].device)): + grad_input = grad_input * grad_output[0][0] + grad_weight = grad_weight * grad_output[0][0] + grad_bias = grad_bias * grad_output[0][0] if grad_bias is not None else None + + return ( + grad_input, # grad_input + grad_weight, # grad_weight + None, # grad_rewards + None, # grad_attention_mask + grad_bias, # grad_bias + None, # grad_loss_fn + None, # grad_chunk_size + None, # grad_beta + None, # grad_compiled + None, # grad_use_ref_model + None, # grad_ref_input + None, # grad_ref_weight + None, # grad_ref_bias + ) + + @staticmethod + def _compute_loss( + input_chunk, + weight, + rewards, + attention_mask_chunk, + bias=None, + preference_loss_fn=None, + beta=0.1, + use_ref_model=False, + ref_input_chunk=None, + ref_weight=None, + ref_bias=None, + ): + """ + Compute the total loss for a chunk of input, using an RLHF loss function. + Args: + input_chunk: Policy model hidden states (batch_size, seq_len, hidden_size) + weight: Linear layer weights (vocab_size, hidden_size) + attention_mask_chunk: Attention mask (batch_size, seq_len) + bias: Optional linear layer bias (vocab_size,) + preference_loss_fn: Loss function (e.g. GRPO loss) + beta: KL penalty weight + rewards: Rewards for advantage computation + use_ref_model: Whether to use reference model + ref_input_chunk: Reference model hidden states + ref_weight: Reference model weights + ref_bias: Reference model bias + """ + # Get policy logits and log probs + batch_size, seq_len, hidden_size = input_chunk.shape + input_reshaped = input_chunk.view(-1, hidden_size) + logits = (input_reshaped @ weight.t()).view(batch_size, seq_len, -1) + if bias is not None: + logits = logits + bias + log_probs = F.log_softmax(logits, dim=-1) + + # Get sequence-level log probs by taking max over vocab + seq_log_probs = log_probs.max(dim=-1).values + + # Get reference model log probs if needed + ref_seq_log_probs = None + if use_ref_model and ref_input_chunk is not None and ref_weight is not None: + with torch.no_grad(): + ref_input_reshaped = ref_input_chunk.view(-1, ref_input_chunk.size(-1)) + ref_logits = (ref_input_reshaped @ ref_weight.t()).view(batch_size, seq_len, -1) + if ref_bias is not None: + ref_logits = ref_logits + ref_bias + ref_log_probs = F.log_softmax(ref_logits, dim=-1) + ref_seq_log_probs = ref_log_probs.max(dim=-1).values + + # Compute KL divergence if using reference model + kl_div = None + if use_ref_model and ref_seq_log_probs is not None: + kl_div = seq_log_probs - ref_seq_log_probs + + # Compute loss using the provided loss function + loss = preference_loss_fn( + seq_log_probs=seq_log_probs, + ref_seq_log_probs=ref_seq_log_probs, + attention_mask=attention_mask_chunk, + rewards=rewards, + beta=beta, + ) + + # Return metrics for logging + metrics = ( + seq_log_probs.mean(), # policy log probs mean + seq_log_probs.std(), # policy log probs std + logits.mean(), # policy logits mean + kl_div.mean() if kl_div is not None else torch.tensor(0.0, device=loss.device), # KL divergence mean + ) + + return loss, metrics diff --git a/src/liger_kernel/chunked_loss/grpo_loss.py b/src/liger_kernel/chunked_loss/grpo_loss.py new file mode 100644 index 000000000..e3d54369b --- /dev/null +++ b/src/liger_kernel/chunked_loss/grpo_loss.py @@ -0,0 +1,199 @@ +import torch +import torch.nn.functional as F + +from liger_kernel.chunked_loss.fused_linear_rlhf import LigerFusedLinearRLHFBase + + +class LigerFusedLinearGRPOFunction(LigerFusedLinearRLHFBase): + @staticmethod + def preference_loss_fn( + logits, + attention_mask, + rewards, + ref_logits=None, + beta=0.1, + **kwargs, + ): + """ + GRPO Loss Function as implemented in GRPOTrainer. + + Args: + logits: Model logits (batch_size, seq_len, vocab_size) + attention_mask: Attention mask (batch_size, seq_len) + rewards: Rewards for each sequence (batch_size,) + ref_logits: Reference model logits (batch_size, seq_len, vocab_size) or None + beta: Weight for KL penalty + """ + # Get log probabilities for policy + log_probs = F.log_softmax(logits, dim=-1) + + # Get sequence-level log probs by taking max over vocab and summing over sequence + policy_seq_logps = (log_probs.max(dim=-1).values * attention_mask).sum(dim=-1) + + # Get reference model log probabilities if provided + if ref_logits is not None: + with torch.no_grad(): + ref_log_probs = F.log_softmax(ref_logits, dim=-1) + ref_seq_logps = (ref_log_probs.max(dim=-1).values * attention_mask).sum(dim=-1) + else: + ref_seq_logps = policy_seq_logps.detach() + + # Compute advantages + advantages = rewards - rewards.mean() + if advantages.std() > 0: + advantages = advantages / advantages.std() + + # Policy gradient loss + policy_loss = -(advantages * policy_seq_logps) + + # KL penalty + kl_div = policy_seq_logps - ref_seq_logps + + # Total loss + loss = policy_loss + beta * kl_div + + return loss.mean() + + @staticmethod + def forward( + ctx, + _input, + weight, + attention_mask, + rewards, + bias=None, + ref_input=None, + ref_weight=None, + ref_bias=None, + beta=0.1, + compiled=True, + use_ref_model=True, + ): + """Forward pass for GRPO loss.""" + # Save tensors needed for backward + ctx.save_for_backward(_input, weight, attention_mask, bias) + ctx.beta = beta + ctx.rewards = rewards # Save rewards for use in backward pass + + # Get policy logits + batch_size, seq_len, hidden_size = _input.shape + input_reshaped = _input.view(-1, hidden_size) + policy_logits = (input_reshaped @ weight.t()).view(batch_size, seq_len, -1) + if bias is not None: + policy_logits = policy_logits + bias + + # Get reference logits if needed + ref_logits = None + if use_ref_model and ref_input is not None and ref_weight is not None: + ref_input_reshaped = ref_input.view(-1, ref_input.size(-1)) + ref_logits = (ref_input_reshaped @ ref_weight.t()).view(batch_size, seq_len, -1) + if ref_bias is not None: + ref_logits = ref_logits + ref_bias + + # Compute loss + loss = LigerFusedLinearGRPOFunction.preference_loss_fn( + logits=policy_logits, + attention_mask=attention_mask, + rewards=rewards, + ref_logits=ref_logits, + beta=beta, + ) + + return loss, () + + @staticmethod + def backward(ctx, grad_output, *args): + _input, weight, attention_mask, bias = ctx.saved_tensors + beta = ctx.beta # Retrieve beta for KL scaling + rewards = ctx.rewards # Retrieve rewards for advantage computation + + # Initialize gradients + grad_input = grad_weight = grad_bias = None + + # Compute gradients using autograd + with torch.enable_grad(): + _input = _input.detach().requires_grad_() + weight = weight.detach().requires_grad_() + if bias is not None: + bias = bias.detach().requires_grad_() + + # Forward pass + batch_size, seq_len, hidden_size = _input.shape + input_reshaped = _input.view(-1, hidden_size) + logits = (input_reshaped @ weight.t()).view(batch_size, seq_len, -1) + if bias is not None: + logits = logits + bias + + # Compute log probabilities and sequence-level scores + log_probs = F.log_softmax(logits, dim=-1) + seq_log_probs = (log_probs.max(dim=-1).values * attention_mask).sum(dim=-1) + + # Compute advantages + advantages = rewards - rewards.mean() + if advantages.std() > 0: + advantages = advantages / advantages.std() + + # Policy gradient loss with KL penalty + policy_loss = -(advantages * seq_log_probs) + kl_div = seq_log_probs - seq_log_probs.detach() # KL divergence from current policy + loss = policy_loss + beta * kl_div # Apply beta scaling to KL term + + # Backward pass + loss.backward(grad_output) + grad_input = _input.grad + grad_weight = weight.grad + grad_bias = bias.grad if bias is not None else None + + return ( + grad_input, + grad_weight, + None, # grad_attention_mask + None, # grad_rewards + grad_bias, + None, # grad_ref_input + None, # grad_ref_weight + None, # grad_ref_bias + None, # grad_beta + None, # grad_compiled + None, # grad_use_ref_model + ) + + +class LigerFusedLinearGRPOLoss(torch.nn.Module): + """Fused linear layer with GRPO loss.""" + + def __init__( + self, + beta: float = 0.1, + compiled: bool = True, + use_ref_model: bool = True, + ): + super().__init__() + self.beta = beta + self.compiled = compiled + self.use_ref_model = use_ref_model + + def forward( + self, + lin_weight, + _input, + attention_mask, + rewards, + bias=None, + ref_input=None, + ref_weight=None, + ref_bias=None, + ): + return LigerFusedLinearGRPOFunction.apply( + _input, + lin_weight, + attention_mask, + rewards, + bias, + ref_input, + ref_weight, + ref_bias, + self.beta, + self.compiled, + self.use_ref_model, + ) From 8891c833d98fb2625b8a04435678f70a45b69a5e Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Fri, 31 Jan 2025 23:10:57 +0100 Subject: [PATCH 03/15] add tests --- test/chunked_loss/test_grpo_loss.py | 217 ++++++++++++++++++++++++++++ 1 file changed, 217 insertions(+) create mode 100644 test/chunked_loss/test_grpo_loss.py diff --git a/test/chunked_loss/test_grpo_loss.py b/test/chunked_loss/test_grpo_loss.py new file mode 100644 index 000000000..e5e04c19e --- /dev/null +++ b/test/chunked_loss/test_grpo_loss.py @@ -0,0 +1,217 @@ +import pytest +import torch +import torch.nn.functional as F + +from liger_kernel.chunked_loss.grpo_loss import LigerFusedLinearGRPOFunction +from liger_kernel.utils import infer_device +from test.utils import assert_verbose_allclose +from test.utils import set_seed + +device = infer_device() + +# set random seed globally +set_seed() + + +class TorchLMHeadGRPO(torch.nn.Module): + def __init__( + self, + H: int, + V: int, + dtype: torch.dtype, + bias: bool = False, + beta: float = 0.1, + ): + super().__init__() + self.lin = torch.nn.Linear(in_features=H, out_features=V, bias=bias, dtype=dtype) + self.beta = beta + + def forward( + self, + x, + attention_mask, + rewards, + ref_input=None, + ref_weight=None, + ref_bias=None, + ): + # Get policy logits and log probs + batch_size, seq_len, hidden_size = x.shape + input_reshaped = x.view(-1, hidden_size) + logits = (input_reshaped @ self.lin.weight.t()).view(batch_size, seq_len, -1) + if self.lin.bias is not None: + logits = logits + self.lin.bias + log_probs = F.log_softmax(logits, dim=-1) + + # Get sequence-level log probs by taking max over vocab and summing over sequence + seq_log_probs = log_probs.max(dim=-1).values + policy_seq_logps = (seq_log_probs * attention_mask).sum(dim=-1) + + # Get reference model log probs if provided + if ref_input is not None and ref_weight is not None: + with torch.no_grad(): + ref_input_reshaped = ref_input.view(-1, ref_input.size(-1)) + ref_logits = (ref_input_reshaped @ ref_weight.t()).view(batch_size, seq_len, -1) + if ref_bias is not None: + ref_logits = ref_logits + ref_bias + ref_log_probs = F.log_softmax(ref_logits, dim=-1) + ref_seq_log_probs = ref_log_probs.max(dim=-1).values + ref_seq_logps = (ref_seq_log_probs * attention_mask).sum(dim=-1) + else: + ref_seq_logps = policy_seq_logps.detach() + + # Compute advantages + advantages = rewards - rewards.mean() + if advantages.std() > 0: + advantages = advantages / advantages.std() + + # Policy gradient loss + policy_loss = -(advantages * policy_seq_logps) + + # KL penalty + kl_div = policy_seq_logps - ref_seq_logps + + # Total loss + loss = policy_loss + self.beta * kl_div + + # Return metrics for logging + metrics = ( + policy_seq_logps.mean(), # policy log probs mean + policy_seq_logps.std(), # policy log probs std + logits.mean(), # policy logits mean + kl_div.mean(), # KL divergence mean + ) + + return loss.mean(), metrics + + +class LigerLMHeadGRPO(torch.nn.Module): + def __init__( + self, + H: int, + V: int, + dtype: torch.dtype, + bias: bool = False, + beta: float = 0.1, + ): + super().__init__() + self.lin = torch.nn.Linear(in_features=H, out_features=V, bias=bias, dtype=dtype) + self.grpo_loss = LigerFusedLinearGRPOFunction.apply + + def forward( + self, + x, + attention_mask, + rewards, + ref_input=None, + ref_weight=None, + ref_bias=None, + ): + return self.grpo_loss( + x, + self.lin.weight, + rewards, + attention_mask, + self.lin.bias, + loss_fn=LigerFusedLinearGRPOFunction.preference_loss_fn, + chunk_size=1, + beta=self.beta, + compiled=True, + use_ref_model=ref_input is not None, + ref_input=ref_input, + ref_weight=ref_weight, + ref_bias=ref_bias, + ) + + +@pytest.mark.parametrize( + "B, T, H, V", + [ + (8, 128, 1024, 4096), + (3, 47, 31, 123), # random shape + ], +) +@pytest.mark.parametrize( + "scalar, dtype, atol, rtol", + [ + (1.0, torch.bfloat16, 5e-2, 5e-2), + (1.0, torch.float32, 1e-5, 5e-4), + ], +) +@pytest.mark.parametrize("bias", [True, False]) +@pytest.mark.parametrize("beta", [0.1, 0.2]) +def test_correctness( + B, + T, + H, + V, + scalar, + dtype, + atol, + rtol, + bias, + beta, +): + torch_lm_head_grpo = TorchLMHeadGRPO( + H=H, + V=V, + dtype=dtype, + bias=bias, + beta=beta, + ) + liger_lm_head_grpo = LigerLMHeadGRPO( + H=H, + V=V, + dtype=dtype, + bias=bias, + beta=beta, + ) + + # Initialize weights + torch_lm_head_grpo.lin.weight.data = liger_lm_head_grpo.lin.weight.data = torch.randn( + V, H, device=device, dtype=dtype + ) + if bias: + torch_lm_head_grpo.lin.bias.data = liger_lm_head_grpo.lin.bias.data = torch.randn(V, device=device, dtype=dtype) + + # Create inputs + _input = torch.randn(B, T, H, device=device, dtype=dtype) * scalar + input1 = _input.detach().clone().requires_grad_(True) + input2 = _input.detach().clone().requires_grad_(True) + + # Create attention mask with random padding + attention_mask = torch.ones(B, T, device=device) + num_elements_to_mask = torch.randint(1, B * T // 2, (1,)).item() + mask_indices = torch.randperm(B * T)[:num_elements_to_mask] + attention_mask.view(-1)[mask_indices] = 0 + + # Create rewards + rewards = torch.randn(B, device=device, dtype=dtype) + + # Forward pass + loss1, aux1 = torch_lm_head_grpo(input1, attention_mask, rewards) + loss2, aux2 = liger_lm_head_grpo(input2, attention_mask, rewards) + + # Check losses match + assert_verbose_allclose(loss1, loss2, atol=atol, rtol=rtol) + assert len(aux1) == len(aux2) + + # Backward pass + loss1.backward() + loss2.backward() + + # Check gradients match + assert_verbose_allclose(input1.grad, input2.grad, atol=atol, rtol=rtol) + assert_verbose_allclose( + torch_lm_head_grpo.lin.weight.grad, + liger_lm_head_grpo.lin.weight.grad, + atol=atol, + rtol=rtol, + ) + if bias: + assert_verbose_allclose( + torch_lm_head_grpo.lin.bias.grad, + liger_lm_head_grpo.lin.bias.grad, + atol=atol, + rtol=rtol, + ) From 4fce4126ba87ea6a7118d80de28845088b3e12a3 Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Fri, 31 Jan 2025 23:22:11 +0100 Subject: [PATCH 04/15] add tests --- test/chunked_loss/test_grpo_loss.py | 26 +++++++++++++------------- 1 file changed, 13 insertions(+), 13 deletions(-) diff --git a/test/chunked_loss/test_grpo_loss.py b/test/chunked_loss/test_grpo_loss.py index e5e04c19e..c331dd20e 100644 --- a/test/chunked_loss/test_grpo_loss.py +++ b/test/chunked_loss/test_grpo_loss.py @@ -97,6 +97,7 @@ def __init__( super().__init__() self.lin = torch.nn.Linear(in_features=H, out_features=V, bias=bias, dtype=dtype) self.grpo_loss = LigerFusedLinearGRPOFunction.apply + self.beta = beta def forward( self, @@ -107,20 +108,19 @@ def forward( ref_weight=None, ref_bias=None, ): + # Pass only the arguments defined in LigerFusedLinearGRPOFunction.forward() return self.grpo_loss( - x, - self.lin.weight, - rewards, - attention_mask, - self.lin.bias, - loss_fn=LigerFusedLinearGRPOFunction.preference_loss_fn, - chunk_size=1, - beta=self.beta, - compiled=True, - use_ref_model=ref_input is not None, - ref_input=ref_input, - ref_weight=ref_weight, - ref_bias=ref_bias, + x, # _input + self.lin.weight, # weight + attention_mask, # attention_mask + rewards, # rewards + self.lin.bias, # bias + ref_input, # ref_input + ref_weight, # ref_weight + ref_bias, # ref_bias + self.beta, # beta + True, # compiled + ref_input is not None, # use_ref_model ) From 9befdcd8b27af86d8f2c30ded862c7ee6e562f46 Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Fri, 31 Jan 2025 23:26:57 +0100 Subject: [PATCH 05/15] fix backward --- src/liger_kernel/chunked_loss/grpo_loss.py | 36 +++++++++++++++++++--- 1 file changed, 32 insertions(+), 4 deletions(-) diff --git a/src/liger_kernel/chunked_loss/grpo_loss.py b/src/liger_kernel/chunked_loss/grpo_loss.py index e3d54369b..241f5c8bb 100644 --- a/src/liger_kernel/chunked_loss/grpo_loss.py +++ b/src/liger_kernel/chunked_loss/grpo_loss.py @@ -90,6 +90,20 @@ def forward( if ref_bias is not None: ref_logits = ref_logits + ref_bias + # Get log probabilities + log_probs = F.log_softmax(policy_logits, dim=-1) + seq_log_probs = (log_probs.max(dim=-1).values * attention_mask).sum(dim=-1) + + # Get reference log probabilities + if ref_logits is not None: + ref_log_probs = F.log_softmax(ref_logits, dim=-1) + ref_seq_logps = (ref_log_probs.max(dim=-1).values * attention_mask).sum(dim=-1) + else: + ref_seq_logps = seq_log_probs.detach() + + # Compute KL divergence + kl_div = seq_log_probs - ref_seq_logps + # Compute loss loss = LigerFusedLinearGRPOFunction.preference_loss_fn( logits=policy_logits, @@ -99,10 +113,24 @@ def forward( beta=beta, ) - return loss, () + # Return metrics matching the PyTorch implementation + metrics = ( + seq_log_probs.mean(), # policy log probs mean + seq_log_probs.std(), # policy log probs std + policy_logits.mean(), # policy logits mean + kl_div.mean(), # KL divergence mean + ) + + return loss, metrics @staticmethod - def backward(ctx, grad_output, *args): + def backward(ctx, grad_output, *grad_metrics): + """Backward pass for GRPO loss. + + Args: + grad_output: Gradient of the loss (scalar) + grad_metrics: Gradients of the metrics (not used in backward computation) + """ _input, weight, attention_mask, bias = ctx.saved_tensors beta = ctx.beta # Retrieve beta for KL scaling rewards = ctx.rewards # Retrieve rewards for advantage computation @@ -136,9 +164,9 @@ def backward(ctx, grad_output, *args): # Policy gradient loss with KL penalty policy_loss = -(advantages * seq_log_probs) kl_div = seq_log_probs - seq_log_probs.detach() # KL divergence from current policy - loss = policy_loss + beta * kl_div # Apply beta scaling to KL term + loss = (policy_loss + beta * kl_div).mean() # Take mean to get scalar loss - # Backward pass + # Backward pass with scalar gradient loss.backward(grad_output) grad_input = _input.grad grad_weight = weight.grad From 9ea4d084acb95cb844640293c263e74a0d0b7988 Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Mon, 3 Feb 2025 12:10:16 +0100 Subject: [PATCH 06/15] use the base class --- .../chunked_loss/fused_linear_rlhf.py | 279 +++++++++--------- src/liger_kernel/chunked_loss/grpo_loss.py | 185 ++++-------- test/chunked_loss/test_grpo_loss.py | 91 ++++-- 3 files changed, 251 insertions(+), 304 deletions(-) diff --git a/src/liger_kernel/chunked_loss/fused_linear_rlhf.py b/src/liger_kernel/chunked_loss/fused_linear_rlhf.py index f7d3f2a0b..caaa92e14 100644 --- a/src/liger_kernel/chunked_loss/fused_linear_rlhf.py +++ b/src/liger_kernel/chunked_loss/fused_linear_rlhf.py @@ -1,4 +1,3 @@ -from abc import abstractmethod from functools import partial import torch @@ -6,20 +5,13 @@ class LigerFusedLinearRLHFBase(torch.autograd.Function): - @abstractmethod - def preference_loss_fn(*args, **kwargs): - """ - To be extended by subclasses. - """ - raise NotImplementedError("Preference loss function must be implemented.") - @staticmethod def forward( ctx, _input, weight, - rewards, attention_mask, + rewards, bias=None, loss_fn=None, chunk_size=1, @@ -30,39 +22,21 @@ def forward( ref_weight=None, ref_bias=None, ): - """ - Base class for fused linear layer with RLHF loss. - Expects _input to contain the policy model inputs. - - Args: - _input (torch.Tensor): Input tensor. Shape: (batch_size, seq_len, hidden_size). - weight (torch.Tensor): Weight tensor. Shape: (vocab_size, hidden_size). - rewards (torch.Tensor): Rewards tensor. Shape: (batch_size,). - attention_mask (torch.Tensor): Attention mask. Shape: (batch_size, seq_len). - bias (torch.Tensor, optional): Bias tensor. Shape: (vocab_size,). - loss_fn (callable): Loss function to compute the loss on a chunk of input. - chunk_size (int): Size of chunks to process. - beta (float): Weight for KL penalty. - compiled (bool): Whether to use torch compile for chunk accumulation. - use_ref_model (bool): Whether to use a reference model. - ref_input (torch.Tensor): Reference model input tensor. - ref_weight (torch.Tensor): Reference model weight tensor. - ref_bias (torch.Tensor, optional): Reference model bias tensor. - """ - # TODO: Tune CHUNK_SIZE to fully utilize the GPU - CHUNK_SIZE = chunk_size - - # Gradients to be accumulated - grad_weight = torch.zeros_like(weight) - grad_inputs = [] - grad_bias = torch.zeros_like(bias) if bias is not None else None + """Chunked forward pass for RLHF loss computation.""" + # Save for backward + ctx.beta = beta + ctx.rewards = rewards - # Loss to be accumulated + # Initialize accumulators loss_acc = torch.zeros((), device=_input.device) + grad_weight = torch.zeros_like(weight) # [V, H] + grad_inputs = [] + grad_bias = torch.zeros_like(bias) if bias is not None else None # [V] + aggregated_metrics = [] compute_loss = partial( - LigerFusedLinearRLHFBase._compute_loss, - preference_loss_fn=loss_fn, + LigerFusedLinearRLHFBase._compute_chunk_loss, + rlhf_loss_fn=loss_fn, beta=beta, use_ref_model=use_ref_model, ref_weight=ref_weight, @@ -71,168 +45,181 @@ def forward( ) def fused_fwd_bwd(input_chunk, attention_mask_chunk, ref_input_chunk): - """ - Fused forward and backward pass for a chunk of input. - """ + """Fused forward and backward for a chunk.""" if bias is not None: return torch.func.grad_and_value(compute_loss, argnums=(0, 1, 4), has_aux=True)( - input_chunk, - weight, - attention_mask_chunk, - bias, - ref_input_chunk=ref_input_chunk, + input_chunk, weight, attention_mask_chunk, ref_input_chunk, bias ) else: return torch.func.grad_and_value(compute_loss, argnums=(0, 1), has_aux=True)( - input_chunk, - weight, - attention_mask_chunk, - ref_input_chunk=ref_input_chunk, + input_chunk, weight, attention_mask_chunk, ref_input_chunk ) def accumulate_chunk(input_chunk, attention_mask_chunk, ref_input_chunk=None): + nonlocal loss_acc, grad_weight, grad_inputs, grad_bias, aggregated_metrics + if bias is not None: - (chunk_grad_input, chunk_grad_weight, chunk_grad_bias), (chunk_loss, _) = fused_fwd_bwd( + (chunk_grad_input, chunk_grad_weight, chunk_grad_bias), (chunk_loss, chunk_metrics) = fused_fwd_bwd( input_chunk, attention_mask_chunk, ref_input_chunk ) grad_bias.add_(chunk_grad_bias) else: - (chunk_grad_input, chunk_grad_weight), (chunk_loss, _) = fused_fwd_bwd( + (chunk_grad_input, chunk_grad_weight), (chunk_loss, chunk_metrics) = fused_fwd_bwd( input_chunk, attention_mask_chunk, ref_input_chunk ) - # Accumulate gradients + # Accumulate gradients and loss grad_weight.add_(chunk_grad_weight) grad_inputs.append(chunk_grad_input) - - # Accumulate loss loss_acc.add_(chunk_loss) + # Initialize storage for metrics on first chunk + if len(aggregated_metrics) == 0: + for metric in chunk_metrics: + if metric.ndim == 0: + aggregated_metrics.append(torch.zeros((), device=metric.device)) + else: + aggregated_metrics.append([]) + + # Accumulate metrics + for i, metric in enumerate(chunk_metrics): + if metric.ndim == 0: + aggregated_metrics[i].add_(metric) + else: + aggregated_metrics[i].append(metric) + if compiled: - fused_fwd_bwd = torch.compile(fused_fwd_bwd) + accumulate_chunk = torch.compile(accumulate_chunk) - chunks = max(1, _input.shape[0] // CHUNK_SIZE) + # Process input in chunks + chunks = max(1, _input.shape[0] // chunk_size) _input_chunks = torch.chunk(_input, chunks=chunks, dim=0) _attention_mask_chunks = torch.chunk(attention_mask, chunks=chunks, dim=0) - - if use_ref_model: - _ref_input_chunks = torch.chunk(ref_input, chunks=chunks, dim=0) + _ref_input_chunks = torch.chunk(ref_input, chunks=chunks, dim=0) if use_ref_model else [None] * chunks for input_chunk, attention_mask_chunk, ref_input_chunk in zip( - _input_chunks, - _attention_mask_chunks, - (_ref_input_chunks if use_ref_model else [None] * len(_input_chunks)), - strict=False, + _input_chunks, _attention_mask_chunks, _ref_input_chunks ): - # Mark dynamic dimensions to prevent recompilation + # Mark dynamic dimensions torch._dynamo.mark_dynamic(input_chunk, 1) torch._dynamo.mark_dynamic(attention_mask_chunk, 1) - torch._dynamo.mark_dynamic(ref_input_chunk, 1) if use_ref_model else None + if ref_input_chunk is not None: + torch._dynamo.mark_dynamic(ref_input_chunk, 1) - # Accumulate loss and gradients accumulate_chunk(input_chunk, attention_mask_chunk, ref_input_chunk) # Combine gradients grad_input = torch.cat(grad_inputs, dim=0) + # Save for backward ctx.save_for_backward(grad_input, grad_weight, grad_bias) - return loss_acc, () - @staticmethod - def backward(ctx, *grad_output): - grad_input, grad_weight, grad_bias = ctx.saved_tensors - if torch.ne(grad_output[0][0], torch.tensor(1.0, device=grad_output[0][0].device)): - grad_input = grad_input * grad_output[0][0] - grad_weight = grad_weight * grad_output[0][0] - grad_bias = grad_bias * grad_output[0][0] if grad_bias is not None else None + # Finalize metrics + final_metrics = [] + for metric in aggregated_metrics: + if isinstance(metric, list): + final_metrics.append(torch.cat(metric, dim=0)) + else: + final_metrics.append(metric / chunks) - return ( - grad_input, # grad_input - grad_weight, # grad_weight - None, # grad_rewards - None, # grad_attention_mask - grad_bias, # grad_bias - None, # grad_loss_fn - None, # grad_chunk_size - None, # grad_beta - None, # grad_compiled - None, # grad_use_ref_model - None, # grad_ref_input - None, # grad_ref_weight - None, # grad_ref_bias - ) + return loss_acc, tuple(final_metrics) @staticmethod - def _compute_loss( + def _compute_chunk_loss( input_chunk, weight, - rewards, attention_mask_chunk, + ref_input_chunk=None, bias=None, - preference_loss_fn=None, + rlhf_loss_fn=None, beta=0.1, use_ref_model=False, - ref_input_chunk=None, ref_weight=None, ref_bias=None, + rewards=None, ): - """ - Compute the total loss for a chunk of input, using an RLHF loss function. - Args: - input_chunk: Policy model hidden states (batch_size, seq_len, hidden_size) - weight: Linear layer weights (vocab_size, hidden_size) - attention_mask_chunk: Attention mask (batch_size, seq_len) - bias: Optional linear layer bias (vocab_size,) - preference_loss_fn: Loss function (e.g. GRPO loss) - beta: KL penalty weight - rewards: Rewards for advantage computation - use_ref_model: Whether to use reference model - ref_input_chunk: Reference model hidden states - ref_weight: Reference model weights - ref_bias: Reference model bias - """ - # Get policy logits and log probs - batch_size, seq_len, hidden_size = input_chunk.shape - input_reshaped = input_chunk.view(-1, hidden_size) - logits = (input_reshaped @ weight.t()).view(batch_size, seq_len, -1) - if bias is not None: - logits = logits + bias - log_probs = F.log_softmax(logits, dim=-1) - - # Get sequence-level log probs by taking max over vocab - seq_log_probs = log_probs.max(dim=-1).values + """Compute loss for a single chunk.""" + # Get policy log probabilities using chunk_forward + ( + log_probs, + logits, + logits_mean, + ) = LigerFusedLinearRLHFBase.chunk_forward( + input_chunk, + weight, + attention_mask_chunk, + bias=bias, + ) - # Get reference model log probs if needed - ref_seq_log_probs = None - if use_ref_model and ref_input_chunk is not None and ref_weight is not None: + # Get reference log probabilities if needed + ref_log_probs = None + if use_ref_model and ref_input_chunk is not None: with torch.no_grad(): - ref_input_reshaped = ref_input_chunk.view(-1, ref_input_chunk.size(-1)) - ref_logits = (ref_input_reshaped @ ref_weight.t()).view(batch_size, seq_len, -1) - if ref_bias is not None: - ref_logits = ref_logits + ref_bias - ref_log_probs = F.log_softmax(ref_logits, dim=-1) - ref_seq_log_probs = ref_log_probs.max(dim=-1).values - - # Compute KL divergence if using reference model - kl_div = None - if use_ref_model and ref_seq_log_probs is not None: - kl_div = seq_log_probs - ref_seq_log_probs - - # Compute loss using the provided loss function - loss = preference_loss_fn( - seq_log_probs=seq_log_probs, - ref_seq_log_probs=ref_seq_log_probs, + ref_log_probs, _, _ = LigerFusedLinearRLHFBase.chunk_forward( + ref_input_chunk, + ref_weight, + attention_mask_chunk, + bias=ref_bias, + ) + + # Compute chunk loss and metrics + chunk_loss, chunk_metrics = rlhf_loss_fn( + log_probs=log_probs, attention_mask=attention_mask_chunk, rewards=rewards, + ref_log_probs=ref_log_probs, beta=beta, ) - # Return metrics for logging - metrics = ( - seq_log_probs.mean(), # policy log probs mean - seq_log_probs.std(), # policy log probs std - logits.mean(), # policy logits mean - kl_div.mean() if kl_div is not None else torch.tensor(0.0, device=loss.device), # KL divergence mean - ) + return chunk_loss, (logits_mean, *chunk_metrics) + + @staticmethod + def chunk_forward( + input_chunk, + weight, + attention_mask_chunk, + bias=None, + ): + """Forward pass computation for a single chunk.""" + batch_size, seq_len, hidden_size = input_chunk.shape + input_reshaped = input_chunk.view(-1, hidden_size) # [B*T, H] + + # Linear layer: [B*T, H] @ [V, H].T -> [B*T, V] + logits = F.linear(input_reshaped, weight) # weight shape is [V, H] + if bias is not None: + logits = logits + bias.view(1, -1) + + # Reshape to [B, T, V] and compute log_probs + logits = logits.view(batch_size, seq_len, -1) + log_probs = F.log_softmax(logits, dim=-1) + + # Calculate mean logits for monitoring + logits_mean = logits.sum() / (batch_size * seq_len * weight.shape[0]) - return loss, metrics + return log_probs, logits, logits_mean + + @staticmethod + def backward(ctx, grad_output, *grad_metrics): + """Backward pass for RLHF loss.""" + grad_input, grad_weight, grad_bias = ctx.saved_tensors + if grad_output != 1.0: + grad_input = grad_input * grad_output + grad_weight = grad_weight * grad_output + if grad_bias is not None: + grad_bias = grad_bias * grad_output + + return ( + grad_input, + grad_weight, + None, # grad_attention_mask + None, # grad_rewards + grad_bias, + None, # grad_loss_fn + None, # grad_chunk_size + None, # grad_beta + None, # grad_compiled + None, # grad_use_ref_model + None, # grad_ref_input + None, # grad_ref_weight + None, # grad_ref_bias + ) diff --git a/src/liger_kernel/chunked_loss/grpo_loss.py b/src/liger_kernel/chunked_loss/grpo_loss.py index 241f5c8bb..d181e2379 100644 --- a/src/liger_kernel/chunked_loss/grpo_loss.py +++ b/src/liger_kernel/chunked_loss/grpo_loss.py @@ -1,58 +1,64 @@ import torch -import torch.nn.functional as F from liger_kernel.chunked_loss.fused_linear_rlhf import LigerFusedLinearRLHFBase class LigerFusedLinearGRPOFunction(LigerFusedLinearRLHFBase): @staticmethod - def preference_loss_fn( - logits, + def rlhf_loss_fn( + log_probs, attention_mask, rewards, - ref_logits=None, + ref_log_probs=None, beta=0.1, **kwargs, ): - """ - GRPO Loss Function as implemented in GRPOTrainer. - - Args: - logits: Model logits (batch_size, seq_len, vocab_size) - attention_mask: Attention mask (batch_size, seq_len) - rewards: Rewards for each sequence (batch_size,) - ref_logits: Reference model logits (batch_size, seq_len, vocab_size) or None - beta: Weight for KL penalty - """ - # Get log probabilities for policy - log_probs = F.log_softmax(logits, dim=-1) - - # Get sequence-level log probs by taking max over vocab and summing over sequence - policy_seq_logps = (log_probs.max(dim=-1).values * attention_mask).sum(dim=-1) - - # Get reference model log probabilities if provided - if ref_logits is not None: + """GRPO Loss Function matching GRPOTrainer implementation.""" + # Get chosen token probabilities + chosen_tokens = log_probs.argmax(dim=-1) # (batch_size, seq_len) + chosen_token_logprobs = log_probs.gather(dim=-1, index=chosen_tokens.unsqueeze(-1)).squeeze( + -1 + ) # (batch_size, seq_len) + + # Get reference model probabilities + if ref_log_probs is not None: with torch.no_grad(): - ref_log_probs = F.log_softmax(ref_logits, dim=-1) - ref_seq_logps = (ref_log_probs.max(dim=-1).values * attention_mask).sum(dim=-1) + ref_token_logprobs = ref_log_probs.gather(dim=-1, index=chosen_tokens.unsqueeze(-1)).squeeze(-1) else: - ref_seq_logps = policy_seq_logps.detach() + ref_token_logprobs = chosen_token_logprobs.detach() # Compute advantages - advantages = rewards - rewards.mean() - if advantages.std() > 0: - advantages = advantages / advantages.std() + mean_grouped_rewards = rewards.mean() + std_grouped_rewards = rewards.std() + advantages = (rewards - mean_grouped_rewards) / (std_grouped_rewards + 1e-8) - # Policy gradient loss - policy_loss = -(advantages * policy_seq_logps) + # Compute policy gradient loss with importance sampling ratio + ratio = torch.exp(chosen_token_logprobs - chosen_token_logprobs.detach()) + policy_loss = -ratio * advantages.unsqueeze(1) - # KL penalty - kl_div = policy_seq_logps - ref_seq_logps + # Compute KL penalty + kl_div = ( + torch.exp(ref_token_logprobs - chosen_token_logprobs) - (ref_token_logprobs - chosen_token_logprobs) - 1.0 + ) - # Total loss - loss = policy_loss + beta * kl_div + # Combine losses + per_token_loss = policy_loss + beta * kl_div - return loss.mean() + # Apply masking and normalize + masked_loss = per_token_loss * attention_mask + seq_lengths = attention_mask.sum(dim=1, keepdim=True) + seq_lengths = torch.clamp(seq_lengths, min=1.0) + loss = (masked_loss.sum(dim=1) / seq_lengths.squeeze(-1)).mean() + + # Calculate metrics + metrics = ( + chosen_token_logprobs.mean(), # mean log prob + chosen_token_logprobs.std(), # std log prob + log_probs.mean(), # mean all log probs + (kl_div * attention_mask).sum(1).mean() / attention_mask.sum(1).mean(), # mean KL div + ) + + return loss, metrics @staticmethod def forward( @@ -69,60 +75,22 @@ def forward( compiled=True, use_ref_model=True, ): - """Forward pass for GRPO loss.""" - # Save tensors needed for backward - ctx.save_for_backward(_input, weight, attention_mask, bias) - ctx.beta = beta - ctx.rewards = rewards # Save rewards for use in backward pass - - # Get policy logits - batch_size, seq_len, hidden_size = _input.shape - input_reshaped = _input.view(-1, hidden_size) - policy_logits = (input_reshaped @ weight.t()).view(batch_size, seq_len, -1) - if bias is not None: - policy_logits = policy_logits + bias - - # Get reference logits if needed - ref_logits = None - if use_ref_model and ref_input is not None and ref_weight is not None: - ref_input_reshaped = ref_input.view(-1, ref_input.size(-1)) - ref_logits = (ref_input_reshaped @ ref_weight.t()).view(batch_size, seq_len, -1) - if ref_bias is not None: - ref_logits = ref_logits + ref_bias - - # Get log probabilities - log_probs = F.log_softmax(policy_logits, dim=-1) - seq_log_probs = (log_probs.max(dim=-1).values * attention_mask).sum(dim=-1) - - # Get reference log probabilities - if ref_logits is not None: - ref_log_probs = F.log_softmax(ref_logits, dim=-1) - ref_seq_logps = (ref_log_probs.max(dim=-1).values * attention_mask).sum(dim=-1) - else: - ref_seq_logps = seq_log_probs.detach() - - # Compute KL divergence - kl_div = seq_log_probs - ref_seq_logps - - # Compute loss - loss = LigerFusedLinearGRPOFunction.preference_loss_fn( - logits=policy_logits, + return LigerFusedLinearRLHFBase.forward( + ctx=ctx, + _input=_input, + weight=weight, attention_mask=attention_mask, + loss_fn=LigerFusedLinearGRPOFunction.rlhf_loss_fn, rewards=rewards, - ref_logits=ref_logits, + bias=bias, + ref_input=ref_input, + ref_weight=ref_weight, + ref_bias=ref_bias, beta=beta, + compiled=compiled, + use_ref_model=use_ref_model, ) - # Return metrics matching the PyTorch implementation - metrics = ( - seq_log_probs.mean(), # policy log probs mean - seq_log_probs.std(), # policy log probs std - policy_logits.mean(), # policy logits mean - kl_div.mean(), # KL divergence mean - ) - - return loss, metrics - @staticmethod def backward(ctx, grad_output, *grad_metrics): """Backward pass for GRPO loss. @@ -131,53 +99,10 @@ def backward(ctx, grad_output, *grad_metrics): grad_output: Gradient of the loss (scalar) grad_metrics: Gradients of the metrics (not used in backward computation) """ - _input, weight, attention_mask, bias = ctx.saved_tensors - beta = ctx.beta # Retrieve beta for KL scaling - rewards = ctx.rewards # Retrieve rewards for advantage computation - - # Initialize gradients - grad_input = grad_weight = grad_bias = None - - # Compute gradients using autograd - with torch.enable_grad(): - _input = _input.detach().requires_grad_() - weight = weight.detach().requires_grad_() - if bias is not None: - bias = bias.detach().requires_grad_() - - # Forward pass - batch_size, seq_len, hidden_size = _input.shape - input_reshaped = _input.view(-1, hidden_size) - logits = (input_reshaped @ weight.t()).view(batch_size, seq_len, -1) - if bias is not None: - logits = logits + bias - - # Compute log probabilities and sequence-level scores - log_probs = F.log_softmax(logits, dim=-1) - seq_log_probs = (log_probs.max(dim=-1).values * attention_mask).sum(dim=-1) - - # Compute advantages - advantages = rewards - rewards.mean() - if advantages.std() > 0: - advantages = advantages / advantages.std() - - # Policy gradient loss with KL penalty - policy_loss = -(advantages * seq_log_probs) - kl_div = seq_log_probs - seq_log_probs.detach() # KL divergence from current policy - loss = (policy_loss + beta * kl_div).mean() # Take mean to get scalar loss - - # Backward pass with scalar gradient - loss.backward(grad_output) - grad_input = _input.grad - grad_weight = weight.grad - grad_bias = bias.grad if bias is not None else None - + grads = LigerFusedLinearRLHFBase.backward(ctx, grad_output) return ( - grad_input, - grad_weight, - None, # grad_attention_mask - None, # grad_rewards - grad_bias, + *grads[:4], # grad_input, grad_weight, grad_attention_mask, grad_rewards + None, # grad_bias None, # grad_ref_input None, # grad_ref_weight None, # grad_ref_bias diff --git a/test/chunked_loss/test_grpo_loss.py b/test/chunked_loss/test_grpo_loss.py index c331dd20e..56c992edf 100644 --- a/test/chunked_loss/test_grpo_loss.py +++ b/test/chunked_loss/test_grpo_loss.py @@ -35,19 +35,21 @@ def forward( ref_weight=None, ref_bias=None, ): - # Get policy logits and log probs + # Forward pass through linear layer batch_size, seq_len, hidden_size = x.shape input_reshaped = x.view(-1, hidden_size) logits = (input_reshaped @ self.lin.weight.t()).view(batch_size, seq_len, -1) if self.lin.bias is not None: logits = logits + self.lin.bias + + # Get log probabilities log_probs = F.log_softmax(logits, dim=-1) - # Get sequence-level log probs by taking max over vocab and summing over sequence - seq_log_probs = log_probs.max(dim=-1).values - policy_seq_logps = (seq_log_probs * attention_mask).sum(dim=-1) + # Get chosen token probabilities + chosen_tokens = log_probs.argmax(dim=-1) + chosen_token_logprobs = log_probs.gather(dim=-1, index=chosen_tokens.unsqueeze(-1)).squeeze(-1) - # Get reference model log probs if provided + # Get reference model probabilities if ref_input is not None and ref_weight is not None: with torch.no_grad(): ref_input_reshaped = ref_input.view(-1, ref_input.size(-1)) @@ -55,34 +57,42 @@ def forward( if ref_bias is not None: ref_logits = ref_logits + ref_bias ref_log_probs = F.log_softmax(ref_logits, dim=-1) - ref_seq_log_probs = ref_log_probs.max(dim=-1).values - ref_seq_logps = (ref_seq_log_probs * attention_mask).sum(dim=-1) + ref_token_logprobs = ref_log_probs.gather(dim=-1, index=chosen_tokens.unsqueeze(-1)).squeeze(-1) else: - ref_seq_logps = policy_seq_logps.detach() + ref_token_logprobs = chosen_token_logprobs.detach() + + # Compute advantages (exactly as in GRPOTrainer) + mean_grouped_rewards = rewards.mean() + std_grouped_rewards = rewards.std() + advantages = (rewards - mean_grouped_rewards) / (std_grouped_rewards + 1e-8) - # Compute advantages - advantages = rewards - rewards.mean() - if advantages.std() > 0: - advantages = advantages / advantages.std() + # Compute policy gradient loss with importance sampling ratio + ratio = torch.exp(chosen_token_logprobs - chosen_token_logprobs.detach()) + policy_loss = -ratio * advantages.unsqueeze(1) - # Policy gradient loss - policy_loss = -(advantages * policy_seq_logps) + # Compute KL penalty + kl_div = ( + torch.exp(ref_token_logprobs - chosen_token_logprobs) - (ref_token_logprobs - chosen_token_logprobs) - 1.0 + ) - # KL penalty - kl_div = policy_seq_logps - ref_seq_logps + # Combine losses + per_token_loss = policy_loss + self.beta * kl_div - # Total loss - loss = policy_loss + self.beta * kl_div + # Apply masking and normalize + masked_loss = per_token_loss * attention_mask + seq_lengths = attention_mask.sum(dim=1, keepdim=True) + seq_lengths = torch.clamp(seq_lengths, min=1.0) + loss = (masked_loss.sum(dim=1) / seq_lengths.squeeze(-1)).mean() - # Return metrics for logging + # Compute metrics metrics = ( - policy_seq_logps.mean(), # policy log probs mean - policy_seq_logps.std(), # policy log probs std - logits.mean(), # policy logits mean - kl_div.mean(), # KL divergence mean + chosen_token_logprobs.mean(), + chosen_token_logprobs.std(), + logits.mean(), + (kl_div * attention_mask).sum(1).mean() / attention_mask.sum(1).mean(), ) - return loss.mean(), metrics + return loss, metrics class LigerLMHeadGRPO(torch.nn.Module): @@ -185,16 +195,29 @@ def test_correctness( mask_indices = torch.randperm(B * T)[:num_elements_to_mask] attention_mask.view(-1)[mask_indices] = 0 - # Create rewards + # Create rewards with random values rewards = torch.randn(B, device=device, dtype=dtype) - # Forward pass - loss1, aux1 = torch_lm_head_grpo(input1, attention_mask, rewards) - loss2, aux2 = liger_lm_head_grpo(input2, attention_mask, rewards) + # Create reference inputs (optional) + ref_input = torch.randn(B, T, H, device=device, dtype=dtype) * scalar + ref_weight = torch.randn(V, H, device=device, dtype=dtype) + ref_bias = torch.randn(V, device=device, dtype=dtype) if bias else None + + # Forward pass with reference model + loss1, aux1 = torch_lm_head_grpo( + input1, attention_mask, rewards, ref_input=ref_input, ref_weight=ref_weight, ref_bias=ref_bias + ) + loss2, aux2 = liger_lm_head_grpo( + input2, attention_mask, rewards, ref_input=ref_input, ref_weight=ref_weight, ref_bias=ref_bias + ) # Check losses match assert_verbose_allclose(loss1, loss2, atol=atol, rtol=rtol) + + # Check metrics match assert len(aux1) == len(aux2) + for metric1, metric2 in zip(aux1, aux2): + assert_verbose_allclose(metric1, metric2, atol=atol, rtol=rtol) # Backward pass loss1.backward() @@ -215,3 +238,15 @@ def test_correctness( atol=atol, rtol=rtol, ) + + # Test without reference model + loss1, aux1 = torch_lm_head_grpo(input1, attention_mask, rewards) + loss2, aux2 = liger_lm_head_grpo(input2, attention_mask, rewards) + + # Check losses match (without reference model) + assert_verbose_allclose(loss1, loss2, atol=atol, rtol=rtol) + + # Check metrics match (without reference model) + assert len(aux1) == len(aux2) + for metric1, metric2 in zip(aux1, aux2): + assert_verbose_allclose(metric1, metric2, atol=atol, rtol=rtol) From 1dfb03ac506a8fc3281ea83eb2c6b4613ab80de8 Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Mon, 3 Feb 2025 16:25:39 +0100 Subject: [PATCH 07/15] add num_generations --- .../chunked_loss/fused_linear_rlhf.py | 68 +++++++-------- src/liger_kernel/chunked_loss/grpo_loss.py | 14 ++- test/chunked_loss/test_grpo_loss.py | 85 ++++++++++++------- 3 files changed, 96 insertions(+), 71 deletions(-) diff --git a/src/liger_kernel/chunked_loss/fused_linear_rlhf.py b/src/liger_kernel/chunked_loss/fused_linear_rlhf.py index caaa92e14..38933e4d8 100644 --- a/src/liger_kernel/chunked_loss/fused_linear_rlhf.py +++ b/src/liger_kernel/chunked_loss/fused_linear_rlhf.py @@ -14,7 +14,7 @@ def forward( rewards, bias=None, loss_fn=None, - chunk_size=1, + num_generations=1, beta=0.1, compiled=True, use_ref_model=False, @@ -34,38 +34,50 @@ def forward( grad_bias = torch.zeros_like(bias) if bias is not None else None # [V] aggregated_metrics = [] + # Create a partial function with fixed arguments compute_loss = partial( LigerFusedLinearRLHFBase._compute_chunk_loss, - rlhf_loss_fn=loss_fn, beta=beta, use_ref_model=use_ref_model, ref_weight=ref_weight, ref_bias=ref_bias, - rewards=rewards, + rlhf_loss_fn=loss_fn, ) - def fused_fwd_bwd(input_chunk, attention_mask_chunk, ref_input_chunk): + def fused_fwd_bwd(input_chunk, attention_mask_chunk, rewards_chunk, ref_input_chunk): """Fused forward and backward for a chunk.""" if bias is not None: return torch.func.grad_and_value(compute_loss, argnums=(0, 1, 4), has_aux=True)( - input_chunk, weight, attention_mask_chunk, ref_input_chunk, bias + input_chunk, # arg 0 + weight, # arg 1 + attention_mask_chunk, # arg 2 + rewards_chunk, # arg 3 + ref_input_chunk, # arg 4 + bias, # arg 5 ) else: return torch.func.grad_and_value(compute_loss, argnums=(0, 1), has_aux=True)( - input_chunk, weight, attention_mask_chunk, ref_input_chunk + input_chunk, # arg 0 + weight, # arg 1 + attention_mask_chunk, # arg 2 + rewards_chunk, # arg 3 + ref_input_chunk, # arg 4 ) - def accumulate_chunk(input_chunk, attention_mask_chunk, ref_input_chunk=None): - nonlocal loss_acc, grad_weight, grad_inputs, grad_bias, aggregated_metrics + def accumulate_chunk(input_chunk, attention_mask_chunk, rewards_chunk, ref_input_chunk=None): + # nonlocal loss_acc, grad_weight, grad_inputs, grad_bias, aggregated_metrics if bias is not None: (chunk_grad_input, chunk_grad_weight, chunk_grad_bias), (chunk_loss, chunk_metrics) = fused_fwd_bwd( - input_chunk, attention_mask_chunk, ref_input_chunk + input_chunk, attention_mask_chunk, rewards_chunk, ref_input_chunk ) + if chunk_grad_bias.shape != grad_bias.shape: + # Ensure we're summing to match the vocab size dimension + chunk_grad_bias = chunk_grad_bias.view(-1, grad_bias.shape[0]).sum(0) grad_bias.add_(chunk_grad_bias) else: (chunk_grad_input, chunk_grad_weight), (chunk_loss, chunk_metrics) = fused_fwd_bwd( - input_chunk, attention_mask_chunk, ref_input_chunk + input_chunk, attention_mask_chunk, rewards_chunk, ref_input_chunk ) # Accumulate gradients and loss @@ -92,13 +104,14 @@ def accumulate_chunk(input_chunk, attention_mask_chunk, ref_input_chunk=None): accumulate_chunk = torch.compile(accumulate_chunk) # Process input in chunks - chunks = max(1, _input.shape[0] // chunk_size) + chunks = max(1, _input.shape[0] // num_generations) _input_chunks = torch.chunk(_input, chunks=chunks, dim=0) _attention_mask_chunks = torch.chunk(attention_mask, chunks=chunks, dim=0) + _rewards_chunks = torch.chunk(rewards, chunks=chunks, dim=0) _ref_input_chunks = torch.chunk(ref_input, chunks=chunks, dim=0) if use_ref_model else [None] * chunks - for input_chunk, attention_mask_chunk, ref_input_chunk in zip( - _input_chunks, _attention_mask_chunks, _ref_input_chunks + for input_chunk, attention_mask_chunk, rewards_chunk, ref_input_chunk in zip( + _input_chunks, _attention_mask_chunks, _rewards_chunks, _ref_input_chunks ): # Mark dynamic dimensions torch._dynamo.mark_dynamic(input_chunk, 1) @@ -106,7 +119,7 @@ def accumulate_chunk(input_chunk, attention_mask_chunk, ref_input_chunk=None): if ref_input_chunk is not None: torch._dynamo.mark_dynamic(ref_input_chunk, 1) - accumulate_chunk(input_chunk, attention_mask_chunk, ref_input_chunk) + accumulate_chunk(input_chunk, attention_mask_chunk, rewards_chunk, ref_input_chunk) # Combine gradients grad_input = torch.cat(grad_inputs, dim=0) @@ -129,44 +142,30 @@ def _compute_chunk_loss( input_chunk, weight, attention_mask_chunk, + rewards_chunk, ref_input_chunk=None, bias=None, - rlhf_loss_fn=None, beta=0.1, use_ref_model=False, ref_weight=None, ref_bias=None, - rewards=None, + rlhf_loss_fn=None, ): """Compute loss for a single chunk.""" # Get policy log probabilities using chunk_forward - ( - log_probs, - logits, - logits_mean, - ) = LigerFusedLinearRLHFBase.chunk_forward( - input_chunk, - weight, - attention_mask_chunk, - bias=bias, - ) + (log_probs, _, logits_mean) = LigerFusedLinearRLHFBase.chunk_forward(input_chunk, weight, bias=bias) # Get reference log probabilities if needed ref_log_probs = None if use_ref_model and ref_input_chunk is not None: with torch.no_grad(): - ref_log_probs, _, _ = LigerFusedLinearRLHFBase.chunk_forward( - ref_input_chunk, - ref_weight, - attention_mask_chunk, - bias=ref_bias, - ) + ref_log_probs, _, _ = LigerFusedLinearRLHFBase.chunk_forward(ref_input_chunk, ref_weight, bias=ref_bias) # Compute chunk loss and metrics chunk_loss, chunk_metrics = rlhf_loss_fn( log_probs=log_probs, attention_mask=attention_mask_chunk, - rewards=rewards, + rewards=rewards_chunk, ref_log_probs=ref_log_probs, beta=beta, ) @@ -177,7 +176,6 @@ def _compute_chunk_loss( def chunk_forward( input_chunk, weight, - attention_mask_chunk, bias=None, ): """Forward pass computation for a single chunk.""" @@ -191,7 +189,7 @@ def chunk_forward( # Reshape to [B, T, V] and compute log_probs logits = logits.view(batch_size, seq_len, -1) - log_probs = F.log_softmax(logits, dim=-1) + log_probs = F.log_softmax(logits.float(), dim=-1) # Calculate mean logits for monitoring logits_mean = logits.sum() / (batch_size * seq_len * weight.shape[0]) diff --git a/src/liger_kernel/chunked_loss/grpo_loss.py b/src/liger_kernel/chunked_loss/grpo_loss.py index d181e2379..e0d9f4b22 100644 --- a/src/liger_kernel/chunked_loss/grpo_loss.py +++ b/src/liger_kernel/chunked_loss/grpo_loss.py @@ -27,9 +27,11 @@ def rlhf_loss_fn( else: ref_token_logprobs = chosen_token_logprobs.detach() - # Compute advantages - mean_grouped_rewards = rewards.mean() - std_grouped_rewards = rewards.std() + # Compute advantages per batch entry in a grouped fashion + mean_grouped_rewards = rewards.mean() # [batch_size,] + std_grouped_rewards = rewards.std() # [batch_size,] + + # Calculate advantages using the same epsilon as in GRPOTrainer advantages = (rewards - mean_grouped_rewards) / (std_grouped_rewards + 1e-8) # Compute policy gradient loss with importance sampling ratio @@ -74,6 +76,7 @@ def forward( beta=0.1, compiled=True, use_ref_model=True, + num_generations=1, ): return LigerFusedLinearRLHFBase.forward( ctx=ctx, @@ -89,6 +92,7 @@ def forward( beta=beta, compiled=compiled, use_ref_model=use_ref_model, + num_generations=num_generations, ) @staticmethod @@ -109,6 +113,7 @@ def backward(ctx, grad_output, *grad_metrics): None, # grad_beta None, # grad_compiled None, # grad_use_ref_model + None, # grad_num_generations ) @@ -120,11 +125,13 @@ def __init__( beta: float = 0.1, compiled: bool = True, use_ref_model: bool = True, + num_generations: int = 1, ): super().__init__() self.beta = beta self.compiled = compiled self.use_ref_model = use_ref_model + self.num_generations = num_generations def forward( self, @@ -149,4 +156,5 @@ def forward( self.beta, self.compiled, self.use_ref_model, + self.num_generations, ) diff --git a/test/chunked_loss/test_grpo_loss.py b/test/chunked_loss/test_grpo_loss.py index 56c992edf..b8841bb16 100644 --- a/test/chunked_loss/test_grpo_loss.py +++ b/test/chunked_loss/test_grpo_loss.py @@ -21,24 +21,29 @@ def __init__( dtype: torch.dtype, bias: bool = False, beta: float = 0.1, + num_generations: int = 4, ): super().__init__() self.lin = torch.nn.Linear(in_features=H, out_features=V, bias=bias, dtype=dtype) self.beta = beta + self.num_generations = num_generations def forward( self, - x, - attention_mask, - rewards, - ref_input=None, + x, # Shape: [batch_size*num_generations, seq_len, hidden_size] + attention_mask, # Shape: [batch_size*num_generations, seq_len] + rewards, # Shape: [batch_size*num_generations,] + ref_input=None, # Shape: [batch_size*num_generations, seq_len, hidden_size] ref_weight=None, ref_bias=None, ): # Forward pass through linear layer - batch_size, seq_len, hidden_size = x.shape + batch_size = x.shape[0] // self.num_generations # Get true batch size + seq_len = x.shape[1] + hidden_size = x.shape[2] + input_reshaped = x.view(-1, hidden_size) - logits = (input_reshaped @ self.lin.weight.t()).view(batch_size, seq_len, -1) + logits = (input_reshaped @ self.lin.weight.t()).view(batch_size * self.num_generations, seq_len, -1) if self.lin.bias is not None: logits = logits + self.lin.bias @@ -53,7 +58,7 @@ def forward( if ref_input is not None and ref_weight is not None: with torch.no_grad(): ref_input_reshaped = ref_input.view(-1, ref_input.size(-1)) - ref_logits = (ref_input_reshaped @ ref_weight.t()).view(batch_size, seq_len, -1) + ref_logits = (ref_input_reshaped @ ref_weight.t()).view(batch_size * self.num_generations, seq_len, -1) if ref_bias is not None: ref_logits = ref_logits + ref_bias ref_log_probs = F.log_softmax(ref_logits, dim=-1) @@ -61,28 +66,36 @@ def forward( else: ref_token_logprobs = chosen_token_logprobs.detach() - # Compute advantages (exactly as in GRPOTrainer) - mean_grouped_rewards = rewards.mean() - std_grouped_rewards = rewards.std() - advantages = (rewards - mean_grouped_rewards) / (std_grouped_rewards + 1e-8) - - # Compute policy gradient loss with importance sampling ratio - ratio = torch.exp(chosen_token_logprobs - chosen_token_logprobs.detach()) - policy_loss = -ratio * advantages.unsqueeze(1) - - # Compute KL penalty + # Compute KL divergence between model and reference model kl_div = ( torch.exp(ref_token_logprobs - chosen_token_logprobs) - (ref_token_logprobs - chosen_token_logprobs) - 1.0 ) - # Combine losses - per_token_loss = policy_loss + self.beta * kl_div + # Compute advantages per batch entry in a grouped fashion + # rewards shape: [batch_size, num_generations] + mean_grouped_rewards = rewards.view(batch_size, self.num_generations).mean( + dim=1, keepdim=True + ) # [batch_size, 1] + std_grouped_rewards = rewards.view(batch_size, self.num_generations).std(dim=1, keepdim=True) # [batch_size, 1] + + # Expand means and stds to match the number of generations + mean_grouped_rewards = mean_grouped_rewards.expand(-1, self.num_generations).reshape( + -1 + ) # [batch_size * num_generations] + std_grouped_rewards = std_grouped_rewards.expand(-1, self.num_generations).reshape( + -1 + ) # [batch_size * num_generations] + + # Calculate advantages using the same epsilon as in GRPOTrainer + rewards_flat = rewards.view(-1) # [batch_size * num_generations] + advantages = (rewards_flat - mean_grouped_rewards) / (std_grouped_rewards + 1e-4) + + # Compute policy gradient loss with importance sampling ratio + per_token_loss = torch.exp(chosen_token_logprobs - chosen_token_logprobs.detach()) * advantages.unsqueeze(1) + per_token_loss = -(per_token_loss - self.beta * kl_div) # Apply masking and normalize - masked_loss = per_token_loss * attention_mask - seq_lengths = attention_mask.sum(dim=1, keepdim=True) - seq_lengths = torch.clamp(seq_lengths, min=1.0) - loss = (masked_loss.sum(dim=1) / seq_lengths.squeeze(-1)).mean() + loss = ((per_token_loss * attention_mask).sum(dim=1) / attention_mask.sum(dim=1)).mean() # Compute metrics metrics = ( @@ -103,11 +116,13 @@ def __init__( dtype: torch.dtype, bias: bool = False, beta: float = 0.1, + num_generations: int = 4, ): super().__init__() self.lin = torch.nn.Linear(in_features=H, out_features=V, bias=bias, dtype=dtype) self.grpo_loss = LigerFusedLinearGRPOFunction.apply self.beta = beta + self.num_generations = num_generations def forward( self, @@ -131,6 +146,7 @@ def forward( self.beta, # beta True, # compiled ref_input is not None, # use_ref_model + self.num_generations, # num_generations ) @@ -162,12 +178,14 @@ def test_correctness( bias, beta, ): + num_generations = 4 # Fixed number of generations for testing torch_lm_head_grpo = TorchLMHeadGRPO( H=H, V=V, dtype=dtype, bias=bias, beta=beta, + num_generations=num_generations, ) liger_lm_head_grpo = LigerLMHeadGRPO( H=H, @@ -175,6 +193,7 @@ def test_correctness( dtype=dtype, bias=bias, beta=beta, + num_generations=num_generations, ) # Initialize weights @@ -184,22 +203,22 @@ def test_correctness( if bias: torch_lm_head_grpo.lin.bias.data = liger_lm_head_grpo.lin.bias.data = torch.randn(V, device=device, dtype=dtype) - # Create inputs - _input = torch.randn(B, T, H, device=device, dtype=dtype) * scalar + # Create inputs with shape [B*num_generations, T, H] + _input = torch.randn(B * num_generations, T, H, device=device, dtype=dtype) * scalar input1 = _input.detach().clone().requires_grad_(True) input2 = _input.detach().clone().requires_grad_(True) - # Create attention mask with random padding - attention_mask = torch.ones(B, T, device=device) - num_elements_to_mask = torch.randint(1, B * T // 2, (1,)).item() - mask_indices = torch.randperm(B * T)[:num_elements_to_mask] + # Create attention mask with random padding [B*num_generations, T] + attention_mask = torch.ones(B * num_generations, T, device=device) + num_elements_to_mask = torch.randint(1, B * num_generations * T // 2, (1,)).item() + mask_indices = torch.randperm(B * num_generations * T)[:num_elements_to_mask] attention_mask.view(-1)[mask_indices] = 0 - # Create rewards with random values - rewards = torch.randn(B, device=device, dtype=dtype) + # Create rewards with shape [B, num_generations] + rewards = torch.randn(B * num_generations, device=device, dtype=dtype) - # Create reference inputs (optional) - ref_input = torch.randn(B, T, H, device=device, dtype=dtype) * scalar + # Create reference inputs (optional) with shape [B*num_generations, T, H] + ref_input = torch.randn(B * num_generations, T, H, device=device, dtype=dtype) * scalar ref_weight = torch.randn(V, H, device=device, dtype=dtype) ref_bias = torch.randn(V, device=device, dtype=dtype) if bias else None From db962c9f1ce75cc6fbd12fb5adacecb36db36520 Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Mon, 3 Feb 2025 18:52:13 +0100 Subject: [PATCH 08/15] fix loss --- .../chunked_loss/fused_linear_rlhf.py | 19 +++++-------------- 1 file changed, 5 insertions(+), 14 deletions(-) diff --git a/src/liger_kernel/chunked_loss/fused_linear_rlhf.py b/src/liger_kernel/chunked_loss/fused_linear_rlhf.py index 38933e4d8..42569beb1 100644 --- a/src/liger_kernel/chunked_loss/fused_linear_rlhf.py +++ b/src/liger_kernel/chunked_loss/fused_linear_rlhf.py @@ -47,7 +47,7 @@ def forward( def fused_fwd_bwd(input_chunk, attention_mask_chunk, rewards_chunk, ref_input_chunk): """Fused forward and backward for a chunk.""" if bias is not None: - return torch.func.grad_and_value(compute_loss, argnums=(0, 1, 4), has_aux=True)( + return torch.func.grad_and_value(compute_loss, argnums=(0, 1, 5), has_aux=True)( input_chunk, # arg 0 weight, # arg 1 attention_mask_chunk, # arg 2 @@ -65,15 +65,10 @@ def fused_fwd_bwd(input_chunk, attention_mask_chunk, rewards_chunk, ref_input_ch ) def accumulate_chunk(input_chunk, attention_mask_chunk, rewards_chunk, ref_input_chunk=None): - # nonlocal loss_acc, grad_weight, grad_inputs, grad_bias, aggregated_metrics - if bias is not None: (chunk_grad_input, chunk_grad_weight, chunk_grad_bias), (chunk_loss, chunk_metrics) = fused_fwd_bwd( input_chunk, attention_mask_chunk, rewards_chunk, ref_input_chunk ) - if chunk_grad_bias.shape != grad_bias.shape: - # Ensure we're summing to match the vocab size dimension - chunk_grad_bias = chunk_grad_bias.view(-1, grad_bias.shape[0]).sum(0) grad_bias.add_(chunk_grad_bias) else: (chunk_grad_input, chunk_grad_weight), (chunk_loss, chunk_metrics) = fused_fwd_bwd( @@ -153,7 +148,7 @@ def _compute_chunk_loss( ): """Compute loss for a single chunk.""" # Get policy log probabilities using chunk_forward - (log_probs, _, logits_mean) = LigerFusedLinearRLHFBase.chunk_forward(input_chunk, weight, bias=bias) + log_probs, _, logits_mean = LigerFusedLinearRLHFBase.chunk_forward(input_chunk, weight, bias=bias) # Get reference log probabilities if needed ref_log_probs = None @@ -161,7 +156,7 @@ def _compute_chunk_loss( with torch.no_grad(): ref_log_probs, _, _ = LigerFusedLinearRLHFBase.chunk_forward(ref_input_chunk, ref_weight, bias=ref_bias) - # Compute chunk loss and metrics + # Compute chunk loss and metrics using the provided loss function chunk_loss, chunk_metrics = rlhf_loss_fn( log_probs=log_probs, attention_mask=attention_mask_chunk, @@ -173,16 +168,12 @@ def _compute_chunk_loss( return chunk_loss, (logits_mean, *chunk_metrics) @staticmethod - def chunk_forward( - input_chunk, - weight, - bias=None, - ): + def chunk_forward(input_chunk, weight, bias=None): """Forward pass computation for a single chunk.""" batch_size, seq_len, hidden_size = input_chunk.shape input_reshaped = input_chunk.view(-1, hidden_size) # [B*T, H] - # Linear layer: [B*T, H] @ [V, H].T -> [B*T, V] + # Linear layer: [B*T, H] @ [H, V] -> [B*T, V] logits = F.linear(input_reshaped, weight) # weight shape is [V, H] if bias is not None: logits = logits + bias.view(1, -1) From c66261eb9196d47c544c64c1289137ea07de1f16 Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Mon, 3 Feb 2025 19:23:38 +0100 Subject: [PATCH 09/15] aux match the HF loss aux metrics --- src/liger_kernel/chunked_loss/fused_linear_rlhf.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/liger_kernel/chunked_loss/fused_linear_rlhf.py b/src/liger_kernel/chunked_loss/fused_linear_rlhf.py index 42569beb1..25510c91b 100644 --- a/src/liger_kernel/chunked_loss/fused_linear_rlhf.py +++ b/src/liger_kernel/chunked_loss/fused_linear_rlhf.py @@ -148,7 +148,7 @@ def _compute_chunk_loss( ): """Compute loss for a single chunk.""" # Get policy log probabilities using chunk_forward - log_probs, _, logits_mean = LigerFusedLinearRLHFBase.chunk_forward(input_chunk, weight, bias=bias) + log_probs, _, _ = LigerFusedLinearRLHFBase.chunk_forward(input_chunk, weight, bias=bias) # Get reference log probabilities if needed ref_log_probs = None @@ -165,7 +165,7 @@ def _compute_chunk_loss( beta=beta, ) - return chunk_loss, (logits_mean, *chunk_metrics) + return chunk_loss, chunk_metrics @staticmethod def chunk_forward(input_chunk, weight, bias=None): From dbf2fd114e24b7628e85b3f2d4494994c7e3a2bc Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Mon, 3 Feb 2025 19:32:40 +0100 Subject: [PATCH 10/15] return the same metrics --- src/liger_kernel/chunked_loss/fused_linear_rlhf.py | 4 ++-- test/chunked_loss/test_grpo_loss.py | 5 +++-- 2 files changed, 5 insertions(+), 4 deletions(-) diff --git a/src/liger_kernel/chunked_loss/fused_linear_rlhf.py b/src/liger_kernel/chunked_loss/fused_linear_rlhf.py index 25510c91b..42569beb1 100644 --- a/src/liger_kernel/chunked_loss/fused_linear_rlhf.py +++ b/src/liger_kernel/chunked_loss/fused_linear_rlhf.py @@ -148,7 +148,7 @@ def _compute_chunk_loss( ): """Compute loss for a single chunk.""" # Get policy log probabilities using chunk_forward - log_probs, _, _ = LigerFusedLinearRLHFBase.chunk_forward(input_chunk, weight, bias=bias) + log_probs, _, logits_mean = LigerFusedLinearRLHFBase.chunk_forward(input_chunk, weight, bias=bias) # Get reference log probabilities if needed ref_log_probs = None @@ -165,7 +165,7 @@ def _compute_chunk_loss( beta=beta, ) - return chunk_loss, chunk_metrics + return chunk_loss, (logits_mean, *chunk_metrics) @staticmethod def chunk_forward(input_chunk, weight, bias=None): diff --git a/test/chunked_loss/test_grpo_loss.py b/test/chunked_loss/test_grpo_loss.py index b8841bb16..08e6fc90e 100644 --- a/test/chunked_loss/test_grpo_loss.py +++ b/test/chunked_loss/test_grpo_loss.py @@ -99,9 +99,10 @@ def forward( # Compute metrics metrics = ( + logits.mean(), chosen_token_logprobs.mean(), chosen_token_logprobs.std(), - logits.mean(), + log_probs.mean(), (kl_div * attention_mask).sum(1).mean() / attention_mask.sum(1).mean(), ) @@ -178,7 +179,7 @@ def test_correctness( bias, beta, ): - num_generations = 4 # Fixed number of generations for testing + num_generations = 1 # Fixed number of generations for testing torch_lm_head_grpo = TorchLMHeadGRPO( H=H, V=V, From e805a8bafcaa1cd2063df0cc200b11d70df1bb89 Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Mon, 3 Feb 2025 21:06:55 +0100 Subject: [PATCH 11/15] fix for num_gen = 1 --- src/liger_kernel/chunked_loss/grpo_loss.py | 3 +-- test/chunked_loss/test_grpo_loss.py | 4 ++-- 2 files changed, 3 insertions(+), 4 deletions(-) diff --git a/src/liger_kernel/chunked_loss/grpo_loss.py b/src/liger_kernel/chunked_loss/grpo_loss.py index e0d9f4b22..e67157cbb 100644 --- a/src/liger_kernel/chunked_loss/grpo_loss.py +++ b/src/liger_kernel/chunked_loss/grpo_loss.py @@ -105,8 +105,7 @@ def backward(ctx, grad_output, *grad_metrics): """ grads = LigerFusedLinearRLHFBase.backward(ctx, grad_output) return ( - *grads[:4], # grad_input, grad_weight, grad_attention_mask, grad_rewards - None, # grad_bias + *grads[:5], # grad_input, grad_weight, grad_attention_mask, grad_rewards, grad_bias None, # grad_ref_input None, # grad_ref_weight None, # grad_ref_bias diff --git a/test/chunked_loss/test_grpo_loss.py b/test/chunked_loss/test_grpo_loss.py index 08e6fc90e..b38d85651 100644 --- a/test/chunked_loss/test_grpo_loss.py +++ b/test/chunked_loss/test_grpo_loss.py @@ -103,7 +103,7 @@ def forward( chosen_token_logprobs.mean(), chosen_token_logprobs.std(), log_probs.mean(), - (kl_div * attention_mask).sum(1).mean() / attention_mask.sum(1).mean(), + ((kl_div * attention_mask).sum(dim=1) / attention_mask.sum(dim=1)).mean(), ) return loss, metrics @@ -162,7 +162,7 @@ def forward( "scalar, dtype, atol, rtol", [ (1.0, torch.bfloat16, 5e-2, 5e-2), - (1.0, torch.float32, 1e-5, 5e-4), + (1.0, torch.float32, 5e-3, 5e-3), ], ) @pytest.mark.parametrize("bias", [True, False]) From 3f09cee275e661de78f64ce911af7c0ca5616497 Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Mon, 3 Feb 2025 21:14:48 +0100 Subject: [PATCH 12/15] scale the loss --- src/liger_kernel/chunked_loss/grpo_loss.py | 2 +- test/chunked_loss/test_grpo_loss.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/liger_kernel/chunked_loss/grpo_loss.py b/src/liger_kernel/chunked_loss/grpo_loss.py index e67157cbb..7236d31fb 100644 --- a/src/liger_kernel/chunked_loss/grpo_loss.py +++ b/src/liger_kernel/chunked_loss/grpo_loss.py @@ -57,7 +57,7 @@ def rlhf_loss_fn( chosen_token_logprobs.mean(), # mean log prob chosen_token_logprobs.std(), # std log prob log_probs.mean(), # mean all log probs - (kl_div * attention_mask).sum(1).mean() / attention_mask.sum(1).mean(), # mean KL div + ((kl_div * attention_mask).sum(dim=1) / attention_mask.sum(dim=1)).mean(), # mean KL div ) return loss, metrics diff --git a/test/chunked_loss/test_grpo_loss.py b/test/chunked_loss/test_grpo_loss.py index b38d85651..ae7e438ff 100644 --- a/test/chunked_loss/test_grpo_loss.py +++ b/test/chunked_loss/test_grpo_loss.py @@ -179,7 +179,7 @@ def test_correctness( bias, beta, ): - num_generations = 1 # Fixed number of generations for testing + num_generations = 4 # Fixed number of generations for testing torch_lm_head_grpo = TorchLMHeadGRPO( H=H, V=V, From 51106ee96af8e1e2815cea346c5d592764d51d61 Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Mon, 3 Feb 2025 21:16:33 +0100 Subject: [PATCH 13/15] scale the loss --- src/liger_kernel/chunked_loss/fused_linear_rlhf.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/src/liger_kernel/chunked_loss/fused_linear_rlhf.py b/src/liger_kernel/chunked_loss/fused_linear_rlhf.py index 42569beb1..8976d9dbd 100644 --- a/src/liger_kernel/chunked_loss/fused_linear_rlhf.py +++ b/src/liger_kernel/chunked_loss/fused_linear_rlhf.py @@ -116,6 +116,9 @@ def accumulate_chunk(input_chunk, attention_mask_chunk, rewards_chunk, ref_input accumulate_chunk(input_chunk, attention_mask_chunk, rewards_chunk, ref_input_chunk) + # Scale accumulated loss by number of chunks since we're averaging + loss_acc = loss_acc / chunks + # Combine gradients grad_input = torch.cat(grad_inputs, dim=0) From f18847855770346336580ce40f5a473b8b71451c Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Tue, 4 Feb 2025 09:10:57 +0100 Subject: [PATCH 14/15] use same epsilon as TRL --- src/liger_kernel/chunked_loss/grpo_loss.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/liger_kernel/chunked_loss/grpo_loss.py b/src/liger_kernel/chunked_loss/grpo_loss.py index 7236d31fb..593baf613 100644 --- a/src/liger_kernel/chunked_loss/grpo_loss.py +++ b/src/liger_kernel/chunked_loss/grpo_loss.py @@ -32,7 +32,7 @@ def rlhf_loss_fn( std_grouped_rewards = rewards.std() # [batch_size,] # Calculate advantages using the same epsilon as in GRPOTrainer - advantages = (rewards - mean_grouped_rewards) / (std_grouped_rewards + 1e-8) + advantages = (rewards - mean_grouped_rewards) / (std_grouped_rewards + 1e-4) # Compute policy gradient loss with importance sampling ratio ratio = torch.exp(chosen_token_logprobs - chosen_token_logprobs.detach()) From 3a3f928e002cabdb8127faee4788b1aa157756a3 Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Tue, 4 Feb 2025 13:24:21 +0100 Subject: [PATCH 15/15] relax test thresholds --- test/chunked_loss/test_grpo_loss.py | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) diff --git a/test/chunked_loss/test_grpo_loss.py b/test/chunked_loss/test_grpo_loss.py index ae7e438ff..afa4b495b 100644 --- a/test/chunked_loss/test_grpo_loss.py +++ b/test/chunked_loss/test_grpo_loss.py @@ -162,11 +162,12 @@ def forward( "scalar, dtype, atol, rtol", [ (1.0, torch.bfloat16, 5e-2, 5e-2), - (1.0, torch.float32, 5e-3, 5e-3), + (1.0, torch.float32, 5e-2, 5e-2), ], ) @pytest.mark.parametrize("bias", [True, False]) -@pytest.mark.parametrize("beta", [0.1, 0.2]) +@pytest.mark.parametrize("ref_bias", [True, False]) +@pytest.mark.parametrize("beta", [0.1, 0.9]) def test_correctness( B, T, @@ -177,6 +178,7 @@ def test_correctness( atol, rtol, bias, + ref_bias, beta, ): num_generations = 4 # Fixed number of generations for testing @@ -216,19 +218,19 @@ def test_correctness( attention_mask.view(-1)[mask_indices] = 0 # Create rewards with shape [B, num_generations] - rewards = torch.randn(B * num_generations, device=device, dtype=dtype) + rewards = torch.rand(B * num_generations, device=device, dtype=dtype) # Create reference inputs (optional) with shape [B*num_generations, T, H] ref_input = torch.randn(B * num_generations, T, H, device=device, dtype=dtype) * scalar ref_weight = torch.randn(V, H, device=device, dtype=dtype) - ref_bias = torch.randn(V, device=device, dtype=dtype) if bias else None + ref_bias_weight = torch.randn(V, device=device, dtype=dtype) if ref_bias else None # Forward pass with reference model loss1, aux1 = torch_lm_head_grpo( - input1, attention_mask, rewards, ref_input=ref_input, ref_weight=ref_weight, ref_bias=ref_bias + input1, attention_mask, rewards, ref_input=ref_input, ref_weight=ref_weight, ref_bias=ref_bias_weight ) loss2, aux2 = liger_lm_head_grpo( - input2, attention_mask, rewards, ref_input=ref_input, ref_weight=ref_weight, ref_bias=ref_bias + input2, attention_mask, rewards, ref_input=ref_input, ref_weight=ref_weight, ref_bias=ref_bias_weight ) # Check losses match