Skip to content

Commit

Permalink
[pre-commit.ci] auto fixes from pre-commit.com hooks
Browse files Browse the repository at this point in the history
for more information, see https://pre-commit.ci

Signed-off-by: NeMo-Aligner CI <[email protected]>
  • Loading branch information
pre-commit-ci[bot] committed Feb 7, 2025
1 parent d3e55ee commit bfc81c6
Showing 1 changed file with 19 additions and 13 deletions.
32 changes: 19 additions & 13 deletions nemo_aligner/experimental/self_rewarding/self_rewarding.py
Original file line number Diff line number Diff line change
Expand Up @@ -597,7 +597,9 @@ def get_rewards(self, list_of_batches, prepare_for_inference=False):
orig_last_prompts = [b["orig_last_prompt"] for b in list_of_batches]
orig_responses = [b["orig_response"] for b in list_of_batches]
for _ in range(self.num_evals_to_average):
reward_responses, prompt_lengths, resp_lengths, is_end = self.get_generations(list_of_batches, prepare_for_inference=prepare_for_inference)
reward_responses, prompt_lengths, resp_lengths, is_end = self.get_generations(
list_of_batches, prepare_for_inference=prepare_for_inference
)
batch_prompts_str, batch_responses_str = [], []
for t, s, e in zip(reward_responses, prompt_lengths.tolist(), resp_lengths.tolist()):
prompt = self.tokenizer.ids_to_text(t[:s].tolist())
Expand Down Expand Up @@ -656,7 +658,9 @@ def get_rewards(self, list_of_batches, prepare_for_inference=False):
def get_rewards_meta(self, list_of_batches, prepare_for_inference=False):
reward_scores = [[] for _ in range(sum([len(b["prompt_lengths"]) for b in list_of_batches]))]
reward_scores = []
reward_responses, prompt_lengths, resp_lengths, is_end = self.get_generations(list_of_batches, prepare_for_inference=prepare_for_inference)
reward_responses, prompt_lengths, resp_lengths, is_end = self.get_generations(
list_of_batches, prepare_for_inference=prepare_for_inference
)
# if (
# torch.distributed.get_rank() == 0
# and torch.distributed.get_rank() == parallel_state.get_data_parallel_src_rank()
Expand Down Expand Up @@ -982,13 +986,15 @@ def augment_dataloader(self, dataloader):
prompt = self.tokenizer.ids_to_text(t[:s].tolist()).replace(
"<extra_id_0>System\n\n", ""
)
response = self.tokenizer.ids_to_text(t[s:e].tolist()).replace("\n<extra_id_1>", "")

response = self.tokenizer.ids_to_text(t[s:e].tolist()).replace(
"\n<extra_id_1>", ""
)

reward_prompt_str = self.template_fn(prompt=prompt, response=response)
reward_prompt = self.tokenizer.text_to_ids(reward_prompt_str)
if len(reward_prompt) > self.model.cfg.data.train_ds.max_seq_length:
#prompt_and_response = self.tokenizer.ids_to_text(t[:e].tolist())
#try:
# prompt_and_response = self.tokenizer.ids_to_text(t[:e].tolist())
# try:
"""
if self.cfg.trt_llm.get("model_type", "gptnext").lower() == "llama":
prompt_ft = re.findall(
Expand Down Expand Up @@ -1036,9 +1042,9 @@ def augment_dataloader(self, dataloader):
)
reward_prompt_str = self.template_fn(prompt=prompt_ft, response=response_ft)
reward_prompt = self.model.tokenizer.text_to_ids(reward_prompt_str)
#prompt = prompt_ft
#response = response_ft
'''
# prompt = prompt_ft
# response = response_ft
"""
except:
# print(f"*** TOO_LONG: {prompt_and_response}")
# overage = len(reward_prompt) - (self.model.cfg.encoder_seq_length - self.max_gen_seq_len)
Expand All @@ -1061,7 +1067,7 @@ def augment_dataloader(self, dataloader):
)
reward_prompt = self.model.tokenizer.text_to_ids(reward_prompt_str)
break
'''
"""
assert len(reward_prompt) <= (
self.model.cfg.encoder_seq_length - self.max_gen_seq_len - 8
), f"truncation of response only failed [ {len(reward_prompt)} ]: {reward_prompt_str}"
Expand Down Expand Up @@ -1263,7 +1269,7 @@ def augment_dataloader(self, dataloader):
"bad_ends": bad_ends,
}
)

self.model.finish_inference()
if self.use_trtllm_generation:
self.trtllm_generate.free()
Expand Down Expand Up @@ -1384,7 +1390,7 @@ def augment_dataloader(self, dataloader):
# at this point self.model is the reference policy from cpu_weight_swap
self.trtllm_generate.refit(self.model)
clear_memory()

num_rollouts = rollout_len // self.rollout_micro_batch_size
meta_buffer_unroll_grp = [(idx, y) for idx, x in enumerate(meta_buffer_pending) for y in x[-1]]
for _ in range(num_rollouts):
Expand Down Expand Up @@ -1484,7 +1490,7 @@ def augment_dataloader(self, dataloader):

del meta_buffer_unroll_grp
meta_buffer_pending = [x for x in meta_buffer_pending if len(x[-1]) > 0]

self.model.finish_inference()
if self.use_trtllm_generation:
self.trtllm_generate.free()
Expand Down

0 comments on commit bfc81c6

Please sign in to comment.