Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

🔧 Optimize GRPO VRAM Usage by Computing Prompt Tokens Just Once #2669

Open
wants to merge 23 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 15 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 26 additions & 26 deletions trl/trainer/grpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@
from ..import_utils import is_vllm_available
from ..models import create_reference_model, prepare_deepspeed, unwrap_model_for_generation
from .grpo_config import GRPOConfig
from .utils import generate_model_card, get_comet_experiment_url, pad
from .utils import compute_logps_with_prompt_cache, generate_model_card, get_comet_experiment_url, pad


if is_peft_available():
Expand Down Expand Up @@ -418,37 +418,37 @@ def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=N
prompt_completion_ids = torch.cat([prompt_inputs_repeated, completion_ids], dim=1)
else:
# Regular generation path
with unwrap_model_for_generation(model, self.accelerator) as unwrapped_model:
with torch.no_grad(), unwrap_model_for_generation(model, self.accelerator) as unwrapped_model:
prompt_completion_ids = unwrapped_model.generate(
**prompt_inputs, generation_config=self.generation_config
)

prompt_length = prompt_inputs["input_ids"].size(1)
completion_ids = prompt_completion_ids[:, prompt_length:]

# Get the per-token log probabilities for the completions for the model and the reference model
def get_per_token_logps(model, input_ids, num_logits_to_keep):
# We add 1 to `num_logits_to_keep` because the last logits of the sequence is later excluded
logits = model(input_ids, num_logits_to_keep=num_logits_to_keep + 1).logits # (B, L, V)
logits = logits[:, :-1, :] # (B, L-1, V), exclude the last logit: it corresponds to the next token pred

# Compute the log probabilities for the input tokens. Use a loop to reduce memory peak.
per_token_logps = []
for logits_row, input_ids_row in zip(logits, input_ids[:, -num_logits_to_keep:]):
log_probs = logits_row.log_softmax(dim=-1)
token_log_prob = torch.gather(log_probs, dim=1, index=input_ids_row.unsqueeze(1)).squeeze(1)
per_token_logps.append(token_log_prob)
return torch.stack(per_token_logps)

num_logits_to_keep = completion_ids.size(1) # we only need to compute the logits for the completion tokens
per_token_logps = get_per_token_logps(model, prompt_completion_ids, num_logits_to_keep)

with torch.inference_mode():
if self.ref_model is not None:
ref_per_token_logps = get_per_token_logps(self.ref_model, prompt_completion_ids, num_logits_to_keep)
else:
with self.accelerator.unwrap_model(model).disable_adapter():
ref_per_token_logps = get_per_token_logps(model, prompt_completion_ids, num_logits_to_keep)
# Current policy logprobs (with grad)
per_token_logps = compute_logps_with_prompt_cache(
model=self.accelerator.unwrap_model(model),
andyl98 marked this conversation as resolved.
Show resolved Hide resolved
prompt_inputs=prompt_inputs,
completion_ids=completion_ids,
requires_grad_for_completion=True,
)

# Reference model logprobs (no grad)
if self.ref_model is not None:
ref_per_token_logps = compute_logps_with_prompt_cache(
model=self.ref_model,
prompt_inputs=prompt_inputs,
completion_ids=completion_ids,
requires_grad_for_completion=False,
)
else:
with self.accelerator.unwrap_model(model).disable_adapter():
ref_per_token_logps = compute_logps_with_prompt_cache(
model=model,
prompt_inputs=prompt_inputs,
completion_ids=completion_ids,
requires_grad_for_completion=False,
)

# Compute the KL divergence between the model and the reference model
per_token_kl = torch.exp(ref_per_token_logps - per_token_logps) - (ref_per_token_logps - per_token_logps) - 1
Expand Down
64 changes: 64 additions & 0 deletions trl/trainer/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1647,3 +1647,67 @@ def flush_left(mask: torch.Tensor, *tensors: torch.Tensor) -> tuple[torch.Tensor
return mask
else:
return mask, *tensors


def compute_logps_with_prompt_cache(
model: torch.nn.Module,
prompt_inputs: dict,
completion_ids: torch.LongTensor,
requires_grad_for_completion: bool = True,
):
"""
The method will compute the log probabilities of the completion tokens by using the prompt cache.
1) Forward pass on the prompt with torch.no_grad() to get `past_key_values`.
2) Forward pass (with or without grad) on the completion tokens using that cache.
3) Compute per-token log probabilities for the completion.

Args:
model (nn.Module): A causal LM (transformers.AutoModelForCausalLM) or similar.
prompt_inputs (dict): The dict of prompt tensors, e.g. {"input_ids", "attention_mask", ...}.
completion_ids (torch.LongTensor): Shape [B*G, completion_len].
requires_grad_for_completion (bool): Whether to enable gradient for the completion pass.

Returns:
per_token_logps (torch.FloatTensor): shape [B*G, completion_len],
where per_token_logps[i, t] is the logprob of ith completion's t-th completion token,
given all preceding tokens in the prompt + the partial completion up to t-1.
"""

# Get the batch size (B), number of completions (G), and completion length (C)
B = prompt_inputs["input_ids"].size(0)
G = completion_ids.size(0) // B
C = completion_ids.size(1)

# Forward pass over prompt tokens
with torch.no_grad():
prompt_out = model(**prompt_inputs, use_cache=True, num_logits_to_keep=1)

# Only keep the last prompt logit, immediately convert to log probabilities and expand to B*G
prompt_last_logps = prompt_out.logits.log_softmax(dim=-1).repeat_interleave(G, dim=0)

# Gather the these log probs as they relates to the first completion token
first_completion_token_logps = torch.gather(
prompt_last_logps, dim=-1, index=completion_ids[:, :1].unsqueeze(-1)
).squeeze(-1)

# Interleave the past key values for the G times
prompt_out.past_key_values.batch_repeat_interleave(G)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't have the full context, but if this shares memory per repeat (as an expand would) then perfect!

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I believe so!


# Forward pass over completion tokens (with or without grad)
with torch.set_grad_enabled(requires_grad_for_completion):
completion_out = model(
input_ids=completion_ids,
past_key_values=prompt_out.past_key_values,
logits_to_keep=C - 1, # keep all but the last logit
use_cache=False,
)

# Convert completions logits to logprobs
completion_token_logps = torch.gather(
completion_out.logits.log_softmax(dim=-1), dim=-1, index=completion_ids[:, 1:].unsqueeze(-1)
).squeeze(-1)

# Concat with the first_completion_token_logps
per_token_logps = torch.cat([first_completion_token_logps, completion_token_logps], dim=1)

return per_token_logps