Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

GRPO - Do not load reference model when beta == 0 #2806

Open
wants to merge 4 commits into
base: main
Choose a base branch
from

Conversation

ingambe
Copy link

@ingambe ingambe commented Feb 8, 2025

What does this PR do?

When beta == 0 we can avoid loading and using the reference model as we do not use the KL divergence between reference and current model.
This speed up training/improve memory usage in scenario which beta == 0 is usable (e.g., by using a low gradient clipping value).

This is especially useful in the current setting in which we do not do multiple iterations and we do not stop when KL divergence reach a threshold.

Who can review?

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.

@qgallouedec
Copy link
Member

Do you have any reference that suggests that training without the kl term can give any good result?

@ingambe
Copy link
Author

ingambe commented Feb 8, 2025

Running the example script from the doc using the PR branch code:

from datasets import load_dataset
from trl import GRPOConfig, GRPOTrainer

dataset = load_dataset("trl-lib/tldr", split="train")

# Define the reward function, which rewards completions that are close to 20 characters
def reward_len(completions, **kwargs):
    return [-abs(20 - len(completion)) for completion in completions]

training_args = GRPOConfig(output_dir="Qwen2-0.5B-GRPO", logging_steps=1,
    max_grad_norm=0.2,
    beta=0,
    per_device_train_batch_size=4,
    gradient_accumulation_steps=2,
    report_to="wandb")
trainer = GRPOTrainer(
    model="Qwen/Qwen2-0.5B-Instruct",
    reward_funcs=reward_len,
    args=training_args,
    train_dataset=dataset
)
trainer.train()

On 8*4090 for 15 minutes:
Screenshot 2025-02-08 at 20 57 40

https://wandb.ai/ingambe/huggingface/runs/6kg1jviv/workspace?nw=nwuseringambe

Model converges. Grad clip value was selected based on the value that worked the best on another dataset.

@qgallouedec
Copy link
Member

qgallouedec commented Feb 8, 2025

No, I mean, do you have any reason to think the KL term is useless in GRPO? I'm sure the reward increases, this is actually expected, but remember, the KL term in RLHF prevents the fine-tuned policy from diverging too much from the pre-trained model, ensuring stability, safety, and generalization while balancing reward maximization.

@ingambe
Copy link
Author

ingambe commented Feb 8, 2025

I wouldn’t say it is useless, it is one of the thing to prevent too abrupt policy changes. But loss clipping and gradient clipping also contribute to it.
In classical GRPO or PPO setups where the model performs multiple updates on the same set of samples, it makes a lot of sense because it directly measures and penalizes divergence at every step. However, in a single update setting, maintaining an extra reference model in memory and performing inference on it for the KL penalty might not always be worth it, doing more iterations with a smaller gradient clipping value is an interesting trade-off IMHO.

@qgallouedec
Copy link
Member

Ok. Here the reward is definitely not the metric to monitor. Instead I would monitor the completions, and run a long training with an actual dataset and reward function (not a toy example).

@ingambe
Copy link
Author

ingambe commented Feb 9, 2025

I am limited by the resources at my disposal, so I cannot run for long.
But on a math dataset for Qwen2-0.5B I get expected results.
https://wandb.ai/ingambe/huggingface/runs/zvdenpo8

Screenshot 2025-02-09 at 01 13 41

Code:

def reward_func(completions, ground_truth, **kwargs):
    # Regular expression to capture content inside \boxed{}
    matches = [re.search(r"\\boxed\{(.*?)\}", completion[0]["content"]) for completion in completions]
    contents = [match.group(1) if match else "" for match in matches]
    # Reward 1 if the content is the same as the ground truth, 0 otherwise
    return [1.0 if c == gt else 0.0 for c, gt in zip(contents, ground_truth)]
    
dataset = load_dataset("ai2-adapt-dev/gsm8k_math_ground_truth", split="train")

# Preprocessing: rename messages to prompt and add a system prompt
def preprocess(example):
    system_prompt = {
        "role": "system", 
        "content": "Please reason step by step, and put your final answer within \\boxed{{}}."
    }
    example["prompt"] = [system_prompt] + example["messages"]
    
    example["completion"] = [{
        "role": "assistant", 
        "content": ""
    }]
    example["ground_truth"] = example.get("ground_truth", "")
    
    return example

dataset = dataset.map(preprocess).remove_columns(["messages"])

training_args = GRPOConfig(output_dir="Qwen2-0.5B-GRPO", 
    logging_steps=8,
    max_grad_norm=0.1,
    beta=0,
    per_device_train_batch_size=4,
    gradient_accumulation_steps=2,
    warmup_steps=100,
    max_prompt_length=650,
    max_completion_length=350,
    num_generations=8,
    report_to="wandb",
    log_completions=True)

model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2-0.5B-Instruct")

optimizer = bnb.optim.PagedAdamW8bit(
    model.parameters(),
    lr=1e-6,
    betas=(0.9, 0.95),
    weight_decay=0.1
)

