diff --git a/nemo_aligner/experimental/self_rewarding/self_rewarding.py b/nemo_aligner/experimental/self_rewarding/self_rewarding.py index aabb8b3cc..3411aa378 100644 --- a/nemo_aligner/experimental/self_rewarding/self_rewarding.py +++ b/nemo_aligner/experimental/self_rewarding/self_rewarding.py @@ -597,7 +597,9 @@ def get_rewards(self, list_of_batches, prepare_for_inference=False): orig_last_prompts = [b["orig_last_prompt"] for b in list_of_batches] orig_responses = [b["orig_response"] for b in list_of_batches] for _ in range(self.num_evals_to_average): - reward_responses, prompt_lengths, resp_lengths, is_end = self.get_generations(list_of_batches, prepare_for_inference=prepare_for_inference) + reward_responses, prompt_lengths, resp_lengths, is_end = self.get_generations( + list_of_batches, prepare_for_inference=prepare_for_inference + ) batch_prompts_str, batch_responses_str = [], [] for t, s, e in zip(reward_responses, prompt_lengths.tolist(), resp_lengths.tolist()): prompt = self.tokenizer.ids_to_text(t[:s].tolist()) @@ -656,7 +658,9 @@ def get_rewards(self, list_of_batches, prepare_for_inference=False): def get_rewards_meta(self, list_of_batches, prepare_for_inference=False): reward_scores = [[] for _ in range(sum([len(b["prompt_lengths"]) for b in list_of_batches]))] reward_scores = [] - reward_responses, prompt_lengths, resp_lengths, is_end = self.get_generations(list_of_batches, prepare_for_inference=prepare_for_inference) + reward_responses, prompt_lengths, resp_lengths, is_end = self.get_generations( + list_of_batches, prepare_for_inference=prepare_for_inference + ) # if ( # torch.distributed.get_rank() == 0 # and torch.distributed.get_rank() == parallel_state.get_data_parallel_src_rank() @@ -982,13 +986,15 @@ def augment_dataloader(self, dataloader): prompt = self.tokenizer.ids_to_text(t[:s].tolist()).replace( "System\n\n", "" ) - response = self.tokenizer.ids_to_text(t[s:e].tolist()).replace("\n", "") - + response = self.tokenizer.ids_to_text(t[s:e].tolist()).replace( + "\n", "" + ) + reward_prompt_str = self.template_fn(prompt=prompt, response=response) reward_prompt = self.tokenizer.text_to_ids(reward_prompt_str) if len(reward_prompt) > self.model.cfg.data.train_ds.max_seq_length: - #prompt_and_response = self.tokenizer.ids_to_text(t[:e].tolist()) - #try: + # prompt_and_response = self.tokenizer.ids_to_text(t[:e].tolist()) + # try: """ if self.cfg.trt_llm.get("model_type", "gptnext").lower() == "llama": prompt_ft = re.findall( @@ -1036,9 +1042,9 @@ def augment_dataloader(self, dataloader): ) reward_prompt_str = self.template_fn(prompt=prompt_ft, response=response_ft) reward_prompt = self.model.tokenizer.text_to_ids(reward_prompt_str) - #prompt = prompt_ft - #response = response_ft - ''' + # prompt = prompt_ft + # response = response_ft + """ except: # print(f"*** TOO_LONG: {prompt_and_response}") # overage = len(reward_prompt) - (self.model.cfg.encoder_seq_length - self.max_gen_seq_len) @@ -1061,7 +1067,7 @@ def augment_dataloader(self, dataloader): ) reward_prompt = self.model.tokenizer.text_to_ids(reward_prompt_str) break - ''' + """ assert len(reward_prompt) <= ( self.model.cfg.encoder_seq_length - self.max_gen_seq_len - 8 ), f"truncation of response only failed [ {len(reward_prompt)} ]: {reward_prompt_str}" @@ -1263,7 +1269,7 @@ def augment_dataloader(self, dataloader): "bad_ends": bad_ends, } ) - + self.model.finish_inference() if self.use_trtllm_generation: self.trtllm_generate.free() @@ -1384,7 +1390,7 @@ def augment_dataloader(self, dataloader): # at this point self.model is the reference policy from cpu_weight_swap self.trtllm_generate.refit(self.model) clear_memory() - + num_rollouts = rollout_len // self.rollout_micro_batch_size meta_buffer_unroll_grp = [(idx, y) for idx, x in enumerate(meta_buffer_pending) for y in x[-1]] for _ in range(num_rollouts): @@ -1484,7 +1490,7 @@ def augment_dataloader(self, dataloader): del meta_buffer_unroll_grp meta_buffer_pending = [x for x in meta_buffer_pending if len(x[-1]) > 0] - + self.model.finish_inference() if self.use_trtllm_generation: self.trtllm_generate.free()