diff --git a/trl/scripts/sft.py b/trl/scripts/sft.py index 2095df3074..764ca3c929 100644 --- a/trl/scripts/sft.py +++ b/trl/scripts/sft.py @@ -84,7 +84,8 @@ def main(script_args, training_args, model_args): tokenizer = AutoTokenizer.from_pretrained( model_args.model_name_or_path, trust_remote_code=model_args.trust_remote_code, use_fast=True ) - tokenizer.pad_token = tokenizer.eos_token + if tokenizer.pad_token is None: + tokenizer.pad_token = tokenizer.eos_token ################ # Dataset