Skip to content

Commit

Permalink
support DPO training (2305.18290)
Browse files Browse the repository at this point in the history
  • Loading branch information
hiyouga committed Aug 10, 2023
1 parent 685dae4 commit 3ec4351
Show file tree
Hide file tree
Showing 34 changed files with 517 additions and 1,027,304 deletions.
60 changes: 39 additions & 21 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@

## Changelog

[23/08/11] Now we support **[DPO training](https://arxiv.org/abs/2305.18290)** for instruction-tuned models. See [this example](#dpo-training) to train your models (experimental feature).

[23/08/03] Now we support training the **Qwen-7B** model in this repo. Try `--model_name_or_path Qwen/Qwen-7B-Chat` and `--lora_target c_attn` arguments to train the Qwen-7B model. Remember to use `--template chatml` argument when you are using the Qwen-7B-Chat model.

[23/07/31] Now we support dataset streaming. Try `--streaming` and `--max_steps 100` arguments to stream your dataset.
Expand Down Expand Up @@ -54,24 +56,18 @@
| [Qwen](https://github.com/QwenLM/Qwen-7B) | 7B | c_attn | chatml |
| [XVERSE](https://github.com/xverse-ai/XVERSE-13B) | 13B | q_proj,v_proj | - |

> * **Default module** is used for the `--lora_target` argument. Please use `python src/train_bash.py -h` to see all available options.
> * For the "base" models, the `--template` argument can be chosen from `default`, `alpaca`, `vicuna` etc.
- **Default module** is used for the `--lora_target` argument. Please use `python src/train_bash.py -h` to see all available options.
- For the "base" models, the `--template` argument can be chosen from `default`, `alpaca`, `vicuna` etc. But make sure to use the corresponding template for the "chat" models.

## Supported Training Approaches

- [(Continually) pre-training](https://s3-us-west-2.amazonaws.com/openai-assets/research-covers/language-unsupervised/language_understanding_paper.pdf)
- Full-parameter tuning
- Partial-parameter tuning
- [LoRA](https://arxiv.org/abs/2106.09685)
- [QLoRA](https://arxiv.org/abs/2305.14314)
- [Supervised fine-tuning](https://arxiv.org/abs/2109.01652)
- Full-parameter tuning
- Partial-parameter tuning
- [LoRA](https://arxiv.org/abs/2106.09685)
- [QLoRA](https://arxiv.org/abs/2305.14314)
- [RLHF](https://arxiv.org/abs/2203.02155)
- [LoRA](https://arxiv.org/abs/2106.09685)
- [QLoRA](https://arxiv.org/abs/2305.14314)
| Approach | Full-parameter | Partial-parameter | LoRA | QLoRA |
| ---------------------- | -------------- | ----------------- | ---- | ----- |
| Pre-Training |||||
| Supervised Fine-Tuning |||||
| Reward Model Training | | |||
| PPO Training | | |||
| DPO Training || |||

## Provided Datasets

Expand All @@ -88,7 +84,6 @@
- [Open Assistant (multilingual)](https://huggingface.co/datasets/OpenAssistant/oasst1)
- [Self-cognition (zh)](data/self_cognition.json)
- [ShareGPT (zh)](https://huggingface.co/datasets/QingyiSi/Alpaca-CoT/tree/main/Chinese-instruction-collection)
- [RefGPT (zh)](https://github.com/sufengniu/RefGPT)
- [Guanaco Dataset (multilingual)](https://huggingface.co/datasets/JosephusCheung/GuanacoDataset)
- [BELLE 2M (zh)](https://huggingface.co/datasets/BelleGroup/train_2M_CN)
- [BELLE 1M (zh)](https://huggingface.co/datasets/BelleGroup/train_1M_CN)
Expand All @@ -103,7 +98,7 @@
- [Web QA (zh)](https://huggingface.co/datasets/suolyer/webqa)
- [UltraChat (en)](https://github.com/thunlp/UltraChat)
- [WebNovel (zh)](https://huggingface.co/datasets/zxbsmk/webnovel_cn)
- For reward modelling:
- For reward modelling or DPO training:
- [HH-RLHF (en)](https://huggingface.co/datasets/Anthropic/hh-rlhf)
- [Open Assistant (multilingual)](https://huggingface.co/datasets/OpenAssistant/oasst1)
- [GPT-4 Generated Data (en&zh)](https://github.com/Instruction-Tuning-with-GPT-4/GPT-4-LLM)
Expand Down Expand Up @@ -139,7 +134,6 @@ Note: please update `data/dataset_info.json` to use your custom dataset. About t
### Dependence Installation (optional)

```bash
git lfs install
git clone https://github.com/hiyouga/LLaMA-Efficient-Tuning.git
conda create -n llama_etuning python=3.10
conda activate llama_etuning
Expand All @@ -161,7 +155,7 @@ CUDA_VISIBLE_DEVICES=0 python src/train_web.py

Currently the web UI only supports training on **a single GPU**.

### (Continually) Pre-Training
### Pre-Training

```bash
CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \
Expand Down Expand Up @@ -222,7 +216,7 @@ CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \
--resume_lora_training False \
--checkpoint_dir path_to_sft_checkpoint \
--output_dir path_to_rm_checkpoint \
--per_device_train_batch_size 4 \
--per_device_train_batch_size 2 \
--gradient_accumulation_steps 4 \
--lr_scheduler_type cosine \
--logging_steps 10 \
Expand All @@ -233,7 +227,7 @@ CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \
--fp16
```

### PPO Training (RLHF)
### PPO Training

```bash
CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \
Expand All @@ -257,6 +251,30 @@ CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \
--plot_loss
```

### DPO Training

```bash
CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \
--stage dpo \
--model_name_or_path path_to_your_model \
--do_train \
--dataset comparison_gpt4_en \
--template default \
--finetuning_type lora \
--resume_lora_training False \
--checkpoint_dir path_to_sft_checkpoint \
--output_dir path_to_dpo_checkpoint \
--per_device_train_batch_size 2 \
--gradient_accumulation_steps 4 \
--lr_scheduler_type cosine \
--logging_steps 10 \
--save_steps 1000 \
--learning_rate 1e-5 \
--num_train_epochs 1.0 \
--plot_loss \
--fp16
```

### Distributed Training

```bash
Expand Down
72 changes: 45 additions & 27 deletions README_zh.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,9 @@

## 更新日志

[23/08/03] 现在我们支持了 **Qwen-7B** 模型的训练。请尝试使用 `--model_name_or_path Qwen/Qwen-7B-Chat``--lora_target c_attn` 参数。请注意使用 Qwen-7B-Chat 模型需要添加 `--template chatml` 参数。
[23/08/11] 现在我们支持了指令模型的 **[DPO 训练](https://arxiv.org/abs/2305.18290)**。详情请参阅[此示例](#dpo-training)(实验性功能)。

[23/08/03] 现在我们支持了 **Qwen-7B** 模型的训练。请尝试使用 `--model_name_or_path Qwen/Qwen-7B-Chat``--lora_target c_attn` 参数。使用 Qwen-7B-Chat 模型请添加 `--template chatml` 参数。

[23/07/31] 现在我们支持了训练数据流式加载。请尝试使用 `--streaming``--max_steps 100` 参数来流式加载数据集。

Expand Down Expand Up @@ -54,41 +56,34 @@
| [Qwen](https://github.com/QwenLM/Qwen-7B) | 7B | c_attn | chatml |
| [XVERSE](https://github.com/xverse-ai/XVERSE-13B) | 13B | q_proj,v_proj | - |

> * **默认模块**`--lora_target` 参数的默认值。请使用 `python src/train_bash.py -h` 查看全部可选项。
> * 对于所有“基座”模型,`--template` 参数可以是 `default`, `alpaca`, `vicuna` 等值。
## 微调方法

- [二次预训练](https://s3-us-west-2.amazonaws.com/openai-assets/research-covers/language-unsupervised/language_understanding_paper.pdf)
- 全参数微调
- 部分参数微调
- [LoRA](https://arxiv.org/abs/2106.09685)
- [QLoRA](https://arxiv.org/abs/2305.14314)
- [指令监督微调](https://arxiv.org/abs/2109.01652)
- 全参数微调
- 部分参数微调
- [LoRA](https://arxiv.org/abs/2106.09685)
- [QLoRA](https://arxiv.org/abs/2305.14314)
- [人类反馈的强化学习(RLHF)](https://arxiv.org/abs/2203.02155)
- [LoRA](https://arxiv.org/abs/2106.09685)
- [QLoRA](https://arxiv.org/abs/2305.14314)
- **默认模块**`--lora_target` 参数的部分可选项。请使用 `python src/train_bash.py -h` 查看全部可选项。
- 对于所有“基座”(Base)模型,`--template` 参数可以是 `default`, `alpaca`, `vicuna` 等任意值。但“对话”(Chat)模型请务必使用对应的模板。

## 训练方法

| 方法 | 全参数训练 | 部分参数训练 | LoRA | QLoRA |
| ---------- | ---------- | ----------- | ---- | ----- |
| 预训练 |||||
| 指令监督微调 |||||
| 奖励模型训练 | | |||
| PPO 训练 | | |||
| DPO 训练 || |||

## 数据集

- 用于二次预训练:
- 用于预训练:
- [Wiki Demo (en)](data/wiki_demo.txt)
- [RefinedWeb (en)](https://huggingface.co/datasets/tiiuae/falcon-refinedweb)
- [StarCoder (en)](https://huggingface.co/datasets/bigcode/starcoderdata)
- [Wikipedia (en)](https://huggingface.co/datasets/olm/olm-wikipedia-20221220)
- [Wikipedia (zh)](https://huggingface.co/datasets/pleisto/wikipedia-cn-20230720-filtered)
- 用于指令监督微调:
- 用于指令监督微调
- [Stanford Alpaca (en)](https://github.com/tatsu-lab/stanford_alpaca)
- [Stanford Alpaca (zh)](https://github.com/ymcui/Chinese-LLaMA-Alpaca)
- [GPT-4 Generated Data (en&zh)](https://github.com/Instruction-Tuning-with-GPT-4/GPT-4-LLM)
- [Open Assistant (multilingual)](https://huggingface.co/datasets/OpenAssistant/oasst1)
- [Self-cognition (zh)](data/self_cognition.json)
- [ShareGPT (zh)](https://huggingface.co/datasets/QingyiSi/Alpaca-CoT/tree/main/Chinese-instruction-collection)
- [RefGPT (zh)](https://github.com/sufengniu/RefGPT)
- [Guanaco Dataset (multilingual)](https://huggingface.co/datasets/JosephusCheung/GuanacoDataset)
- [BELLE 2M (zh)](https://huggingface.co/datasets/BelleGroup/train_2M_CN)
- [BELLE 1M (zh)](https://huggingface.co/datasets/BelleGroup/train_1M_CN)
Expand All @@ -103,7 +98,7 @@
- [Web QA (zh)](https://huggingface.co/datasets/suolyer/webqa)
- [UltraChat (en)](https://github.com/thunlp/UltraChat)
- [WebNovel (zh)](https://huggingface.co/datasets/zxbsmk/webnovel_cn)
- 用于奖励模型训练:
- 用于奖励模型或 DPO 训练:
- [HH-RLHF (en)](https://huggingface.co/datasets/Anthropic/hh-rlhf)
- [Open Assistant (multilingual)](https://huggingface.co/datasets/OpenAssistant/oasst1)
- [GPT-4 Generated Data (en&zh)](https://github.com/Instruction-Tuning-with-GPT-4/GPT-4-LLM)
Expand Down Expand Up @@ -139,7 +134,6 @@ huggingface-cli login
### 环境搭建(可跳过)

```bash
git lfs install
git clone https://github.com/hiyouga/LLaMA-Efficient-Tuning.git
conda create -n llama_etuning python=3.10
conda activate llama_etuning
Expand All @@ -161,7 +155,7 @@ CUDA_VISIBLE_DEVICES=0 python src/train_web.py

目前网页 UI 仅支持**单卡训练**

### 二次预训练
### 预训练

```bash
CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \
Expand Down Expand Up @@ -222,7 +216,7 @@ CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \
--resume_lora_training False \
--checkpoint_dir path_to_sft_checkpoint \
--output_dir path_to_rm_checkpoint \
--per_device_train_batch_size 4 \
--per_device_train_batch_size 2 \
--gradient_accumulation_steps 4 \
--lr_scheduler_type cosine \
--logging_steps 10 \
Expand All @@ -233,7 +227,7 @@ CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \
--fp16
```

### RLHF 训练
### PPO 训练

```bash
CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \
Expand All @@ -257,6 +251,30 @@ CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \
--plot_loss
```

### DPO 训练

```bash
CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \
--stage dpo \
--model_name_or_path path_to_your_model \
--do_train \
--dataset comparison_gpt4_zh \
--template default \
--finetuning_type lora \
--resume_lora_training False \
--checkpoint_dir path_to_sft_checkpoint \
--output_dir path_to_dpo_checkpoint \
--per_device_train_batch_size 2 \
--gradient_accumulation_steps 4 \
--lr_scheduler_type cosine \
--logging_steps 10 \
--save_steps 1000 \
--learning_rate 1e-5 \
--num_train_epochs 1.0 \
--plot_loss \
--fp16
```

### 多 GPU 分布式训练

```bash
Expand Down
20 changes: 0 additions & 20 deletions data/dataset_info.json
Original file line number Diff line number Diff line change
Expand Up @@ -49,26 +49,6 @@
"history": "history"
}
},
"refgpt_zh_p1": {
"file_name": "refgpt_zh_50k_p1.json",
"file_sha1": "b40f4f4d0ffacd16da7c275b056d5b6670021752",
"columns": {
"prompt": "instruction",
"query": "input",
"response": "output",
"history": "history"
}
},
"refgpt_zh_p2": {
"file_name": "refgpt_zh_50k_p2.json",
"file_sha1": "181f32b2c60264a29f81f59d3c76095793eae1b0",
"columns": {
"prompt": "instruction",
"query": "input",
"response": "output",
"history": "history"
}
},
"lima": {
"file_name": "lima.json",
"file_sha1": "9db59f6b7007dc4b17529fc63379b9cd61640f37",
Expand Down
Loading

0 comments on commit 3ec4351

Please sign in to comment.