Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[RFC] Judge Framework and Online DPO #2413

Open
sam-pi opened this issue Feb 19, 2025 · 1 comment
Open

[RFC] Judge Framework and Online DPO #2413

sam-pi opened this issue Feb 19, 2025 · 1 comment
Labels
rfc Request for comments

Comments

@sam-pi
Copy link
Contributor

sam-pi commented Feb 19, 2025

TRL has a concept of multiple different judges which can be used in various online RLHF type methods, see the TRL Judges Doc. As a starting point, we could implement just a pairwise judge that could be useful with techniques like Online DPO similar to TRL Online DPO Trainer.

The basic idea for simple online DPO is to input a prompt dataset, generate two candidate outputs from the policy model, judge these to get chosen and rejected using a specified judge, and then output a preference dataset format.

Here is a first rough cut of a PairwiseJudge that has been working for me in early online DPO experiments. I've included an example below of a pairwise judge based on response length. Please let me know if you have feedback on this approach or if there could be other more useful abstractions to make this something torchtune would want to integrate.

Pairwise Judge Proposal

class PairwiseJudge:
    """
    Base class for pairwise judges.

    Provides a common interface for executing judgment functions on pairwise completions.
    """

    def __init__(self, rank: int = None):
        self.rank = rank

    def judge(self, prompt: str, candidates: List[str]) -> int:
        """
        Abstract method to be implemented by subclasses.

        Args:
            prompt (str): The input prompt.
            candidates (List[str]): List of candidate responses.

        Returns:
            int: Index of the selected candidate.
        """
        raise NotImplementedError("Subclasses must implement this method")

    def judge_all(self, prompts: List[str], completions: List[List[str]], pbar: bool = True) -> List[int]:
        """
        Executes the judge function in parallel.

        Args:
            prompts (List[str]): List of prompts.
            completions (List[List[str]]): List of completion pairs for each prompt.
            pbar (bool, optional): Whether to show a progress bar.

        Returns:
            List[int]: List of chosen completion indices.
        """
        total_tasks = len(completions) * len(completions[0])
        progress_bar = tqdm(total=total_tasks, desc=f"Judging with {self.__class__.__name__}", leave=False, disable=not pbar)

        def wrapped_judge(prompt, completion):
            result = self.judge(prompt, completion)
            progress_bar.update(1)
            return result

        with concurrent.futures.ThreadPoolExecutor() as executor:
            choices = list(executor.map(wrapped_judge, prompts, completions))

        progress_bar.close()
        return choices

    def run(self, prompts: List[str], completions: List[List[str]], display: bool = False, pbar: bool = True, **kwargs) -> List[int]:
        """
        Evaluates multiple prompts and their response candidates.

        Args:
            prompts (List[str]): List of prompts.
            completions (List[List[str]]): List of completion pairs for each prompt.
            display (bool, optional): Whether to log judgment results. Defaults to False.
            pbar (bool, optional): Whether to display progress bar. Defaults to True.

        Returns:
            List[int]: List of selected response indices.
        """
        choices = self.judge_all(prompts, completions, pbar=pbar, **kwargs)

        if display:
            self._display_results(prompts, completions, choices) # Not shown here, simply logs results

        return choices


class LengthPairwiseJudge(PairwiseJudge):
    """Judges based on response length objective"""

    def __init__(self, objective: Literal["shorter", "longer"], **kwargs):
        super().__init__(**kwargs)
        self.objective = objective

    def judge(self, prompt: str, candidates: List[str]) -> int:
        if self.objective == "shorter":
            return 0 if len(candidates[0]) <= len(candidates[1]) else 1
        elif self.objective == "longer":
            return 0 if len(candidates[0]) >= len(candidates[1]) else 1

Online DPO Proposal (WIP)

Starting from the full DPO distributed recipe, create an online DPO distributed recipe with a new method which converts from a prompt batch into a preference batch. This method would run in the training loop between loading a batch of prompts (using padded_collate) and outputting a batch of preference data (padded_collate_dpo).

So this prompt to preference function would involve:

  1. Get prompt batch
  2. Run two completions per prompt using the policy model
  3. Run the given PairwiseJudge run method to get chosen and rejected
  4. Convert to preference dataset format
@ebsmothers
Copy link
Contributor

Thanks @sam-pi for sharing the RFC! The general design of the judges abstraction looks pretty reasonable to me. A couple comments:

  1. Personally I would split up the parent and child classes a little differently. (Admittedly some of this is personal preference but) I'd lean towards just having PairwiseJudge as a protocol and leaving all implementation details for the child classes. We do this a lot in other parts of our codebase too (e.g. our tokenizers).
  2. (Kind of an extension of (1)..) I wonder whether judge_all could be a utility outside of a class, e.g. it takes in judge_fn or something as an argument, then wraps in the executor. The advantage of this is that then the judges become even easier to write/understand.
  3. The online DPO proposal sounds reasonable at a high level, but would also like to see some code snippet to get a feel for how the judge interacts with it. Do you envision this as its own training recipe? (Seems that it does change the dataloading and training loop in a nontrivial way)

@ebsmothers ebsmothers added the rfc Request for comments label Feb 23, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
rfc Request for comments
Projects
None yet
Development

No branches or pull requests

2 participants