Skip to content

Commit

Permalink
add num_generations
Browse files Browse the repository at this point in the history
  • Loading branch information
kashif committed Feb 3, 2025
1 parent 212263b commit f5f3157
Show file tree
Hide file tree
Showing 3 changed files with 96 additions and 71 deletions.
68 changes: 33 additions & 35 deletions src/liger_kernel/chunked_loss/fused_linear_rlhf.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ def forward(
rewards,
bias=None,
loss_fn=None,
chunk_size=1,
num_generations=1,
beta=0.1,
compiled=True,
use_ref_model=False,
Expand All @@ -34,38 +34,50 @@ def forward(
grad_bias = torch.zeros_like(bias) if bias is not None else None # [V]
aggregated_metrics = []

# Create a partial function with fixed arguments
compute_loss = partial(
LigerFusedLinearRLHFBase._compute_chunk_loss,
rlhf_loss_fn=loss_fn,
beta=beta,
use_ref_model=use_ref_model,
ref_weight=ref_weight,
ref_bias=ref_bias,
rewards=rewards,
rlhf_loss_fn=loss_fn,
)

def fused_fwd_bwd(input_chunk, attention_mask_chunk, ref_input_chunk):
def fused_fwd_bwd(input_chunk, attention_mask_chunk, rewards_chunk, ref_input_chunk):
"""Fused forward and backward for a chunk."""
if bias is not None:
return torch.func.grad_and_value(compute_loss, argnums=(0, 1, 4), has_aux=True)(
input_chunk, weight, attention_mask_chunk, ref_input_chunk, bias
input_chunk, # arg 0
weight, # arg 1
attention_mask_chunk, # arg 2
rewards_chunk, # arg 3
ref_input_chunk, # arg 4
bias, # arg 5
)
else:
return torch.func.grad_and_value(compute_loss, argnums=(0, 1), has_aux=True)(
input_chunk, weight, attention_mask_chunk, ref_input_chunk
input_chunk, # arg 0
weight, # arg 1
attention_mask_chunk, # arg 2
rewards_chunk, # arg 3
ref_input_chunk, # arg 4
)

def accumulate_chunk(input_chunk, attention_mask_chunk, ref_input_chunk=None):
nonlocal loss_acc, grad_weight, grad_inputs, grad_bias, aggregated_metrics
def accumulate_chunk(input_chunk, attention_mask_chunk, rewards_chunk, ref_input_chunk=None):
# nonlocal loss_acc, grad_weight, grad_inputs, grad_bias, aggregated_metrics

if bias is not None:
(chunk_grad_input, chunk_grad_weight, chunk_grad_bias), (chunk_loss, chunk_metrics) = fused_fwd_bwd(
input_chunk, attention_mask_chunk, ref_input_chunk
input_chunk, attention_mask_chunk, rewards_chunk, ref_input_chunk
)
if chunk_grad_bias.shape != grad_bias.shape:
# Ensure we're summing to match the vocab size dimension
chunk_grad_bias = chunk_grad_bias.view(-1, grad_bias.shape[0]).sum(0)
grad_bias.add_(chunk_grad_bias)
else:
(chunk_grad_input, chunk_grad_weight), (chunk_loss, chunk_metrics) = fused_fwd_bwd(
input_chunk, attention_mask_chunk, ref_input_chunk
input_chunk, attention_mask_chunk, rewards_chunk, ref_input_chunk
)

# Accumulate gradients and loss
Expand All @@ -92,21 +104,22 @@ def accumulate_chunk(input_chunk, attention_mask_chunk, ref_input_chunk=None):
accumulate_chunk = torch.compile(accumulate_chunk)

# Process input in chunks
chunks = max(1, _input.shape[0] // chunk_size)
chunks = max(1, _input.shape[0] // num_generations)
_input_chunks = torch.chunk(_input, chunks=chunks, dim=0)
_attention_mask_chunks = torch.chunk(attention_mask, chunks=chunks, dim=0)
_rewards_chunks = torch.chunk(rewards, chunks=chunks, dim=0)
_ref_input_chunks = torch.chunk(ref_input, chunks=chunks, dim=0) if use_ref_model else [None] * chunks

for input_chunk, attention_mask_chunk, ref_input_chunk in zip(
_input_chunks, _attention_mask_chunks, _ref_input_chunks
for input_chunk, attention_mask_chunk, rewards_chunk, ref_input_chunk in zip(
_input_chunks, _attention_mask_chunks, _rewards_chunks, _ref_input_chunks
):
# Mark dynamic dimensions
torch._dynamo.mark_dynamic(input_chunk, 1)
torch._dynamo.mark_dynamic(attention_mask_chunk, 1)
if ref_input_chunk is not None:
torch._dynamo.mark_dynamic(ref_input_chunk, 1)

accumulate_chunk(input_chunk, attention_mask_chunk, ref_input_chunk)
accumulate_chunk(input_chunk, attention_mask_chunk, rewards_chunk, ref_input_chunk)

# Combine gradients
grad_input = torch.cat(grad_inputs, dim=0)
Expand All @@ -129,44 +142,30 @@ def _compute_chunk_loss(
input_chunk,
weight,
attention_mask_chunk,
rewards_chunk,
ref_input_chunk=None,
bias=None,
rlhf_loss_fn=None,
beta=0.1,
use_ref_model=False,
ref_weight=None,
ref_bias=None,
rewards=None,
rlhf_loss_fn=None,
):
"""Compute loss for a single chunk."""
# Get policy log probabilities using chunk_forward
(
log_probs,
logits,
logits_mean,
) = LigerFusedLinearRLHFBase.chunk_forward(
input_chunk,
weight,
attention_mask_chunk,
bias=bias,
)
(log_probs, _, logits_mean) = LigerFusedLinearRLHFBase.chunk_forward(input_chunk, weight, bias=bias)

# Get reference log probabilities if needed
ref_log_probs = None
if use_ref_model and ref_input_chunk is not None:
with torch.no_grad():
ref_log_probs, _, _ = LigerFusedLinearRLHFBase.chunk_forward(
ref_input_chunk,
ref_weight,
attention_mask_chunk,
bias=ref_bias,
)
ref_log_probs, _, _ = LigerFusedLinearRLHFBase.chunk_forward(ref_input_chunk, ref_weight, bias=ref_bias)

# Compute chunk loss and metrics
chunk_loss, chunk_metrics = rlhf_loss_fn(
log_probs=log_probs,
attention_mask=attention_mask_chunk,
rewards=rewards,
rewards=rewards_chunk,
ref_log_probs=ref_log_probs,
beta=beta,
)
Expand All @@ -177,7 +176,6 @@ def _compute_chunk_loss(
def chunk_forward(
input_chunk,
weight,
attention_mask_chunk,
bias=None,
):
"""Forward pass computation for a single chunk."""
Expand All @@ -191,7 +189,7 @@ def chunk_forward(

# Reshape to [B, T, V] and compute log_probs
logits = logits.view(batch_size, seq_len, -1)
log_probs = F.log_softmax(logits, dim=-1)
log_probs = F.log_softmax(logits.float(), dim=-1)

# Calculate mean logits for monitoring
logits_mean = logits.sum() / (batch_size * seq_len * weight.shape[0])
Expand Down
14 changes: 11 additions & 3 deletions src/liger_kernel/chunked_loss/grpo_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,9 +27,11 @@ def rlhf_loss_fn(
else:
ref_token_logprobs = chosen_token_logprobs.detach()

# Compute advantages
mean_grouped_rewards = rewards.mean()
std_grouped_rewards = rewards.std()
# Compute advantages per batch entry in a grouped fashion
mean_grouped_rewards = rewards.mean() # [batch_size,]
std_grouped_rewards = rewards.std() # [batch_size,]

# Calculate advantages using the same epsilon as in GRPOTrainer
advantages = (rewards - mean_grouped_rewards) / (std_grouped_rewards + 1e-8)

# Compute policy gradient loss with importance sampling ratio
Expand Down Expand Up @@ -74,6 +76,7 @@ def forward(
beta=0.1,
compiled=True,
use_ref_model=True,
num_generations=1,
):
return LigerFusedLinearRLHFBase.forward(
ctx=ctx,
Expand All @@ -89,6 +92,7 @@ def forward(
beta=beta,
compiled=compiled,
use_ref_model=use_ref_model,
num_generations=num_generations,
)

@staticmethod
Expand All @@ -109,6 +113,7 @@ def backward(ctx, grad_output, *grad_metrics):
None, # grad_beta
None, # grad_compiled
None, # grad_use_ref_model
None, # grad_num_generations
)


Expand All @@ -120,11 +125,13 @@ def __init__(
beta: float = 0.1,
compiled: bool = True,
use_ref_model: bool = True,
num_generations: int = 1,
):
super().__init__()
self.beta = beta
self.compiled = compiled
self.use_ref_model = use_ref_model
self.num_generations = num_generations

def forward(
self,
Expand All @@ -149,4 +156,5 @@ def forward(
self.beta,
self.compiled,
self.use_ref_model,
self.num_generations,
)
85 changes: 52 additions & 33 deletions test/chunked_loss/test_grpo_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,24 +21,29 @@ def __init__(
dtype: torch.dtype,
bias: bool = False,
beta: float = 0.1,
num_generations: int = 4,
):
super().__init__()
self.lin = torch.nn.Linear(in_features=H, out_features=V, bias=bias, dtype=dtype)
self.beta = beta
self.num_generations = num_generations

def forward(
self,
x,
attention_mask,
rewards,
ref_input=None,
x, # Shape: [batch_size*num_generations, seq_len, hidden_size]
attention_mask, # Shape: [batch_size*num_generations, seq_len]
rewards, # Shape: [batch_size*num_generations,]
ref_input=None, # Shape: [batch_size*num_generations, seq_len, hidden_size]
ref_weight=None,
ref_bias=None,
):
# Forward pass through linear layer
batch_size, seq_len, hidden_size = x.shape
batch_size = x.shape[0] // self.num_generations # Get true batch size
seq_len = x.shape[1]
hidden_size = x.shape[2]

input_reshaped = x.view(-1, hidden_size)
logits = (input_reshaped @ self.lin.weight.t()).view(batch_size, seq_len, -1)
logits = (input_reshaped @ self.lin.weight.t()).view(batch_size * self.num_generations, seq_len, -1)
if self.lin.bias is not None:
logits = logits + self.lin.bias

Expand All @@ -53,36 +58,44 @@ def forward(
if ref_input is not None and ref_weight is not None:
with torch.no_grad():
ref_input_reshaped = ref_input.view(-1, ref_input.size(-1))
ref_logits = (ref_input_reshaped @ ref_weight.t()).view(batch_size, seq_len, -1)
ref_logits = (ref_input_reshaped @ ref_weight.t()).view(batch_size * self.num_generations, seq_len, -1)
if ref_bias is not None:
ref_logits = ref_logits + ref_bias
ref_log_probs = F.log_softmax(ref_logits, dim=-1)
ref_token_logprobs = ref_log_probs.gather(dim=-1, index=chosen_tokens.unsqueeze(-1)).squeeze(-1)
else:
ref_token_logprobs = chosen_token_logprobs.detach()

# Compute advantages (exactly as in GRPOTrainer)
mean_grouped_rewards = rewards.mean()
std_grouped_rewards = rewards.std()
advantages = (rewards - mean_grouped_rewards) / (std_grouped_rewards + 1e-8)

# Compute policy gradient loss with importance sampling ratio
ratio = torch.exp(chosen_token_logprobs - chosen_token_logprobs.detach())
policy_loss = -ratio * advantages.unsqueeze(1)

# Compute KL penalty
# Compute KL divergence between model and reference model
kl_div = (
torch.exp(ref_token_logprobs - chosen_token_logprobs) - (ref_token_logprobs - chosen_token_logprobs) - 1.0
)

# Combine losses
per_token_loss = policy_loss + self.beta * kl_div
# Compute advantages per batch entry in a grouped fashion
# rewards shape: [batch_size, num_generations]
mean_grouped_rewards = rewards.view(batch_size, self.num_generations).mean(
dim=1, keepdim=True
) # [batch_size, 1]
std_grouped_rewards = rewards.view(batch_size, self.num_generations).std(dim=1, keepdim=True) # [batch_size, 1]

# Expand means and stds to match the number of generations
mean_grouped_rewards = mean_grouped_rewards.expand(-1, self.num_generations).reshape(
-1
) # [batch_size * num_generations]
std_grouped_rewards = std_grouped_rewards.expand(-1, self.num_generations).reshape(
-1
) # [batch_size * num_generations]

# Calculate advantages using the same epsilon as in GRPOTrainer
rewards_flat = rewards.view(-1) # [batch_size * num_generations]
advantages = (rewards_flat - mean_grouped_rewards) / (std_grouped_rewards + 1e-4)

# Compute policy gradient loss with importance sampling ratio
per_token_loss = torch.exp(chosen_token_logprobs - chosen_token_logprobs.detach()) * advantages.unsqueeze(1)
per_token_loss = -(per_token_loss - self.beta * kl_div)

# Apply masking and normalize
masked_loss = per_token_loss * attention_mask
seq_lengths = attention_mask.sum(dim=1, keepdim=True)
seq_lengths = torch.clamp(seq_lengths, min=1.0)
loss = (masked_loss.sum(dim=1) / seq_lengths.squeeze(-1)).mean()
loss = ((per_token_loss * attention_mask).sum(dim=1) / attention_mask.sum(dim=1)).mean()

# Compute metrics
metrics = (
Expand All @@ -103,11 +116,13 @@ def __init__(
dtype: torch.dtype,
bias: bool = False,
beta: float = 0.1,
num_generations: int = 4,
):
super().__init__()
self.lin = torch.nn.Linear(in_features=H, out_features=V, bias=bias, dtype=dtype)
self.grpo_loss = LigerFusedLinearGRPOFunction.apply
self.beta = beta
self.num_generations = num_generations

def forward(
self,
Expand All @@ -131,6 +146,7 @@ def forward(
self.beta, # beta
True, # compiled
ref_input is not None, # use_ref_model
self.num_generations, # num_generations
)


Expand Down Expand Up @@ -162,19 +178,22 @@ def test_correctness(
bias,
beta,
):
num_generations = 4 # Fixed number of generations for testing
torch_lm_head_grpo = TorchLMHeadGRPO(
H=H,
V=V,
dtype=dtype,
bias=bias,
beta=beta,
num_generations=num_generations,
)
liger_lm_head_grpo = LigerLMHeadGRPO(
H=H,
V=V,
dtype=dtype,
bias=bias,
beta=beta,
num_generations=num_generations,
)

# Initialize weights
Expand All @@ -184,22 +203,22 @@ def test_correctness(
if bias:
torch_lm_head_grpo.lin.bias.data = liger_lm_head_grpo.lin.bias.data = torch.randn(V, device=device, dtype=dtype)

# Create inputs
_input = torch.randn(B, T, H, device=device, dtype=dtype) * scalar
# Create inputs with shape [B*num_generations, T, H]
_input = torch.randn(B * num_generations, T, H, device=device, dtype=dtype) * scalar
input1 = _input.detach().clone().requires_grad_(True)
input2 = _input.detach().clone().requires_grad_(True)

# Create attention mask with random padding
attention_mask = torch.ones(B, T, device=device)
num_elements_to_mask = torch.randint(1, B * T // 2, (1,)).item()
mask_indices = torch.randperm(B * T)[:num_elements_to_mask]
# Create attention mask with random padding [B*num_generations, T]
attention_mask = torch.ones(B * num_generations, T, device=device)
num_elements_to_mask = torch.randint(1, B * num_generations * T // 2, (1,)).item()
mask_indices = torch.randperm(B * num_generations * T)[:num_elements_to_mask]
attention_mask.view(-1)[mask_indices] = 0

# Create rewards with random values
rewards = torch.randn(B, device=device, dtype=dtype)
# Create rewards with shape [B, num_generations]
rewards = torch.randn(B * num_generations, device=device, dtype=dtype)

# Create reference inputs (optional)
ref_input = torch.randn(B, T, H, device=device, dtype=dtype) * scalar
# Create reference inputs (optional) with shape [B*num_generations, T, H]
ref_input = torch.randn(B * num_generations, T, H, device=device, dtype=dtype) * scalar
ref_weight = torch.randn(V, H, device=device, dtype=dtype)
ref_bias = torch.randn(V, device=device, dtype=dtype) if bias else None

Expand Down

0 comments on commit f5f3157

Please sign in to comment.