scheduler = LambdaLR(
    optimizer,
    lr_lambda=lambda epoch: min(1.0, (epoch + 1) / training_args.warmup_steps)
)

trainer = GRPOTrainer(
    model=model,
    reward_funcs=reward_func,
    args=training_args,
    train_dataset=dataset,
    optimizers=(optimizer, scheduler)
)

trainer.train()

You can check model input/output in the artifacts.

Model Input:

<|im_start|>system
Please reason step by step, and put your final answer within \boxed{{}}.<|im_end|>
<|im_start|>user
Question: Find the domain of the expression $\frac{\sqrt{x-2}}{\sqrt{5-x}}$.}
Answer:The expressions inside each square root must be non-negative.
Therefore, $x-2 \ge 0$, so $x\ge2$, and $5 - x \ge 0$, so $x \le 5$.
Also, the denominator cannot be equal to zero, so $5-x>0$, which gives $x<5$.
Therefore, the domain of the expression is $\boxed{[2,5)}$.

Question: If $\det \mathbf{A} = 2$ and $\det \mathbf{B} = 12,$ then find $\det (\mathbf{A} \mathbf{B}).$
Answer:We have that $\det (\mathbf{A} \mathbf{B}) = (\det \mathbf{A})(\det \mathbf{B}) = (2)(12) = \boxed{24}.$

Question: Terrell usually lifts two 20-pound weights 12 times. If he uses two 15-pound weights instead, how many times must Terrell lift them in order to lift the same total weight?
Answer:If Terrell lifts two 20-pound weights 12 times, he lifts a total of $2\cdot 12\cdot20=480$ pounds of weight.  If he lifts two 15-pound weights instead for $n$ times, he will lift a total of $2\cdot15\cdot n=30n$ pounds of weight.  Equating this to 480 pounds, we can solve for $n$: \begin{align*}
30n&=480\\
\Rightarrow\qquad n&=480/30=\boxed{16}
\end{align*}

Question: If the system of equations

\begin{align*}
6x-4y&=a,\\
6y-9x &=b.
\end{align*}has a solution $(x, y)$ where $x$ and $y$ are both nonzero, find $\frac{a}{b},$ assuming $b$ is nonzero.
Answer:If we multiply the first equation by $-\frac{3}{2}$, we obtain

$$6y-9x=-\frac{3}{2}a.$$Since we also know that $6y-9x=b$, we have

$$-\frac{3}{2}a=b\Rightarrow\frac{a}{b}=\boxed{-\frac{2}{3}}.$$

Question: Square $ABCD$ has side length $1$ unit.  Points $E$ and $F$ are on sides $AB$ and $CB$, respectively, with $AE = CF$.  When the square is folded along the lines $DE$ and $DF$, sides $AD$ and $CD$ coincide and lie on diagonal $BD$.  The length of segment $AE$ can be expressed in the form $\sqrt{k}-m$ units. What is the integer value of $k+m$?<|im_end|>
<|im_start|>assistant

Model Output:

 The length of side $AD$ is the diagonal of the square, which is equal to $1$ unit.
When the square is folded along the lines $DE$ and $DF$, sides $AD$ and $CD$ coincide and lie on diagonal $BD$, which is the side length of the square itself.
Let $AE = CF = t$ be the length of the side of the square.
Since $AE = CF$ and $AE + CF = 1$, we have $t + t = 1$, which means $2t = 1$, or $t = \frac{1}{2}$.
Therefore, $AE = \sqrt{\frac{1}{2}} = \frac{\sqrt{2}}{2}$.
The length of side $AD\sqrt{2}$ is also $\frac{\sqrt{2}}{2}$.
So, $k+m = \frac{1}{2} + 2 = \boxed{3}$.
The answer is $\boxed{3}$. The answer is $\boxed{3}$

Note:

  • I am limited in VRAM; therefore, I could not run it with a long context, which is not great at all, especially for this dataset, which has some inputs that are quite long.
  • Qwen2 0.5B sometimes ends up repeating itself; this is a known behavior of this small model. If I had to redo it, I would have added additionally a formatting reward on top of the correctness one.
  • I had to use PagedAdamW8bit due to memory pressure, which might have added some instability in the mix, but the results are not too bad for such a small model.
  • Of course I could not tune the hyper parameters.

I cannot run larger experiments. To me, it is useful not to have KL penalties and to have a more aggressive gradient clipping to vastly improve the training speed/reduce memory, especially with such a low default kl_penalty coefficient, but your mileage may vary.

@BaohaoLiao
Copy link

I think adding this option makes sense. SimPO (https://arxiv.org/pdf/2405.14734) also has a similar finding that reference model is not always needed.

@qgallouedec
Copy link
Member

cc @edbeeching, this might interest you

@mirceapricop
Copy link
Contributor

+1 would also appreciate having this option. And wouldn't it be a pure optimization for the case where beta == 0, without affecting other runs?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants