From 31b7820aade7cefa06e3d4160dcff8a602a14850 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= <45557362+qgallouedec@users.noreply.github.com> Date: Fri, 18 Oct 2024 21:02:24 +0200 Subject: [PATCH] =?UTF-8?q?=F0=9F=94=80=20Rename=20`get=5Fbatch=5Fsample`?= =?UTF-8?q?=20and=20add=20`num=5Fitems=5Fin=5Fbatch`=20to=20`compute=5Flos?= =?UTF-8?q?s`=20(#2246)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- trl/trainer/bco_trainer.py | 5 +++-- trl/trainer/cpo_trainer.py | 5 +++-- trl/trainer/dpo_trainer.py | 5 +++-- trl/trainer/gkd_trainer.py | 8 +++++--- trl/trainer/kto_trainer.py | 5 +++-- trl/trainer/nash_md_trainer.py | 4 +++- trl/trainer/online_dpo_trainer.py | 4 +++- trl/trainer/orpo_trainer.py | 5 +++-- trl/trainer/reward_trainer.py | 1 + trl/trainer/xpo_trainer.py | 4 +++- 10 files changed, 30 insertions(+), 16 deletions(-) diff --git a/trl/trainer/bco_trainer.py b/trl/trainer/bco_trainer.py index 91461a9b0d..c6ce2d4902 100644 --- a/trl/trainer/bco_trainer.py +++ b/trl/trainer/bco_trainer.py @@ -1260,6 +1260,7 @@ def compute_loss( model: Union[PreTrainedModel, nn.Module], inputs: Dict[str, Union[torch.Tensor, Any]], return_outputs=False, + num_items_in_batch=None, ) -> Union[torch.Tensor, Tuple[torch.Tensor, Dict[str, torch.Tensor]]]: if not self.use_dpo_data_collator: warnings.warn( @@ -1290,7 +1291,7 @@ def _get_train_sampler(self) -> Optional[torch.utils.data.Sampler]: return None return SequentialSampler(self.train_dataset) - def get_batch_samples(self, model, batch: Dict[str, torch.LongTensor]) -> Tuple[str, str]: + def generate_from_model_and_ref(self, model, batch: Dict[str, torch.LongTensor]) -> Tuple[str, str]: """Generate samples from the model and reference model for the given batch of inputs.""" # If one uses `generate_during_eval` with peft + bf16, we need to explicitly call generate with @@ -1407,7 +1408,7 @@ def evaluation_loop( "prompt_attention_mask": itemgetter(*target_indicies)(random_batch["prompt_attention_mask"]), "prompt": itemgetter(*target_indicies)(random_batch["prompt"]), } - policy_output_decoded, ref_output_decoded = self.get_batch_samples(self.model, target_batch) + policy_output_decoded, ref_output_decoded = self.generate_from_model_and_ref(self.model, target_batch) self.log( { diff --git a/trl/trainer/cpo_trainer.py b/trl/trainer/cpo_trainer.py index 5847cb182b..5e74fdaceb 100644 --- a/trl/trainer/cpo_trainer.py +++ b/trl/trainer/cpo_trainer.py @@ -828,6 +828,7 @@ def compute_loss( model: Union[PreTrainedModel, nn.Module], inputs: Dict[str, Union[torch.Tensor, Any]], return_outputs=False, + num_items_in_batch=None, ) -> Union[torch.Tensor, Tuple[torch.Tensor, Dict[str, torch.Tensor]]]: if not self.use_dpo_data_collator: warnings.warn( @@ -847,7 +848,7 @@ def compute_loss( return (loss, metrics) return loss - def get_batch_samples(self, model, batch: Dict[str, torch.LongTensor]) -> Tuple[str, str]: + def generate_from_model(self, model, batch: Dict[str, torch.LongTensor]) -> str: """Generate samples from the model and reference model for the given batch of inputs.""" # If one uses `generate_during_eval` with peft + bf16, we need to explicitly call generate with @@ -938,7 +939,7 @@ def evaluation_loop( random_batch = self.data_collator(random_batch_dataset) random_batch = self._prepare_inputs(random_batch) - policy_output_decoded = self.get_batch_samples(self.model, random_batch) + policy_output_decoded = self.generate_from_model(self.model, random_batch) self.log( { diff --git a/trl/trainer/dpo_trainer.py b/trl/trainer/dpo_trainer.py index 8b9843cd91..082a627ce0 100644 --- a/trl/trainer/dpo_trainer.py +++ b/trl/trainer/dpo_trainer.py @@ -1547,6 +1547,7 @@ def compute_loss( model: Union[PreTrainedModel, nn.Module], inputs: Dict[str, Union[torch.Tensor, Any]], return_outputs=False, + num_items_in_batch=None, ) -> Union[torch.Tensor, Tuple[torch.Tensor, Dict[str, torch.Tensor]]]: compute_loss_context_manager = amp.autocast("cuda") if self._peft_has_been_casted_to_bf16 else nullcontext() with compute_loss_context_manager: @@ -1561,7 +1562,7 @@ def compute_loss( return (loss, metrics) return loss - def get_batch_samples(self, model, batch: Dict[str, torch.LongTensor]) -> Tuple[str, str]: + def generate_from_model_and_ref(self, model, batch: Dict[str, torch.LongTensor]) -> Tuple[str, str]: """Generate samples from the model and reference model for the given batch of inputs.""" # If one uses `generate_during_eval` with peft + bf16, we need to explicitly call generate with @@ -1672,7 +1673,7 @@ def evaluation_loop( random_batch = self.data_collator(random_batch_dataset) random_batch = self._prepare_inputs(random_batch) - policy_output_decoded, ref_output_decoded = self.get_batch_samples(self.model, random_batch) + policy_output_decoded, ref_output_decoded = self.generate_from_model_and_ref(self.model, random_batch) self.log( { diff --git a/trl/trainer/gkd_trainer.py b/trl/trainer/gkd_trainer.py index 1b7c77557d..49e93e269b 100644 --- a/trl/trainer/gkd_trainer.py +++ b/trl/trainer/gkd_trainer.py @@ -215,7 +215,7 @@ def generalized_jsd_loss( else: return jsd - def compute_loss(self, model, inputs, return_outputs=False): + def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None): # compute student output outputs_student = model( input_ids=inputs["input_ids"], @@ -273,7 +273,9 @@ def generate_on_policy_outputs(model, inputs, generation_config, pad_token_id=No return generated_tokens, new_attention_mask, new_labels - def training_step(self, model: nn.Module, inputs: Dict[str, Union[torch.Tensor, Any]]) -> torch.Tensor: + def training_step( + self, model: nn.Module, inputs: Dict[str, Union[torch.Tensor, Any]], num_items_in_batch: Optional[int] = None + ) -> torch.Tensor: """ Perform a training step for the Generalized Knowledge Distillation (GKD) model. @@ -298,7 +300,7 @@ def training_step(self, model: nn.Module, inputs: Dict[str, Union[torch.Tensor, inputs["attention_mask"] = new_attention_mask inputs["labels"] = new_labels - loss = super().training_step(model, inputs) + loss = super().training_step(model, inputs, num_items_in_batch) return loss def _prepare_deepspeed(self, model: PreTrainedModelWrapper): diff --git a/trl/trainer/kto_trainer.py b/trl/trainer/kto_trainer.py index ab9ba87e41..7f32424812 100644 --- a/trl/trainer/kto_trainer.py +++ b/trl/trainer/kto_trainer.py @@ -1234,6 +1234,7 @@ def compute_loss( model: Union[PreTrainedModel, nn.Module], inputs: Dict[str, Union[torch.Tensor, Any]], return_outputs=False, + num_items_in_batch=None, ) -> Union[torch.Tensor, Tuple[torch.Tensor, Dict[str, torch.Tensor]]]: if not self.use_dpo_data_collator: warnings.warn( @@ -1264,7 +1265,7 @@ def _get_train_sampler(self) -> Optional[torch.utils.data.Sampler]: return None return SequentialSampler(self.train_dataset) - def get_batch_samples(self, model, batch: Dict[str, torch.LongTensor]) -> Tuple[str, str]: + def generate_from_model_and_ref(self, model, batch: Dict[str, torch.LongTensor]) -> Tuple[str, str]: """Generate samples from the model and reference model for the given batch of inputs.""" # If one uses `generate_during_eval` with peft + bf16, we need to explicitly call generate with @@ -1383,7 +1384,7 @@ def evaluation_loop( "prompt_attention_mask": random_batch["prompt_attention_mask"][target_indicies], "prompt": itemgetter(*target_indicies)(random_batch["prompt"]), } - policy_output_decoded, ref_output_decoded = self.get_batch_samples(self.model, target_batch) + policy_output_decoded, ref_output_decoded = self.generate_from_model_and_ref(self.model, target_batch) self.log( { diff --git a/trl/trainer/nash_md_trainer.py b/trl/trainer/nash_md_trainer.py index db0c3046b3..73aab7899a 100644 --- a/trl/trainer/nash_md_trainer.py +++ b/trl/trainer/nash_md_trainer.py @@ -328,7 +328,9 @@ def gather_mean(tensor): self.stats["beta"].append(self.beta) self.stats["mixture_coef"].append(self.mixture_coef) - def training_step(self, model: nn.Module, inputs: Dict[str, Union[torch.Tensor, Any]]) -> torch.Tensor: + def training_step( + self, model: nn.Module, inputs: Dict[str, Union[torch.Tensor, Any]], num_items_in_batch: Optional[int] = None + ) -> torch.Tensor: model.train() # Apply chat template and tokenize the input diff --git a/trl/trainer/online_dpo_trainer.py b/trl/trainer/online_dpo_trainer.py index ffc407b57d..c480c61fc5 100644 --- a/trl/trainer/online_dpo_trainer.py +++ b/trl/trainer/online_dpo_trainer.py @@ -366,7 +366,9 @@ def get_eval_dataloader(self, eval_dataset: Optional[Union[str, Dataset]] = None return self.accelerator.prepare(eval_dataloader) - def training_step(self, model: nn.Module, inputs: Dict[str, Union[torch.Tensor, Any]]) -> torch.Tensor: + def training_step( + self, model: nn.Module, inputs: Dict[str, Union[torch.Tensor, Any]], num_items_in_batch: Optional[int] = None + ) -> torch.Tensor: model.train() # Apply chat template and tokenize the input. diff --git a/trl/trainer/orpo_trainer.py b/trl/trainer/orpo_trainer.py index 4edbf9b1a5..123f935208 100644 --- a/trl/trainer/orpo_trainer.py +++ b/trl/trainer/orpo_trainer.py @@ -844,6 +844,7 @@ def compute_loss( model: Union[PreTrainedModel, nn.Module], inputs: Dict[str, Union[torch.Tensor, Any]], return_outputs=False, + num_items_in_batch=None, ) -> Union[torch.Tensor, Tuple[torch.Tensor, Dict[str, torch.Tensor]]]: if not self.use_dpo_data_collator: warnings.warn( @@ -866,7 +867,7 @@ def compute_loss( return (loss, metrics) return loss - def get_batch_samples(self, model, batch: Dict[str, torch.LongTensor]) -> Tuple[str, str]: + def generate_from_model(self, model, batch: Dict[str, torch.LongTensor]) -> str: """Generate samples from the model and reference model for the given batch of inputs.""" # If one uses `generate_during_eval` with peft + bf16, we need to explicitly call generate with @@ -957,7 +958,7 @@ def evaluation_loop( random_batch = self.data_collator(random_batch_dataset) random_batch = self._prepare_inputs(random_batch) - policy_output_decoded = self.get_batch_samples(self.model, random_batch) + policy_output_decoded = self.generate_from_model(self.model, random_batch) self.log( { diff --git a/trl/trainer/reward_trainer.py b/trl/trainer/reward_trainer.py index 787c6cbd54..0ebdee68b4 100644 --- a/trl/trainer/reward_trainer.py +++ b/trl/trainer/reward_trainer.py @@ -266,6 +266,7 @@ def compute_loss( model: Union[PreTrainedModel, nn.Module], inputs: Dict[str, Union[torch.Tensor, Any]], return_outputs=False, + num_items_in_batch=None, ) -> Union[torch.Tensor, Tuple[torch.Tensor, Dict[str, torch.Tensor]]]: if not self.use_reward_data_collator: warnings.warn( diff --git a/trl/trainer/xpo_trainer.py b/trl/trainer/xpo_trainer.py index 0255e6206f..a154875821 100644 --- a/trl/trainer/xpo_trainer.py +++ b/trl/trainer/xpo_trainer.py @@ -377,7 +377,9 @@ def gather_mean(tensor): self.stats["alpha"].append(self.alpha) self.stats["beta"].append(self.beta) - def training_step(self, model: nn.Module, inputs: Dict[str, Union[torch.Tensor, Any]]) -> torch.Tensor: + def training_step( + self, model: nn.Module, inputs: Dict[str, Union[torch.Tensor, Any]], num_items_in_batch: Optional[int] = None + ) -> torch.Tensor: model.train() # Apply chat template and tokenize the input