Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove tyro #1176

Merged
merged 42 commits into from
Jan 26, 2024
Merged

Remove tyro #1176

merged 42 commits into from
Jan 26, 2024

Conversation

vwxyzjn
Copy link
Contributor

@vwxyzjn vwxyzjn commented Jan 4, 2024

The command to run is simply python examples/scripts/ppo.py --log_with wandb. There is no evidence of significant regression in the new refactor, though the learning curves appear less smooth.

This was the learning curve before

@vwxyzjn vwxyzjn requested a review from lewtun January 4, 2024 15:07
@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

Copy link
Member

@lewtun lewtun left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks a lot for working on this @vwxyzjn ! I think it all looks good, modulo a question I have about ppo_config.py: shouldn't we also be removing the tyro import and related bits about tyro.conf.Suppress?

e.g. here https://github.com/huggingface/trl/pull/1176/files#diff-070022da4e4a14f0782200ebeab47af78224f10933e97853a0475d26007d4a1eR22-R31

benchmark/trl.slurm_template Outdated Show resolved Hide resolved
trust_remote_code: bool = field(default=False, metadata={"help": "Enable `trust_remote_code`"})

# LoraConfig
use_peft: bool = field(default=False, metadata={"help": "whether to use peft"})
lora_alpha: Optional[float] = field(default=16, metadata={"help": "the lora alpha parameter"})
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not for this PR, but I think we may want to have something like a ModelArguments or ModelConfig class that collects all hyperparams associated with model loading since this tends to be pretty similar across SFT/RM/DPO etc

),
model_name: str = field(default="facebook/opt-350m", metadata={"help": "the model name"})
dataset_name: str = field(default="Anthropic/hh-rlhf", metadata={"help": "the dataset name"})
dataset_text_field: str = field(default="text", metadata={"help": "the text field of the dataset"})
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not for this PR, but I think we should refactor the RewardTrainer to be similar to the SFTTrainer so that all these args can be captured in RewardConfig and the preprocessing done on the fly

examples/scripts/ddpo.py Show resolved Hide resolved
examples/scripts/reward_modeling.py Show resolved Hide resolved
Comment on lines +189 to +194
ddpo_config.project_kwargs = {
"logging_dir": "./logs",
"automatic_checkpoint_naming": True,
"total_limit": 5,
"project_dir": "./save",
}
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

required. Otherwise errors out.

args.reward_config.evaluation_strategy = "steps" if args.eval_split != "none" else "no"
parser = HfArgumentParser((ScriptArguments, RewardConfig))
args, reward_config = parser.parse_args_into_dataclasses()
reward_config.gradient_checkpointing_kwargs = dict(use_reentrant=False)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

required. otherwise errors out.

@vwxyzjn
Copy link
Contributor Author

vwxyzjn commented Jan 9, 2024

@vwxyzjn
Copy link
Contributor Author

vwxyzjn commented Jan 9, 2024

Also removed duplicate / slightly different doc strings to make things more aligned.
image

image

@vwxyzjn
Copy link
Contributor Author

vwxyzjn commented Jan 9, 2024

Not necessarily in this PR but we can also do

parser = HfArgumentParser((TrainingArguments, BitsAndBytesConfig, LoraConfig))

but currently

parser = HfArgumentParser((LoraConfig))

errors out. The reason is as follows @younesbelkada. So not sure if we need to change peft or HfArgumentParser.

from dataclasses import dataclass, field
from typing import Literal
from transformers import HfArgumentParser

@dataclass
class TestConfig: 
    init_lora_weights: bool | Literal["gaussian", "loftq"] = field(
        default=True,
        metadata={
            "help": (
                "How to initialize the weights of the LoRA layers. Passing True (default) results in the default "
                "initialization from the reference implementation from Microsoft. Passing 'gaussian' results "
                "in Gaussian initialization scaled by the LoRA rank for linear and layers. Setting the initialization "
                "to False leads to completely random initialization and is discouraged."
                "Pass `'loftq'` to use LoftQ initialization"
            ),
        },
    )
