Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for ChatML dataset format in #1208

Merged
merged 14 commits into from
Jan 12, 2024

Conversation

philschmid
Copy link
Contributor

@philschmid philschmid commented Jan 9, 2024

What does this PR do?

This PR adds support for a standardized dataset to be automatically formated for training in the SFTTrainer using the apply_chat_template from transformers.
This allow users to pass the dataset without the need of a formatting_func to the SFTTrainer. Example below

from datasets import load_dataset

dataset = load_dataset("philschmid/dolly-15k-oai-style", split="train")

trainer = SFTTrainer(
    model=model,
    args=training_args,
    train_dataset=dataset,
    packing=True,
)

In the init method the SFTTrainer tries to finds the correct formatting function based on the dataset structure. Currently supported datasets are:
- ChatML with [{"role": str, "content": str}]
- instruction with [{"prompt": str, "completion": str}]

Based on the dataset it returns a callable which uses the tokenizer of the model and the corresponding apply_chat_template method. This allows continues fine-tuning for, e.g. Llama2-chat or other models which already have a defined format.

The nice part about is that you can use the "extras" outside of the SFFTrainer, e.g. if you want to format DPO datasets with the methods.

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

Copy link
Contributor

@younesbelkada younesbelkada left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Amazing addition, thanks! I went through the PR and it looks great to me!
Can you add few lines in the documentaiton to explain what it does under the hood? I think the doc section should live inside the SFTTrainer docs - wdyt?

trl/extras/dataset_formatting.py Outdated Show resolved Hide resolved
trl/extras/dataset_formatting.py Show resolved Hide resolved
@philschmid
Copy link
Contributor Author

Can you add few lines in the documentaiton to explain what it does under the hood? I think the doc section should live inside the SFTTrainer docs - wdyt?

I didn't worked on the documentation yet, since i wanted to see what you think. Will work on the docs next.

Copy link
Contributor

@younesbelkada younesbelkada left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looking great to me thanks! I just left one question for the documentation

{"prompt": "<prompt text>", "completion": "<ideal generated text>"}
```

If your dataset uses one of the above formats, you can directly pass it to the trainer without pre-processing. The [`SFTTrainer`] will then format the dataset for you using the defined format from the model's tokenizer with the [apply_chat_template](https://huggingface.co/docs/transformers/main/en/chat_templating#templates-for-chat-models) method.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe worth adding a line saying that the user need to make sure that tokenizer support apply_chat_template - otherwise it'll fail I think no?

Copy link
Contributor Author

@philschmid philschmid Jan 10, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it is always supported. If there is no template defined it falls back to a default template, which is ChatML format from OAI. cc @Rocketknight1 to confirm

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok perfect then!

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yep, the default format for tokenizers with no chat_template or class-level default_chat_template is ChatML.

Copy link
Member

@lvwerra lvwerra left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks a lot @philschmid! Overall very clean, just left a few small nits.

tests/test_sft_trainer.py Show resolved Hide resolved
docs/source/sft_trainer.mdx Show resolved Hide resolved
Copy link
Contributor

@younesbelkada younesbelkada left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Still looks really good to me! Is it ok for merge @philschmid ?

@younesbelkada younesbelkada merged commit 776939d into huggingface:main Jan 12, 2024
9 checks passed
lapp0 pushed a commit to lapp0/trl that referenced this pull request May 10, 2024
* Add support for ChatML dataset format in
SFTTrainer

* fix formatting

* fix tests

* more comment

* fix intent

* fix doc string

* Update dataset_formatting.py

* Update dataset_formatting.py

* add documentation

* Update sft_trainer.mdx

* add leonardos comment and more tests

* added more tests and fixed batching

* style

* comment in
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants