Skip to content

Commit

Permalink
🔀 Rename get_batch_sample and add num_items_in_batch to `compute_…
Browse files Browse the repository at this point in the history
…loss` (#2246)
  • Loading branch information
qgallouedec authored Oct 18, 2024
1 parent b9aa965 commit 31b7820
Show file tree
Hide file tree
Showing 10 changed files with 30 additions and 16 deletions.
5 changes: 3 additions & 2 deletions trl/trainer/bco_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1260,6 +1260,7 @@ def compute_loss(
model: Union[PreTrainedModel, nn.Module],
inputs: Dict[str, Union[torch.Tensor, Any]],
return_outputs=False,
num_items_in_batch=None,
) -> Union[torch.Tensor, Tuple[torch.Tensor, Dict[str, torch.Tensor]]]:
if not self.use_dpo_data_collator:
warnings.warn(
Expand Down Expand Up @@ -1290,7 +1291,7 @@ def _get_train_sampler(self) -> Optional[torch.utils.data.Sampler]:
return None
return SequentialSampler(self.train_dataset)

def get_batch_samples(self, model, batch: Dict[str, torch.LongTensor]) -> Tuple[str, str]:
def generate_from_model_and_ref(self, model, batch: Dict[str, torch.LongTensor]) -> Tuple[str, str]:
"""Generate samples from the model and reference model for the given batch of inputs."""

# If one uses `generate_during_eval` with peft + bf16, we need to explicitly call generate with
Expand Down Expand Up @@ -1407,7 +1408,7 @@ def evaluation_loop(
"prompt_attention_mask": itemgetter(*target_indicies)(random_batch["prompt_attention_mask"]),
"prompt": itemgetter(*target_indicies)(random_batch["prompt"]),
}
policy_output_decoded, ref_output_decoded = self.get_batch_samples(self.model, target_batch)
policy_output_decoded, ref_output_decoded = self.generate_from_model_and_ref(self.model, target_batch)

self.log(
{
Expand Down
5 changes: 3 additions & 2 deletions trl/trainer/cpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -828,6 +828,7 @@ def compute_loss(
model: Union[PreTrainedModel, nn.Module],
inputs: Dict[str, Union[torch.Tensor, Any]],
return_outputs=False,
num_items_in_batch=None,
) -> Union[torch.Tensor, Tuple[torch.Tensor, Dict[str, torch.Tensor]]]:
if not self.use_dpo_data_collator:
warnings.warn(
Expand All @@ -847,7 +848,7 @@ def compute_loss(
return (loss, metrics)
return loss

def get_batch_samples(self, model, batch: Dict[str, torch.LongTensor]) -> Tuple[str, str]:
def generate_from_model(self, model, batch: Dict[str, torch.LongTensor]) -> str:
"""Generate samples from the model and reference model for the given batch of inputs."""

# If one uses `generate_during_eval` with peft + bf16, we need to explicitly call generate with
Expand Down Expand Up @@ -938,7 +939,7 @@ def evaluation_loop(
random_batch = self.data_collator(random_batch_dataset)
random_batch = self._prepare_inputs(random_batch)

policy_output_decoded = self.get_batch_samples(self.model, random_batch)
policy_output_decoded = self.generate_from_model(self.model, random_batch)

self.log(
{
Expand Down
5 changes: 3 additions & 2 deletions trl/trainer/dpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1547,6 +1547,7 @@ def compute_loss(
model: Union[PreTrainedModel, nn.Module],
inputs: Dict[str, Union[torch.Tensor, Any]],
return_outputs=False,
num_items_in_batch=None,
) -> Union[torch.Tensor, Tuple[torch.Tensor, Dict[str, torch.Tensor]]]:
compute_loss_context_manager = amp.autocast("cuda") if self._peft_has_been_casted_to_bf16 else nullcontext()
with compute_loss_context_manager:
Expand All @@ -1561,7 +1562,7 @@ def compute_loss(
return (loss, metrics)
return loss

def get_batch_samples(self, model, batch: Dict[str, torch.LongTensor]) -> Tuple[str, str]:
def generate_from_model_and_ref(self, model, batch: Dict[str, torch.LongTensor]) -> Tuple[str, str]:
"""Generate samples from the model and reference model for the given batch of inputs."""

# If one uses `generate_during_eval` with peft + bf16, we need to explicitly call generate with
Expand Down Expand Up @@ -1672,7 +1673,7 @@ def evaluation_loop(
random_batch = self.data_collator(random_batch_dataset)
random_batch = self._prepare_inputs(random_batch)

policy_output_decoded, ref_output_decoded = self.get_batch_samples(self.model, random_batch)
policy_output_decoded, ref_output_decoded = self.generate_from_model_and_ref(self.model, random_batch)

self.log(
{
Expand Down
8 changes: 5 additions & 3 deletions trl/trainer/gkd_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,7 +215,7 @@ def generalized_jsd_loss(
else:
return jsd

def compute_loss(self, model, inputs, return_outputs=False):
def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None):
# compute student output
outputs_student = model(
input_ids=inputs["input_ids"],
Expand Down Expand Up @@ -273,7 +273,9 @@ def generate_on_policy_outputs(model, inputs, generation_config, pad_token_id=No

return generated_tokens, new_attention_mask, new_labels

def training_step(self, model: nn.Module, inputs: Dict[str, Union[torch.Tensor, Any]]) -> torch.Tensor:
def training_step(
self, model: nn.Module, inputs: Dict[str, Union[torch.Tensor, Any]], num_items_in_batch: Optional[int] = None
) -> torch.Tensor:
"""
Perform a training step for the Generalized Knowledge Distillation (GKD) model.
Expand All @@ -298,7 +300,7 @@ def training_step(self, model: nn.Module, inputs: Dict[str, Union[torch.Tensor,
inputs["attention_mask"] = new_attention_mask
inputs["labels"] = new_labels

loss = super().training_step(model, inputs)
loss = super().training_step(model, inputs, num_items_in_batch)
return loss

def _prepare_deepspeed(self, model: PreTrainedModelWrapper):
Expand Down
5 changes: 3 additions & 2 deletions trl/trainer/kto_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1234,6 +1234,7 @@ def compute_loss(
model: Union[PreTrainedModel, nn.Module],
inputs: Dict[str, Union[torch.Tensor, Any]],
return_outputs=False,
num_items_in_batch=None,
) -> Union[torch.Tensor, Tuple[torch.Tensor, Dict[str, torch.Tensor]]]:
if not self.use_dpo_data_collator:
warnings.warn(
Expand Down Expand Up @@ -1264,7 +1265,7 @@ def _get_train_sampler(self) -> Optional[torch.utils.data.Sampler]:
return None
return SequentialSampler(self.train_dataset)

def get_batch_samples(self, model, batch: Dict[str, torch.LongTensor]) -> Tuple[str, str]:
def generate_from_model_and_ref(self, model, batch: Dict[str, torch.LongTensor]) -> Tuple[str, str]:
"""Generate samples from the model and reference model for the given batch of inputs."""

# If one uses `generate_during_eval` with peft + bf16, we need to explicitly call generate with
Expand Down Expand Up @@ -1383,7 +1384,7 @@ def evaluation_loop(
"prompt_attention_mask": random_batch["prompt_attention_mask"][target_indicies],
"prompt": itemgetter(*target_indicies)(random_batch["prompt"]),
}
policy_output_decoded, ref_output_decoded = self.get_batch_samples(self.model, target_batch)
policy_output_decoded, ref_output_decoded = self.generate_from_model_and_ref(self.model, target_batch)

self.log(
{
Expand Down
4 changes: 3 additions & 1 deletion trl/trainer/nash_md_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -328,7 +328,9 @@ def gather_mean(tensor):
self.stats["beta"].append(self.beta)
self.stats["mixture_coef"].append(self.mixture_coef)

def training_step(self, model: nn.Module, inputs: Dict[str, Union[torch.Tensor, Any]]) -> torch.Tensor:
def training_step(
self, model: nn.Module, inputs: Dict[str, Union[torch.Tensor, Any]], num_items_in_batch: Optional[int] = None
) -> torch.Tensor:
model.train()

# Apply chat template and tokenize the input
Expand Down
4 changes: 3 additions & 1 deletion trl/trainer/online_dpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -366,7 +366,9 @@ def get_eval_dataloader(self, eval_dataset: Optional[Union[str, Dataset]] = None

return self.accelerator.prepare(eval_dataloader)

def training_step(self, model: nn.Module, inputs: Dict[str, Union[torch.Tensor, Any]]) -> torch.Tensor:
def training_step(
self, model: nn.Module, inputs: Dict[str, Union[torch.Tensor, Any]], num_items_in_batch: Optional[int] = None
) -> torch.Tensor:
model.train()

# Apply chat template and tokenize the input.
Expand Down
5 changes: 3 additions & 2 deletions trl/trainer/orpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -844,6 +844,7 @@ def compute_loss(
model: Union[PreTrainedModel, nn.Module],
inputs: Dict[str, Union[torch.Tensor, Any]],
return_outputs=False,
num_items_in_batch=None,
) -> Union[torch.Tensor, Tuple[torch.Tensor, Dict[str, torch.Tensor]]]:
if not self.use_dpo_data_collator:
warnings.warn(
Expand All @@ -866,7 +867,7 @@ def compute_loss(
return (loss, metrics)
return loss

def get_batch_samples(self, model, batch: Dict[str, torch.LongTensor]) -> Tuple[str, str]:
def generate_from_model(self, model, batch: Dict[str, torch.LongTensor]) -> str:
"""Generate samples from the model and reference model for the given batch of inputs."""

# If one uses `generate_during_eval` with peft + bf16, we need to explicitly call generate with
Expand Down Expand Up @@ -957,7 +958,7 @@ def evaluation_loop(
random_batch = self.data_collator(random_batch_dataset)
random_batch = self._prepare_inputs(random_batch)

policy_output_decoded = self.get_batch_samples(self.model, random_batch)
policy_output_decoded = self.generate_from_model(self.model, random_batch)

self.log(
{
Expand Down
1 change: 1 addition & 0 deletions trl/trainer/reward_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -266,6 +266,7 @@ def compute_loss(
model: Union[PreTrainedModel, nn.Module],
inputs: Dict[str, Union[torch.Tensor, Any]],
return_outputs=False,
num_items_in_batch=None,
) -> Union[torch.Tensor, Tuple[torch.Tensor, Dict[str, torch.Tensor]]]:
if not self.use_reward_data_collator:
warnings.warn(
Expand Down
4 changes: 3 additions & 1 deletion trl/trainer/xpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -377,7 +377,9 @@ def gather_mean(tensor):
self.stats["alpha"].append(self.alpha)
self.stats["beta"].append(self.beta)

def training_step(self, model: nn.Module, inputs: Dict[str, Union[torch.Tensor, Any]]) -> torch.Tensor:
def training_step(
self, model: nn.Module, inputs: Dict[str, Union[torch.Tensor, Any]], num_items_in_batch: Optional[int] = None
) -> torch.Tensor:
model.train()

# Apply chat template and tokenize the input
Expand Down

0 comments on commit 31b7820

Please sign in to comment.