diff --git a/src/liger_kernel/chunked_loss/fused_linear_rlhf.py b/src/liger_kernel/chunked_loss/fused_linear_rlhf.py index caaa92e14..38933e4d8 100644 --- a/src/liger_kernel/chunked_loss/fused_linear_rlhf.py +++ b/src/liger_kernel/chunked_loss/fused_linear_rlhf.py @@ -14,7 +14,7 @@ def forward( rewards, bias=None, loss_fn=None, - chunk_size=1, + num_generations=1, beta=0.1, compiled=True, use_ref_model=False, @@ -34,38 +34,50 @@ def forward( grad_bias = torch.zeros_like(bias) if bias is not None else None # [V] aggregated_metrics = [] + # Create a partial function with fixed arguments compute_loss = partial( LigerFusedLinearRLHFBase._compute_chunk_loss, - rlhf_loss_fn=loss_fn, beta=beta, use_ref_model=use_ref_model, ref_weight=ref_weight, ref_bias=ref_bias, - rewards=rewards, + rlhf_loss_fn=loss_fn, ) - def fused_fwd_bwd(input_chunk, attention_mask_chunk, ref_input_chunk): + def fused_fwd_bwd(input_chunk, attention_mask_chunk, rewards_chunk, ref_input_chunk): """Fused forward and backward for a chunk.""" if bias is not None: return torch.func.grad_and_value(compute_loss, argnums=(0, 1, 4), has_aux=True)( - input_chunk, weight, attention_mask_chunk, ref_input_chunk, bias + input_chunk, # arg 0 + weight, # arg 1 + attention_mask_chunk, # arg 2 + rewards_chunk, # arg 3 + ref_input_chunk, # arg 4 + bias, # arg 5 ) else: return torch.func.grad_and_value(compute_loss, argnums=(0, 1), has_aux=True)( - input_chunk, weight, attention_mask_chunk, ref_input_chunk + input_chunk, # arg 0 + weight, # arg 1 + attention_mask_chunk, # arg 2 + rewards_chunk, # arg 3 + ref_input_chunk, # arg 4 ) - def accumulate_chunk(input_chunk, attention_mask_chunk, ref_input_chunk=None): - nonlocal loss_acc, grad_weight, grad_inputs, grad_bias, aggregated_metrics + def accumulate_chunk(input_chunk, attention_mask_chunk, rewards_chunk, ref_input_chunk=None): + # nonlocal loss_acc, grad_weight, grad_inputs, grad_bias, aggregated_metrics if bias is not None: (chunk_grad_input, chunk_grad_weight, chunk_grad_bias), (chunk_loss, chunk_metrics) = fused_fwd_bwd( - input_chunk, attention_mask_chunk, ref_input_chunk + input_chunk, attention_mask_chunk, rewards_chunk, ref_input_chunk ) + if chunk_grad_bias.shape != grad_bias.shape: + # Ensure we're summing to match the vocab size dimension + chunk_grad_bias = chunk_grad_bias.view(-1, grad_bias.shape[0]).sum(0) grad_bias.add_(chunk_grad_bias) else: (chunk_grad_input, chunk_grad_weight), (chunk_loss, chunk_metrics) = fused_fwd_bwd( - input_chunk, attention_mask_chunk, ref_input_chunk + input_chunk, attention_mask_chunk, rewards_chunk, ref_input_chunk ) # Accumulate gradients and loss @@ -92,13 +104,14 @@ def accumulate_chunk(input_chunk, attention_mask_chunk, ref_input_chunk=None): accumulate_chunk = torch.compile(accumulate_chunk) # Process input in chunks - chunks = max(1, _input.shape[0] // chunk_size) + chunks = max(1, _input.shape[0] // num_generations) _input_chunks = torch.chunk(_input, chunks=chunks, dim=0) _attention_mask_chunks = torch.chunk(attention_mask, chunks=chunks, dim=0) + _rewards_chunks = torch.chunk(rewards, chunks=chunks, dim=0) _ref_input_chunks = torch.chunk(ref_input, chunks=chunks, dim=0) if use_ref_model else [None] * chunks - for input_chunk, attention_mask_chunk, ref_input_chunk in zip( - _input_chunks, _attention_mask_chunks, _ref_input_chunks + for input_chunk, attention_mask_chunk, rewards_chunk, ref_input_chunk in zip( + _input_chunks, _attention_mask_chunks, _rewards_chunks, _ref_input_chunks ): # Mark dynamic dimensions torch._dynamo.mark_dynamic(input_chunk, 1) @@ -106,7 +119,7 @@ def accumulate_chunk(input_chunk, attention_mask_chunk, ref_input_chunk=None): if ref_input_chunk is not None: torch._dynamo.mark_dynamic(ref_input_chunk, 1) - accumulate_chunk(input_chunk, attention_mask_chunk, ref_input_chunk) + accumulate_chunk(input_chunk, attention_mask_chunk, rewards_chunk, ref_input_chunk) # Combine gradients grad_input = torch.cat(grad_inputs, dim=0) @@ -129,44 +142,30 @@ def _compute_chunk_loss( input_chunk, weight, attention_mask_chunk, + rewards_chunk, ref_input_chunk=None, bias=None, - rlhf_loss_fn=None, beta=0.1, use_ref_model=False, ref_weight=None, ref_bias=None, - rewards=None, + rlhf_loss_fn=None, ): """Compute loss for a single chunk.""" # Get policy log probabilities using chunk_forward - ( - log_probs, - logits, - logits_mean, - ) = LigerFusedLinearRLHFBase.chunk_forward( - input_chunk, - weight, - attention_mask_chunk, - bias=bias, - ) + (log_probs, _, logits_mean) = LigerFusedLinearRLHFBase.chunk_forward(input_chunk, weight, bias=bias) # Get reference log probabilities if needed ref_log_probs = None if use_ref_model and ref_input_chunk is not None: with torch.no_grad(): - ref_log_probs, _, _ = LigerFusedLinearRLHFBase.chunk_forward( - ref_input_chunk, - ref_weight, - attention_mask_chunk, - bias=ref_bias, - ) + ref_log_probs, _, _ = LigerFusedLinearRLHFBase.chunk_forward(ref_input_chunk, ref_weight, bias=ref_bias) # Compute chunk loss and metrics chunk_loss, chunk_metrics = rlhf_loss_fn( log_probs=log_probs, attention_mask=attention_mask_chunk, - rewards=rewards, + rewards=rewards_chunk, ref_log_probs=ref_log_probs, beta=beta, ) @@ -177,7 +176,6 @@ def _compute_chunk_loss( def chunk_forward( input_chunk, weight, - attention_mask_chunk, bias=None, ): """Forward pass computation for a single chunk.""" @@ -191,7 +189,7 @@ def chunk_forward( # Reshape to [B, T, V] and compute log_probs logits = logits.view(batch_size, seq_len, -1) - log_probs = F.log_softmax(logits, dim=-1) + log_probs = F.log_softmax(logits.float(), dim=-1) # Calculate mean logits for monitoring logits_mean = logits.sum() / (batch_size * seq_len * weight.shape[0]) diff --git a/src/liger_kernel/chunked_loss/grpo_loss.py b/src/liger_kernel/chunked_loss/grpo_loss.py index d181e2379..e0d9f4b22 100644 --- a/src/liger_kernel/chunked_loss/grpo_loss.py +++ b/src/liger_kernel/chunked_loss/grpo_loss.py @@ -27,9 +27,11 @@ def rlhf_loss_fn( else: ref_token_logprobs = chosen_token_logprobs.detach() - # Compute advantages - mean_grouped_rewards = rewards.mean() - std_grouped_rewards = rewards.std() + # Compute advantages per batch entry in a grouped fashion + mean_grouped_rewards = rewards.mean() # [batch_size,] + std_grouped_rewards = rewards.std() # [batch_size,] + + # Calculate advantages using the same epsilon as in GRPOTrainer advantages = (rewards - mean_grouped_rewards) / (std_grouped_rewards + 1e-8) # Compute policy gradient loss with importance sampling ratio @@ -74,6 +76,7 @@ def forward( beta=0.1, compiled=True, use_ref_model=True, + num_generations=1, ): return LigerFusedLinearRLHFBase.forward( ctx=ctx, @@ -89,6 +92,7 @@ def forward( beta=beta, compiled=compiled, use_ref_model=use_ref_model, + num_generations=num_generations, ) @staticmethod @@ -109,6 +113,7 @@ def backward(ctx, grad_output, *grad_metrics): None, # grad_beta None, # grad_compiled None, # grad_use_ref_model + None, # grad_num_generations ) @@ -120,11 +125,13 @@ def __init__( beta: float = 0.1, compiled: bool = True, use_ref_model: bool = True, + num_generations: int = 1, ): super().__init__() self.beta = beta self.compiled = compiled self.use_ref_model = use_ref_model + self.num_generations = num_generations def forward( self, @@ -149,4 +156,5 @@ def forward( self.beta, self.compiled, self.use_ref_model, + self.num_generations, ) diff --git a/test/chunked_loss/test_grpo_loss.py b/test/chunked_loss/test_grpo_loss.py index 56c992edf..b8841bb16 100644 --- a/test/chunked_loss/test_grpo_loss.py +++ b/test/chunked_loss/test_grpo_loss.py @@ -21,24 +21,29 @@ def __init__( dtype: torch.dtype, bias: bool = False, beta: float = 0.1, + num_generations: int = 4, ): super().__init__() self.lin = torch.nn.Linear(in_features=H, out_features=V, bias=bias, dtype=dtype) self.beta = beta + self.num_generations = num_generations def forward( self, - x, - attention_mask, - rewards, - ref_input=None, + x, # Shape: [batch_size*num_generations, seq_len, hidden_size] + attention_mask, # Shape: [batch_size*num_generations, seq_len] + rewards, # Shape: [batch_size*num_generations,] + ref_input=None, # Shape: [batch_size*num_generations, seq_len, hidden_size] ref_weight=None, ref_bias=None, ): # Forward pass through linear layer - batch_size, seq_len, hidden_size = x.shape + batch_size = x.shape[0] // self.num_generations # Get true batch size + seq_len = x.shape[1] + hidden_size = x.shape[2] + input_reshaped = x.view(-1, hidden_size) - logits = (input_reshaped @ self.lin.weight.t()).view(batch_size, seq_len, -1) + logits = (input_reshaped @ self.lin.weight.t()).view(batch_size * self.num_generations, seq_len, -1) if self.lin.bias is not None: logits = logits + self.lin.bias @@ -53,7 +58,7 @@ def forward( if ref_input is not None and ref_weight is not None: with torch.no_grad(): ref_input_reshaped = ref_input.view(-1, ref_input.size(-1)) - ref_logits = (ref_input_reshaped @ ref_weight.t()).view(batch_size, seq_len, -1) + ref_logits = (ref_input_reshaped @ ref_weight.t()).view(batch_size * self.num_generations, seq_len, -1) if ref_bias is not None: ref_logits = ref_logits + ref_bias ref_log_probs = F.log_softmax(ref_logits, dim=-1) @@ -61,28 +66,36 @@ def forward( else: ref_token_logprobs = chosen_token_logprobs.detach() - # Compute advantages (exactly as in GRPOTrainer) - mean_grouped_rewards = rewards.mean() - std_grouped_rewards = rewards.std() - advantages = (rewards - mean_grouped_rewards) / (std_grouped_rewards + 1e-8) - - # Compute policy gradient loss with importance sampling ratio - ratio = torch.exp(chosen_token_logprobs - chosen_token_logprobs.detach()) - policy_loss = -ratio * advantages.unsqueeze(1) - - # Compute KL penalty + # Compute KL divergence between model and reference model kl_div = ( torch.exp(ref_token_logprobs - chosen_token_logprobs) - (ref_token_logprobs - chosen_token_logprobs) - 1.0 ) - # Combine losses - per_token_loss = policy_loss + self.beta * kl_div + # Compute advantages per batch entry in a grouped fashion + # rewards shape: [batch_size, num_generations] + mean_grouped_rewards = rewards.view(batch_size, self.num_generations).mean( + dim=1, keepdim=True + ) # [batch_size, 1] + std_grouped_rewards = rewards.view(batch_size, self.num_generations).std(dim=1, keepdim=True) # [batch_size, 1] + + # Expand means and stds to match the number of generations + mean_grouped_rewards = mean_grouped_rewards.expand(-1, self.num_generations).reshape( + -1 + ) # [batch_size * num_generations] + std_grouped_rewards = std_grouped_rewards.expand(-1, self.num_generations).reshape( + -1 + ) # [batch_size * num_generations] + + # Calculate advantages using the same epsilon as in GRPOTrainer + rewards_flat = rewards.view(-1) # [batch_size * num_generations] + advantages = (rewards_flat - mean_grouped_rewards) / (std_grouped_rewards + 1e-4) + + # Compute policy gradient loss with importance sampling ratio + per_token_loss = torch.exp(chosen_token_logprobs - chosen_token_logprobs.detach()) * advantages.unsqueeze(1) + per_token_loss = -(per_token_loss - self.beta * kl_div) # Apply masking and normalize - masked_loss = per_token_loss * attention_mask - seq_lengths = attention_mask.sum(dim=1, keepdim=True) - seq_lengths = torch.clamp(seq_lengths, min=1.0) - loss = (masked_loss.sum(dim=1) / seq_lengths.squeeze(-1)).mean() + loss = ((per_token_loss * attention_mask).sum(dim=1) / attention_mask.sum(dim=1)).mean() # Compute metrics metrics = ( @@ -103,11 +116,13 @@ def __init__( dtype: torch.dtype, bias: bool = False, beta: float = 0.1, + num_generations: int = 4, ): super().__init__() self.lin = torch.nn.Linear(in_features=H, out_features=V, bias=bias, dtype=dtype) self.grpo_loss = LigerFusedLinearGRPOFunction.apply self.beta = beta + self.num_generations = num_generations def forward( self, @@ -131,6 +146,7 @@ def forward( self.beta, # beta True, # compiled ref_input is not None, # use_ref_model + self.num_generations, # num_generations ) @@ -162,12 +178,14 @@ def test_correctness( bias, beta, ): + num_generations = 4 # Fixed number of generations for testing torch_lm_head_grpo = TorchLMHeadGRPO( H=H, V=V, dtype=dtype, bias=bias, beta=beta, + num_generations=num_generations, ) liger_lm_head_grpo = LigerLMHeadGRPO( H=H, @@ -175,6 +193,7 @@ def test_correctness( dtype=dtype, bias=bias, beta=beta, + num_generations=num_generations, ) # Initialize weights @@ -184,22 +203,22 @@ def test_correctness( if bias: torch_lm_head_grpo.lin.bias.data = liger_lm_head_grpo.lin.bias.data = torch.randn(V, device=device, dtype=dtype) - # Create inputs - _input = torch.randn(B, T, H, device=device, dtype=dtype) * scalar + # Create inputs with shape [B*num_generations, T, H] + _input = torch.randn(B * num_generations, T, H, device=device, dtype=dtype) * scalar input1 = _input.detach().clone().requires_grad_(True) input2 = _input.detach().clone().requires_grad_(True) - # Create attention mask with random padding - attention_mask = torch.ones(B, T, device=device) - num_elements_to_mask = torch.randint(1, B * T // 2, (1,)).item() - mask_indices = torch.randperm(B * T)[:num_elements_to_mask] + # Create attention mask with random padding [B*num_generations, T] + attention_mask = torch.ones(B * num_generations, T, device=device) + num_elements_to_mask = torch.randint(1, B * num_generations * T // 2, (1,)).item() + mask_indices = torch.randperm(B * num_generations * T)[:num_elements_to_mask] attention_mask.view(-1)[mask_indices] = 0 - # Create rewards with random values - rewards = torch.randn(B, device=device, dtype=dtype) + # Create rewards with shape [B, num_generations] + rewards = torch.randn(B * num_generations, device=device, dtype=dtype) - # Create reference inputs (optional) - ref_input = torch.randn(B, T, H, device=device, dtype=dtype) * scalar + # Create reference inputs (optional) with shape [B*num_generations, T, H] + ref_input = torch.randn(B * num_generations, T, H, device=device, dtype=dtype) * scalar ref_weight = torch.randn(V, H, device=device, dtype=dtype) ref_bias = torch.randn(V, device=device, dtype=dtype) if bias else None