diff --git a/trl/trainer/ppo_trainer.py b/trl/trainer/ppo_trainer.py index ef29461a70..83926cfd6a 100644 --- a/trl/trainer/ppo_trainer.py +++ b/trl/trainer/ppo_trainer.py @@ -138,10 +138,18 @@ def __init__( if data_collator is None: data_collator = DataCollatorWithPadding(self.processing_class) - self.policy_model.generation_config.eos_token_id = ( - None # disable `pad_token_id` and `eos_token_id` because we just want to - ) - self.policy_model.generation_config.pad_token_id = None # generate tokens without truncation / padding + # Handle stop token settings: update policy model's generation_config to use provided stop token + if args.stop_token and args.stop_token_id: + raise ValueError("You cannot set both `stop_token` and `stop_token_id`.") + elif args.stop_token: + if args.stop_token == "eos": + self.policy_model.generation_config.eos_token_id = self.stop_token_id = processing_class.eos_token_id + else: + raise ValueError( + f"Unknown `stop_token` {args.stop_token}. Allowed values are: `'eos'` and `None` (no stop token)." + ) + else: + self.policy_model.generation_config.eos_token_id = self.stop_token_id = args.stop_token_id # None or int # peft support if not is_peft_available() and peft_config is not None: @@ -220,8 +228,6 @@ def __init__( for module in [self.policy_model, self.ref_model, self.value_model, self.reward_model]: if module is not None: disable_dropout_in_model(module) - if args.stop_token and args.stop_token == "eos": - args.stop_token_id = processing_class.eos_token_id self.model = PolicyAndValueWrapper(self.policy_model, self.value_model) self.model.config = self.policy_model.config # needed for pushing to hub self.create_optimizer_and_scheduler( @@ -449,9 +455,9 @@ def repeat_generator(): # Response Processing 1. truncate response after the first occurrence of `stop_token_id` postprocessed_response = response - if args.stop_token_id is not None: # handle the edge case when stop_token_id exists but is 0 + if self.stop_token_id is not None: # handle the edge case when stop_token_id exists but is 0 postprocessed_response = truncate_response( - args.stop_token_id, processing_class.pad_token_id, response + self.stop_token_id, processing_class.pad_token_id, response ) # Response Processing 2. run reward model on the truncated responses @@ -706,9 +712,9 @@ def generate_completions(self, sampling: bool = False): ) response = query_response[:, context_length:] postprocessed_response = response - if args.stop_token_id is not None: # handle the edge case when stop_token_id exists but is 0 + if self.stop_token_id is not None: # handle the edge case when stop_token_id exists but is 0 postprocessed_response = truncate_response( - args.stop_token_id, processing_class.pad_token_id, response + self.stop_token_id, processing_class.pad_token_id, response ) table["query"].extend( gather_object(processing_class.batch_decode(query, skip_special_tokens=True)) diff --git a/trl/trainer/utils.py b/trl/trainer/utils.py index 1228dc7ece..719d952f1f 100644 --- a/trl/trainer/utils.py +++ b/trl/trainer/utils.py @@ -993,9 +993,15 @@ class OnPolicyConfig(TrainingArguments): response_length (`int`, *optional*, defaults to `53`): Length of the response. stop_token (`str` or `None`, *optional*, defaults to `None`): - Stop token. + Specifies the stop token to use for text generation. This parameter is mutually exclusive with + `stop_token_id`. + + - `None`: No stop token is applied, unless `stop_token_id` is specified. + - `'eos'`: Uses the tokenizer's `eos_token`. + stop_token_id (`int` or `None`, *optional*, defaults to `None`): - Truncation token id. + Specifies the ID of the stop token to use for text generation. If `None`, no stop token ID is applied, + unless `stop_token` is specified. This parameter is mutually exclusive with `stop_token`. temperature (`float`, *optional*, defaults to `0.7`): Sampling temperature. missing_eos_penalty (`float` or `None`, *optional*, defaults to `None`): @@ -1054,11 +1060,17 @@ class OnPolicyConfig(TrainingArguments): ) stop_token: Optional[Literal["eos"]] = field( default=None, - metadata={"help": "Stop token."}, + metadata={ + "help": "Specifies the stop token to use for text generation. This parameter is mutually exclusive with " + "`stop_token_id`." + }, ) stop_token_id: Optional[int] = field( default=None, - metadata={"help": "Truncation token id."}, + metadata={ + "help": "Specifies the ID of the stop token to use for text generation. If `None`, no stop token ID is " + "applied, unless `stop_token` is specified. This parameter is mutually exclusive with `stop_token`." + }, ) temperature: float = field( default=0.7,