parser = HfArgumentParser((TestConfig))
args = parser.parse_args_into_dataclasses()
python g.py --help
Traceback (most recent call last):
  File "/fsx/costa/trl/g.py", line 21, in <module>
    parser = HfArgumentParser((TestConfig))
  File "/fsx/costa/pyenv/versions/mambaforge-22.9.0-3/envs/trl-310/lib/python3.10/site-packages/transformers/hf_argparser.py", line 136, in __init__
    self._add_dataclass_arguments(dtype)
  File "/fsx/costa/pyenv/versions/mambaforge-22.9.0-3/envs/trl-310/lib/python3.10/site-packages/transformers/hf_argparser.py", line 263, in _add_dataclass_arguments
    self._parse_dataclass_field(parser, field)
  File "/fsx/costa/pyenv/versions/mambaforge-22.9.0-3/envs/trl-310/lib/python3.10/site-packages/transformers/hf_argparser.py", line 159, in _parse_dataclass_field
    raise ValueError(
ValueError: Only `Union[X, NoneType]` (i.e., `Optional[X]`) is allowed for `Union` because the argument parser only supports one type per argument. Problem encountered in field 'init_lora_weights'.

@vwxyzjn vwxyzjn requested a review from lvwerra January 9, 2024 16:01
@vwxyzjn vwxyzjn changed the title Remove tyro from PPO Remove tyro Jan 10, 2024
@lewtun
Copy link
Member

lewtun commented Jan 10, 2024

Not necessarily in this PR but we can also do

parser = HfArgumentParser((TrainingArguments, BitsAndBytesConfig, LoraConfig))

but currently

parser = HfArgumentParser((LoraConfig))

errors out. The reason is as follows @younesbelkada. So not sure if we need to change peft or HfArgumentParser.

from dataclasses import dataclass, field
from typing import Literal
from transformers import HfArgumentParser

@dataclass
class TestConfig: 
    init_lora_weights: bool | Literal["gaussian", "loftq"] = field(
        default=True,
        metadata={
            "help": (
                "How to initialize the weights of the LoRA layers. Passing True (default) results in the default "
                "initialization from the reference implementation from Microsoft. Passing 'gaussian' results "
                "in Gaussian initialization scaled by the LoRA rank for linear and layers. Setting the initialization "
                "to False leads to completely random initialization and is discouraged."
                "Pass `'loftq'` to use LoftQ initialization"
            ),
        },
    )
parser = HfArgumentParser((TestConfig))
args = parser.parse_args_into_dataclasses()
python g.py --help
Traceback (most recent call last):
 File "/fsx/costa/trl/g.py", line 21, in <module>
   parser = HfArgumentParser((TestConfig))
 File "/fsx/costa/pyenv/versions/mambaforge-22.9.0-3/envs/trl-310/lib/python3.10/site-packages/transformers/hf_argparser.py", line 136, in __init__
   self._add_dataclass_arguments(dtype)
 File "/fsx/costa/pyenv/versions/mambaforge-22.9.0-3/envs/trl-310/lib/python3.10/site-packages/transformers/hf_argparser.py", line 263, in _add_dataclass_arguments
   self._parse_dataclass_field(parser, field)
 File "/fsx/costa/pyenv/versions/mambaforge-22.9.0-3/envs/trl-310/lib/python3.10/site-packages/transformers/hf_argparser.py", line 159, in _parse_dataclass_field
   raise ValueError(
ValueError: Only `Union[X, NoneType]` (i.e., `Optional[X]`) is allowed for `Union` because the argument parser only supports one type per argument. Problem encountered in field 'init_lora_weights'.

If I'm not mistaken, we can encapsulate the BnB and LoRA configs into a single ModelConfig class that contains everything associated with model loading. This is what we do in the handbook here, although admittedly it only covers the args we found most useful. This way we can reuse the same config class across all the scripts (except possibly PPO which requires one to also specify the reward model kwargs)

@younesbelkada
Copy link
Contributor

@vwxyzjn hmmm indeed maybe the fix for that specific problem should go in PEFT, maybe as a quick workaround we can go for @lewtun 's solution?

encapsulate the BnB and LoRA configs into a single ModelConfig class that contains everything associated with model loading

@vwxyzjn
Copy link
Contributor Author

vwxyzjn commented Jan 10, 2024

@lewtun that's a great idea. Let me refactor things using ModelArguments then!

@lvwerra lvwerra requested a review from younesbelkada January 11, 2024 12:04
@vwxyzjn
Copy link
Contributor Author

vwxyzjn commented Jan 11, 2024

So everything looks good, maybe except the DPO train accuracies. It's approaching 100% which seems a bit off because reward_modeling.py has like 65% accuracy on the same dataset...

Copy link
Member

@lewtun lewtun left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great job adding the ModelConfig @vwxyzjn ! This is looking really good and I left a few small comments on things we can remove / tidy up

Regarding the DPO training curves, the better metric to track for overfitting is the train/val loss IMO but in general DPO does overfit quite quickly and this turns out not to matter much in practice

trl/trainer/model_config.py Outdated Show resolved Hide resolved
trl/trainer/model_config.py Outdated Show resolved Hide resolved
trl/trainer/model_config.py Outdated Show resolved Hide resolved
trl/trainer/model_config.py Outdated Show resolved Hide resolved
trl/trainer/model_config.py Outdated Show resolved Hide resolved
examples/scripts/dpo.py Outdated Show resolved Hide resolved
examples/scripts/dpo.py Show resolved Hide resolved
examples/scripts/dpo.py Show resolved Hide resolved
examples/scripts/dpo.py Outdated Show resolved Hide resolved
examples/scripts/dpo.py Show resolved Hide resolved
@vwxyzjn
Copy link
Contributor Author

vwxyzjn commented Jan 24, 2024

@younesbelkada there were some merge conflicts in dpo.py such as

image

With this PR we can keep the code as is: the users can do --bf16 if they want to and it will be picked up by the TrainingArguments :)

@vwxyzjn
Copy link
Contributor Author

vwxyzjn commented Jan 24, 2024

The CI failed but it seems unrelated to the change in this PR. See #1273

Copy link
Contributor

@younesbelkada younesbelkada left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Very nice clean up @vwxyzjn ! Thanks a lot for the refactor - just as a sanity check before merging, could you run make test_examples to make sure the example scripts will not fail on the GPU CI? Only DPO with DS-1 , 2 & 3 should fail but all other configurations should pass

trl/trainer/utils.py Outdated Show resolved Hide resolved
trl/trainer/utils.py Outdated Show resolved Hide resolved
@@ -426,13 +426,13 @@ To use Flash Attention 2, first install the latest `flash-attn` package:
pip install -U flash-attn
```

And add `use_flash_attention_2=True` when calling `from_pretrained`:
And add `attn_implementation="flash_attention_2"` when calling `from_pretrained`:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you add a documentation section here to explain what ModelConfig does and how to use it together with the other utility methods that you have exposed on main init? 🙏

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done!

trl/trainer/model_config.py Outdated Show resolved Hide resolved
trl/trainer/model_config.py Outdated Show resolved Hide resolved


if __name__ == "__main__":
parser = HfArgumentParser((ScriptArguments, TrainingArguments, ModelConfig))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just to confirm - --bf16 is passed directly through the terminal command and parsed thanks to the HfArgumentParser?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

that's correct.

image

vwxyzjn and others added 4 commits January 26, 2024 06:45
@vwxyzjn
Copy link
Contributor Author

vwxyzjn commented Jan 26, 2024

Thanks @younesbelkada for the comment! make test_examples seems to launch 4 experiments and 2 of them failed with

RuntimeError: Expected tensor for argument #1 'indices' to have one of the following scalar types: Long, Int; but got torch.cuda.HalfTensor instead (while checking arguments for embedding)
[2024-01-26 14:57:32,456] torch.distributed.elastic.multiprocessing.api: [WARNING] Sending process 2212929 closing signal SIGTERM

Is this expected?

image

Copy link
Contributor

@younesbelkada younesbelkada left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Huge work @vwxyzjn ! Thanks so much for this refactor !

@vwxyzjn vwxyzjn merged commit 9a71e67 into huggingface:main Jan 26, 2024
6 of 9 checks passed
lapp0 pushed a commit to lapp0/trl that referenced this pull request May 10, 2024
* refactor

* Remove tyro in `ppo.py`

* quick update

* update default args

* quick push

* precommit

* refactor

* quick change

* remove tyro

* quick change

* precommit

* quick change

* fix hello_world

* remove docstring diffences

* add `module load cuda/12.1`

* push changes

* precommit

* make dpo runnable

* fix circular import

* quick fix

* refactor

* quick update

* path change

* update plots

* fix docs

* quick change

* Update trl/trainer/model_config.py

Co-authored-by: lewtun <[email protected]>

* Update trl/trainer/model_config.py

Co-authored-by: lewtun <[email protected]>

* Update trl/trainer/utils.py

Co-authored-by: lewtun <[email protected]>

* Update examples/scripts/dpo.py

Co-authored-by: lewtun <[email protected]>

* address comments. use attn_implementation

* precommit

* remove duplicate code

* update peft.py

* fix test no op dep

* Update trl/trainer/utils.py

Co-authored-by: Younes Belkada <[email protected]>

* Apply suggestions from code review

Co-authored-by: lewtun <[email protected]>
Co-authored-by: Younes Belkada <[email protected]>

* precommit

* add docs

---------

Co-authored-by: lewtun <[email protected]>
Co-authored-by: Younes Belkada <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants