-
Notifications
You must be signed in to change notification settings - Fork 52
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #137 from X-LANCE/yxdu
Yxdu
- Loading branch information
Showing
18 changed files
with
110,624 additions
and
4 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,68 @@ | ||
# ST_covost2 | ||
|
||
## Download Model | ||
We only train the q-former projector in this recipe. | ||
Encoder | Projector | LLM | ||
|---|---|--- | ||
[whisper-large-v3](https://huggingface.co/openai/whisper-large-v3) | [q-former](https://huggingface.co/yxdu/cotst) | [Qwen2-7B](https://huggingface.co/Qwen/Qwen2-7B) | ||
``` | ||
git lfs clone https://huggingface.co/openai/whisper-large-v3 | ||
git lfs clone https://huggingface.co/yxdu/cotst | ||
git lfs clone https://huggingface.co/Qwen/Qwen2-7B | ||
``` | ||
|
||
|
||
## Data | ||
You need to download this dataset. | ||
``` | ||
(https://github.com/facebookresearch/covost) | ||
``` | ||
|
||
|
||
|
||
## Data preparation | ||
You need to prepare the data jsonl in this format. | ||
You can find the test jsonl in "test_st.jsonl" | ||
``` | ||
{"audio": "/userhome/speech/data/common/4/en/clips/common_voice_en_699711.mp3", "prompt": "<|en|>", "gt": "\"She'll be all right.\"", "source": "covost_en"} | ||
{"audio": "/userhome/speech/data/common/4/en/clips/common_voice_en_699711.mp3", "prompt": "<|de|>", "gt": "\"She'll be all right.\"<|de|>Sie wird schon in Ordnung sein.", "source": "covost_ende"} | ||
{"audio": "/userhome/speech/data/common/4/en/clips/common_voice_en_699711.mp3", "prompt": "<|ja|>", "gt": "\"She'll be all right.\"<|ja|>彼女は大丈夫だろう。", "source": "covost_enja"} | ||
{"audio": "/userhome/speech/data/common/4/en/clips/common_voice_en_699711.mp3", "prompt": "<|zh|>", "gt": "\"She'll be all right.\"<|zh|>她会没事的。", "source": "covost_enzh"} | ||
{"audio": "/userhome/speech/data/common/4/en/clips/common_voice_en_699711.mp3", "prompt": "\"She'll be all right.\"<|de|>", "gt": "\"She'll be all right.\"<|de|>Sie wird schon in Ordnung sein.", "source": "covost_enende"} | ||
{"audio": "/userhome/speech/data/common/4/en/clips/common_voice_en_699711.mp3", "prompt": "\"She'll be all right.\"<|ja|>", "gt": "\"She'll be all right.\"<|ja|>彼女は大丈夫だろう。", "source": "covost_enenja"} | ||
{"audio": "/userhome/speech/data/common/4/en/clips/common_voice_en_699711.mp3", "prompt": "\"She'll be all right.\"<|zh|>", "gt": "\"She'll be all right.\"<|zh|>她会没事的。", "source": "covost_enenzh"} | ||
``` | ||
## Train Stage | ||
Here, we have designed a four-step training process, where each training session uses the checkpoint obtained from the previous training session. | ||
``` | ||
#In this step, we perform ASR pretraining to acquire speech recognition capabilities. | ||
bash asr_pretrain.sh | ||
#In this phase, we conduct multimodal machine translation training to enhance the final performance. | ||
bash mmt.sh | ||
#monolingual SRT training. | ||
bash srt.sh | ||
#multilingual multitask training. | ||
bash zsrt.sh | ||
``` | ||
|
||
|
||
## Infer Stage | ||
You can try our pre-trained model. | ||
|
||
``` | ||
bash infer.sh | ||
``` | ||
|
||
## Citation | ||
You can refer to the paper for more results. | ||
``` | ||
@article{ma2024embarrassingly, | ||
title={An Embarrassingly Simple Approach for LLM with Strong ASR Capacity}, | ||
author={Ma, Ziyang and Yang, Guanrou and Yang, Yifan and Gao, Zhifu and Wang, Jiaming and Du, Zhihao and Yu, Fan and Chen, Qian and Zheng, Siqi and Zhang, Shiliang and others}, | ||
journal={arXiv preprint arXiv:2402.08846}, | ||
year={2024} | ||
} | ||
``` |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,135 @@ | ||
from dataclasses import dataclass, field | ||
from typing import Optional, List | ||
from torch.distributed.fsdp import ShardingStrategy | ||
from torch.distributed.fsdp.fully_sharded_data_parallel import StateDictType | ||
|
||
@dataclass | ||
class ModelConfig: | ||
file: str = "examples/st_covost2/model/slam_model_st.py" | ||
llm_name: str = "vicuna-13b-v1.5" | ||
llm_path: str = "PATH/to/LLAMA/7B" | ||
llm_type: str = "decoder_only" | ||
llm_dim: int = 3584 | ||
encoder_path_hf: Optional[str] = None | ||
encoder_name: Optional[str] = None | ||
encoder_ds_rate: int = 2 | ||
encoder_path: Optional[str] = None | ||
encoder_dim: int = 1280 | ||
encoder_projector: str = "linear" | ||
encoder_projector_ds_rate: int = 5 | ||
modal: str = "audio" | ||
normalize: Optional[bool] = field(default=False, metadata={ | ||
"help": "whether input is normalized, used for models such as wavlm" | ||
}) | ||
encoder_type: str = field(default="finetune", metadata={ | ||
"help": "whether model is only pretrained or finetuned, used for models such as hubert" | ||
}) | ||
ckpt_path: Optional[str] = None | ||
query_len: Optional[str] = None | ||
|
||
|
||
@dataclass | ||
class PeftConfig: | ||
peft_method: str = "lora" # None , llama_adapter, prefix | ||
r: int = 8 | ||
lora_alpha: int = 32 | ||
target_modules: List = field(default_factory=lambda: [ "q_proj", "v_proj" ]) | ||
bias: str = "none" | ||
task_type: str = "CAUSAL_LM" | ||
lora_dropout: float = 0.05 | ||
inference_mode: bool = False | ||
|
||
@dataclass | ||
class TrainConfig: | ||
model_name:str = "PATH/to/LLAMA/7B" | ||
enable_ddp:bool = False | ||
enable_deepspeed:bool = False | ||
enable_fsdp:bool = False | ||
low_cpu_fsdp:bool = False | ||
run_validation:bool = True | ||
batch_size_training:int = 4 | ||
batching_strategy:str = field(default="packing", metadata={ | ||
"help":"alternative: padding" | ||
}) # | ||
context_length:int = 4096 | ||
gradient_accumulation_steps:int = 1 | ||
num_epochs:int = 3 | ||
num_workers_dataloader:int = 1 | ||
warmup_steps:int = 1000 | ||
total_steps:int = 100000 | ||
validation_interval:int = 1000 | ||
lr:float = 1e-4 | ||
weight_decay:float = 0.01 | ||
gamma:float = 0.85 | ||
seed:int = 42 | ||
use_fp16:bool = False | ||
mixed_precision:bool = True | ||
val_batch_size:int = 1 | ||
|
||
use_peft:bool = False | ||
peft_config:PeftConfig = field(default_factory=PeftConfig) | ||
output_dir:str = "PATH/to/save/PEFT/model" | ||
freeze_layers:bool = False | ||
num_freeze_layers:int = 1 | ||
quantization:bool = False | ||
one_gpu:bool = False | ||
save_model:bool = True | ||
dist_checkpoint_root_folder:str = "PATH/to/save/FSDP/model" # will be used if using FSDP | ||
dist_checkpoint_folder:str = "fine-tuned" # will be used if using FSDP | ||
save_optimizer:bool = False # will be used if using FSDP | ||
use_fast_kernels:bool = False # Enable using SDPA from PyTroch Accelerated Transformers, make use Flash Attention and Xformer memory-efficient kernels | ||
run_test_during_validation:bool = False | ||
run_test_during_validation_file:str = "test.wav" | ||
run_test_during_validation_prompt:str = "<|ASR|>" | ||
freeze_llm:bool = field(default=False, metadata={ | ||
"help": "whether to freeze llm when finetuning, should be true when use peft finetuning" | ||
}) | ||
freeze_encoder:bool = False | ||
|
||
@dataclass | ||
class DataConfig: | ||
dataset: str = "st_dataset" | ||
file: str = "examples/st_covost2/dataset/st_dataset.py:get_speech_dataset" | ||
train_data_path: Optional[str] = None | ||
val_data_path: Optional[str] = None | ||
train_split: str = "train" | ||
test_split:str = "validation" | ||
prompt: Optional[str] = None | ||
data_path: Optional[str] = None | ||
max_words: Optional[int] = None | ||
max_mel: Optional[float] = None | ||
fix_length_audio: int = -1 | ||
inference_mode:bool = False | ||
input_type: str = field(default="raw", metadata={ | ||
"help":"Use raw when input is wav, mel when for whisper" | ||
}) | ||
mel_size: int = field(default=80, metadata={ | ||
"help": "80 for whisper large v1 and v2, 128 for v3" | ||
}) | ||
normalize: Optional[bool] = field(default=False, metadata={ | ||
"help": "whether input is normalized, used for models such as wavlm" | ||
}) | ||
bf16:bool = True | ||
source: Optional[str] = None | ||
|
||
@dataclass | ||
class FSDPConfig: | ||
mixed_precision: bool = True | ||
use_fp16: bool = False | ||
sharding_strategy: ShardingStrategy = ShardingStrategy.FULL_SHARD # HYBRID_SHARD "Full Shard within a node DDP cross Nodes", SHARD_GRAD_OP "Shard only Gradients and Optimizer States", NO_SHARD "Similar to DDP". | ||
checkpoint_type: StateDictType = StateDictType.SHARDED_STATE_DICT # alternatively can use SHARDED_STATE_DICT save one file per rank, and can resize the world-size. | ||
fsdp_activation_checkpointing: bool = True | ||
fsdp_cpu_offload: bool = False | ||
pure_bf16: bool = False | ||
optimizer: str = "AdamW" | ||
|
||
@dataclass | ||
class LogConfig: | ||
use_wandb: bool = False | ||
wandb_dir: str = "test_wandb" | ||
wandb_entity_name: str = "SLAM" | ||
wandb_project_name: str = "project_name" | ||
wandb_exp_name: str = "exp_name" | ||
log_file: str = "./test.log" | ||
log_interval: int = 50 | ||
decode_log: str = "./test.log" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,19 @@ | ||
{ | ||
"train_micro_batch_size_per_gpu": 4, | ||
"gradient_accumulation_steps": 1, | ||
"optimizer": { | ||
"type": "Adam", | ||
"params": { | ||
"lr": 1e-4 | ||
} | ||
}, | ||
"fp16": { | ||
"enabled": true | ||
}, | ||
"zero_optimization": { | ||
"stage": 3, | ||
"offload_optimizer": { | ||
"device": "cpu" | ||
} | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,4 @@ | ||
dataset_config: | ||
# we put prompt here, because the hydra override in shell script only support a small subset of chars | ||
# prompt: "Transcribe speech to text. Output the transcription directly without redundant content. Ensure that the output is not duplicated. " | ||
prompt: "<en>" |
Oops, something went wrong.