Skip to content

Commit

Permalink
Fixed ref model not used in PPO generation (#1534)
Browse files Browse the repository at this point in the history
  • Loading branch information
ejmejm authored Apr 17, 2024
1 parent edf60e8 commit 3bbe7e0
Show file tree
Hide file tree
Showing 2 changed files with 50 additions and 1 deletion.
49 changes: 49 additions & 0 deletions tests/test_ppo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1142,6 +1142,55 @@ def test_generation(self):

assert generations_single == generations_batched

def test_generation_with_ref_model(self):
dummy_dataset = self._init_dummy_dataset()
model = AutoModelForCausalLMWithValueHead.from_pretrained("gpt2")
tokenizer = AutoTokenizer.from_pretrained("gpt2")

# Negate the weights in the last layer of the ref model so it never
# outputs the same things as the primary model
ref_model = copy.deepcopy(model)
lm_head_weight = ref_model.pretrained_model.lm_head.weight
lm_head_weight.data = -lm_head_weight.data

ppo_trainer = PPOTrainer(
config=self.ppo_config,
model=model,
ref_model=ref_model,
tokenizer=tokenizer,
dataset=dummy_dataset,
)

input_texts = ["this is a test", "this is another, longer test"]

generation_kwargs = {"do_sample": False, "max_new_tokens": 4, "pad_token_id": tokenizer.eos_token_id}

tokenizer.pad_token = tokenizer.eos_token

model_inputs = [tokenizer(txt, return_tensors="pt").input_ids.squeeze() for txt in input_texts]

generations_batched, ref_generations_batched = ppo_trainer.generate(
model_inputs, batch_size=2, generate_ref_response=True, **generation_kwargs
)
generations_batched = tokenizer.batch_decode(generations_batched)
ref_generations_batched = tokenizer.batch_decode(ref_generations_batched)

generations_single = []
ref_generations_single = []
for inputs in model_inputs:
generation, ref_generation = ppo_trainer.generate(inputs, generate_ref_response=True, **generation_kwargs)
generations_single.append(generation.squeeze())
ref_generations_single.append(ref_generation.squeeze())

generations_single = tokenizer.batch_decode(generations_single)
ref_generations_single = tokenizer.batch_decode(ref_generations_single)

assert generations_single == generations_batched
assert ref_generations_single == ref_generations_batched

assert generations_batched != ref_generations_batched
assert generations_single != ref_generations_single

def test_grad_accumulation(self):
dummy_dataset = self._init_dummy_dataset()

Expand Down
2 changes: 1 addition & 1 deletion trl/trainer/ppo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -498,7 +498,7 @@ def generate(

if generate_ref_response:
with unwrap_model_for_generation(
self.model, self.accelerator, is_peft_model=self.is_peft_model
ref_model, self.accelerator, is_peft_model=self.is_peft_model
) as unwrapped_model:
ref_response = unwrapped_model.generate(
input_ids=query_tensor.unsqueeze(dim=0), **generation_kwargs
Expand Down

0 comments on commit 3bbe7e0

Please sign in to comment.