You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
TRL has a concept of multiple different judges which can be used in various online RLHF type methods, see the TRL Judges Doc. As a starting point, we could implement just a pairwise judge that could be useful with techniques like Online DPO similar to TRL Online DPO Trainer.
The basic idea for simple online DPO is to input a prompt dataset, generate two candidate outputs from the policy model, judge these to get chosen and rejected using a specified judge, and then output a preference dataset format.
Here is a first rough cut of a PairwiseJudge that has been working for me in early online DPO experiments. I've included an example below of a pairwise judge based on response length. Please let me know if you have feedback on this approach or if there could be other more useful abstractions to make this something torchtune would want to integrate.
Pairwise Judge Proposal
classPairwiseJudge:
""" Base class for pairwise judges. Provides a common interface for executing judgment functions on pairwise completions. """def__init__(self, rank: int=None):
self.rank=rankdefjudge(self, prompt: str, candidates: List[str]) ->int:
""" Abstract method to be implemented by subclasses. Args: prompt (str): The input prompt. candidates (List[str]): List of candidate responses. Returns: int: Index of the selected candidate. """raiseNotImplementedError("Subclasses must implement this method")
defjudge_all(self, prompts: List[str], completions: List[List[str]], pbar: bool=True) ->List[int]:
""" Executes the judge function in parallel. Args: prompts (List[str]): List of prompts. completions (List[List[str]]): List of completion pairs for each prompt. pbar (bool, optional): Whether to show a progress bar. Returns: List[int]: List of chosen completion indices. """total_tasks=len(completions) *len(completions[0])
progress_bar=tqdm(total=total_tasks, desc=f"Judging with {self.__class__.__name__}", leave=False, disable=notpbar)
defwrapped_judge(prompt, completion):
result=self.judge(prompt, completion)
progress_bar.update(1)
returnresultwithconcurrent.futures.ThreadPoolExecutor() asexecutor:
choices=list(executor.map(wrapped_judge, prompts, completions))
progress_bar.close()
returnchoicesdefrun(self, prompts: List[str], completions: List[List[str]], display: bool=False, pbar: bool=True, **kwargs) ->List[int]:
""" Evaluates multiple prompts and their response candidates. Args: prompts (List[str]): List of prompts. completions (List[List[str]]): List of completion pairs for each prompt. display (bool, optional): Whether to log judgment results. Defaults to False. pbar (bool, optional): Whether to display progress bar. Defaults to True. Returns: List[int]: List of selected response indices. """choices=self.judge_all(prompts, completions, pbar=pbar, **kwargs)
ifdisplay:
self._display_results(prompts, completions, choices) # Not shown here, simply logs resultsreturnchoicesclassLengthPairwiseJudge(PairwiseJudge):
"""Judges based on response length objective"""def__init__(self, objective: Literal["shorter", "longer"], **kwargs):
super().__init__(**kwargs)
self.objective=objectivedefjudge(self, prompt: str, candidates: List[str]) ->int:
ifself.objective=="shorter":
return0iflen(candidates[0]) <=len(candidates[1]) else1elifself.objective=="longer":
return0iflen(candidates[0]) >=len(candidates[1]) else1
Online DPO Proposal (WIP)
Starting from the full DPO distributed recipe, create an online DPO distributed recipe with a new method which converts from a prompt batch into a preference batch. This method would run in the training loop between loading a batch of prompts (using padded_collate) and outputting a batch of preference data (padded_collate_dpo).
So this prompt to preference function would involve:
Get prompt batch
Run two completions per prompt using the policy model
Run the given PairwiseJudge run method to get chosen and rejected
Convert to preference dataset format
The text was updated successfully, but these errors were encountered:
Thanks @sam-pi for sharing the RFC! The general design of the judges abstraction looks pretty reasonable to me. A couple comments:
Personally I would split up the parent and child classes a little differently. (Admittedly some of this is personal preference but) I'd lean towards just having PairwiseJudge as a protocol and leaving all implementation details for the child classes. We do this a lot in other parts of our codebase too (e.g. our tokenizers).
(Kind of an extension of (1)..) I wonder whether judge_all could be a utility outside of a class, e.g. it takes in judge_fn or something as an argument, then wraps in the executor. The advantage of this is that then the judges become even easier to write/understand.
The online DPO proposal sounds reasonable at a high level, but would also like to see some code snippet to get a feel for how the judge interacts with it. Do you envision this as its own training recipe? (Seems that it does change the dataloading and training loop in a nontrivial way)
TRL has a concept of multiple different judges which can be used in various online RLHF type methods, see the TRL Judges Doc. As a starting point, we could implement just a pairwise judge that could be useful with techniques like Online DPO similar to TRL Online DPO Trainer.
The basic idea for simple online DPO is to input a prompt dataset, generate two candidate outputs from the policy model, judge these to get chosen and rejected using a specified judge, and then output a preference dataset format.
Here is a first rough cut of a PairwiseJudge that has been working for me in early online DPO experiments. I've included an example below of a pairwise judge based on response length. Please let me know if you have feedback on this approach or if there could be other more useful abstractions to make this something torchtune would want to integrate.
Pairwise Judge Proposal
Online DPO Proposal (WIP)
Starting from the full DPO distributed recipe, create an online DPO distributed recipe with a new method which converts from a prompt batch into a preference batch. This method would run in the training loop between loading a batch of prompts (using padded_collate) and outputting a batch of preference data (padded_collate_dpo).
So this prompt to preference function would involve:
The text was updated successfully, but these errors were encountered: