-
Notifications
You must be signed in to change notification settings - Fork 1.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Remove tyro #1176
Remove tyro #1176
Conversation
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks a lot for working on this @vwxyzjn ! I think it all looks good, modulo a question I have about ppo_config.py
: shouldn't we also be removing the tyro
import and related bits about tyro.conf.Suppress
?
trust_remote_code: bool = field(default=False, metadata={"help": "Enable `trust_remote_code`"}) | ||
|
||
# LoraConfig | ||
use_peft: bool = field(default=False, metadata={"help": "whether to use peft"}) | ||
lora_alpha: Optional[float] = field(default=16, metadata={"help": "the lora alpha parameter"}) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not for this PR, but I think we may want to have something like a ModelArguments
or ModelConfig
class that collects all hyperparams associated with model loading since this tends to be pretty similar across SFT/RM/DPO etc
examples/scripts/reward_modeling.py
Outdated
), | ||
model_name: str = field(default="facebook/opt-350m", metadata={"help": "the model name"}) | ||
dataset_name: str = field(default="Anthropic/hh-rlhf", metadata={"help": "the dataset name"}) | ||
dataset_text_field: str = field(default="text", metadata={"help": "the text field of the dataset"}) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not for this PR, but I think we should refactor the RewardTrainer
to be similar to the SFTTrainer
so that all these args can be captured in RewardConfig
and the preprocessing done on the fly
ddpo_config.project_kwargs = { | ||
"logging_dir": "./logs", | ||
"automatic_checkpoint_naming": True, | ||
"total_limit": 5, | ||
"project_dir": "./save", | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
required. Otherwise errors out.
examples/scripts/reward_modeling.py
Outdated
args.reward_config.evaluation_strategy = "steps" if args.eval_split != "none" else "no" | ||
parser = HfArgumentParser((ScriptArguments, RewardConfig)) | ||
args, reward_config = parser.parse_args_into_dataclasses() | ||
reward_config.gradient_checkpointing_kwargs = dict(use_reentrant=False) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
required. otherwise errors out.
Scripts seem to run fine :) |
Not necessarily in this PR but we can also do
but currently
errors out. The reason is as follows @younesbelkada. So not sure if we need to change peft or from dataclasses import dataclass, field
from typing import Literal
from transformers import HfArgumentParser
@dataclass
class TestConfig:
init_lora_weights: bool | Literal["gaussian", "loftq"] = field(
default=True,
metadata={
"help": (
"How to initialize the weights of the LoRA layers. Passing True (default) results in the default "
"initialization from the reference implementation from Microsoft. Passing 'gaussian' results "
"in Gaussian initialization scaled by the LoRA rank for linear and layers. Setting the initialization "
"to False leads to completely random initialization and is discouraged."
"Pass `'loftq'` to use LoftQ initialization"
),
},
)
parser = HfArgumentParser((TestConfig))
args = parser.parse_args_into_dataclasses()
|
If I'm not mistaken, we can encapsulate the BnB and LoRA configs into a single |
@lewtun that's a great idea. Let me refactor things using |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Great job adding the ModelConfig
@vwxyzjn ! This is looking really good and I left a few small comments on things we can remove / tidy up
Regarding the DPO training curves, the better metric to track for overfitting is the train/val loss IMO but in general DPO does overfit quite quickly and this turns out not to matter much in practice
Co-authored-by: lewtun <[email protected]>
Co-authored-by: lewtun <[email protected]>
Co-authored-by: lewtun <[email protected]>
Co-authored-by: lewtun <[email protected]>
@younesbelkada there were some merge conflicts in With this PR we can keep the code as is: the users can do |
The CI failed but it seems unrelated to the change in this PR. See #1273 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Very nice clean up @vwxyzjn ! Thanks a lot for the refactor - just as a sanity check before merging, could you run make test_examples
to make sure the example scripts will not fail on the GPU CI? Only DPO with DS-1 , 2 & 3 should fail but all other configurations should pass
@@ -426,13 +426,13 @@ To use Flash Attention 2, first install the latest `flash-attn` package: | |||
pip install -U flash-attn | |||
``` | |||
|
|||
And add `use_flash_attention_2=True` when calling `from_pretrained`: | |||
And add `attn_implementation="flash_attention_2"` when calling `from_pretrained`: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could you add a documentation section here to explain what ModelConfig
does and how to use it together with the other utility methods that you have exposed on main init? 🙏
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done!
|
||
|
||
if __name__ == "__main__": | ||
parser = HfArgumentParser((ScriptArguments, TrainingArguments, ModelConfig)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just to confirm - --bf16
is passed directly through the terminal command and parsed thanks to the HfArgumentParser?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Co-authored-by: Younes Belkada <[email protected]>
Co-authored-by: lewtun <[email protected]> Co-authored-by: Younes Belkada <[email protected]>
Thanks @younesbelkada for the comment!
Is this expected? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Huge work @vwxyzjn ! Thanks so much for this refactor !
* refactor * Remove tyro in `ppo.py` * quick update * update default args * quick push * precommit * refactor * quick change * remove tyro * quick change * precommit * quick change * fix hello_world * remove docstring diffences * add `module load cuda/12.1` * push changes * precommit * make dpo runnable * fix circular import * quick fix * refactor * quick update * path change * update plots * fix docs * quick change * Update trl/trainer/model_config.py Co-authored-by: lewtun <[email protected]> * Update trl/trainer/model_config.py Co-authored-by: lewtun <[email protected]> * Update trl/trainer/utils.py Co-authored-by: lewtun <[email protected]> * Update examples/scripts/dpo.py Co-authored-by: lewtun <[email protected]> * address comments. use attn_implementation * precommit * remove duplicate code * update peft.py * fix test no op dep * Update trl/trainer/utils.py Co-authored-by: Younes Belkada <[email protected]> * Apply suggestions from code review Co-authored-by: lewtun <[email protected]> Co-authored-by: Younes Belkada <[email protected]> * precommit * add docs --------- Co-authored-by: lewtun <[email protected]> Co-authored-by: Younes Belkada <[email protected]>
The command to run is simply
python examples/scripts/ppo.py --log_with wandb
. There is no evidence of significant regression in the new refactor, though the learning curves appear less smooth.This was the learning curve before