Skip to content

Commit

Permalink
Merge pull request #35 from ddlBoJack/dev-mzy
Browse files Browse the repository at this point in the history
ddp, hydra, lr_scheduler, q-former
  • Loading branch information
ddlBoJack authored Jan 25, 2024
2 parents 6d30313 + de6d84e commit 6e0f191
Show file tree
Hide file tree
Showing 14 changed files with 323 additions and 185 deletions.
3 changes: 2 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -7,4 +7,5 @@ debug.py
transformers
wandb/
log/
*.log
*.log
outputs/
4 changes: 2 additions & 2 deletions scripts/compute_wer.sh
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
#cd /root/SLAM-LLM

trans="/nfs/maziyang.mzy/exps/TinyLlama-1.1B-Chat-v0.4-finetune-asr-ds5-proj2048-lr1e-4-freeze-whisper-large-v2-prompt-padding30-20240115/asr/2/decode_log_test_other_beam4_repetition_penalty1_gt"
preds="/nfs/maziyang.mzy/exps/TinyLlama-1.1B-Chat-v0.4-finetune-asr-ds5-proj2048-lr1e-4-freeze-whisper-large-v2-prompt-padding30-20240115/asr/2/decode_log_test_other_beam4_repetition_penalty1_pred"
trans="/nfs/maziyang.mzy/exps/vicuna-7b-v1.5-finetune-asr-ds5-proj2048-steplr-whisper-largev2-prompt-lowergt-padding30-20240124/asr/3/decode_log_test_clean_beam4_repetition_penalty1_gt"
preds="/nfs/maziyang.mzy/exps/vicuna-7b-v1.5-finetune-asr-ds5-proj2048-steplr-whisper-largev2-prompt-lowergt-padding30-20240124/asr/3/decode_log_test_clean_beam4_repetition_penalty1_pred"

# python src/llama_recipes/utils/preprocess_text.py ${preds} ${preds}.proc
# python src/llama_recipes/utils/compute_wer.py ${trans} ${preds}.proc ${preds}.proc.wer
Expand Down
12 changes: 8 additions & 4 deletions scripts/conf/asr_vicuna_lora.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -34,15 +34,19 @@ model_config:

train_config:
model_name: "PATH/to/LLAMA/7B"
enable_ddp: false
enable_fsdp: false
low_cpu_fsdp: false
run_validation: true
batch_size_training: 4
batching_strategy: "packing" #alternative: padding
context_length: 4096
gradient_accumulation_steps: 1
num_epochs: 3
num_epochs: 100
num_workers_dataloader: 1
warmup_steps: 1000
total_steps: 100000
validation_interval: 1000
lr: 1e-4
weight_decay: 0.0
gamma: 0.85
Expand All @@ -57,7 +61,7 @@ train_config:
r: 8
lora_alpha: 32
target_modules: [ "q_proj", "v_proj" ]
bias: null
bias: "none"
task_type: "CAUSAL_LM"
lora_dropout: 0.05
inference_mode: false
Expand Down Expand Up @@ -93,8 +97,8 @@ fsdp_config:
mixed_precision: true
use_fp16: false
# sharding_strategy: "FULL_SHARD" #ShardingStrategy = ShardingStrategy.FULL_SHARD
sharding_strategy: "NO_SHARD" #ShardingStrategy.NO_SHARD #MZY: set NO_SHARD to use DDP mode in FSDP
checkpoint_type: "StateDictType.SHARDED_STATE_DICT" # alternatively can use SHARDED_STATE_DICT save one file per rank, and can resize the world-size.
sharding_strategy: "NO_SHARD" #ShardingStrategy.NO_SHARD #MZY: set NO_SHARD when use DDP
checkpoint_type: "SHARDED_STATE_DICT" # alternatively can use SHARDED_STATE_DICT save one file per rank, and can resize the world-size.
fsdp_activation_checkpointing: true
fsdp_cpu_offload: false
pure_bf16: false
Expand Down
61 changes: 39 additions & 22 deletions scripts/finetune_asr_vicuna.sh
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
#!/bin/bash
# export PYTHONPATH=/root/whisper:$PYTHONPATH
export PYTHONPATH=/root/fairseq:$PYTHONPATH
export CUDA_VISIBLE_DEVICES=2,3,4,5
export CUDA_VISIBLE_DEVICES=0,1
export TOKENIZERS_PARALLELISM=false
# export CUDA_LAUNCH_BLOCKING=1
export OMP_NUM_THREADS=1

Expand All @@ -12,19 +13,28 @@ export OMP_NUM_THREADS=1

cd /root/SLAM-LLM

# speech_encoder_path=/nfs/zhifu.gzf/ckpt/Whisper/large-v2.pt
speech_encoder_path=/nfs/maziyang.mzy/models/Whisper/large-v2-qwen.pt
# speech_encoder_path=/nfs/maziyang.mzy/models/Whisper/tiny.pt
# speech_encoder_path=/nfs/maziyang.mzy/models/Whisper/base.pt
# speech_encoder_path=//nfs/maziyang.mzy/models/Whisper/small.pt
# speech_encoder_path=/nfs/maziyang.mzy/models/Whisper/medium.pt
speech_encoder_path=/nfs/maziyang.mzy/models/Whisper/large-v2.pt
# speech_encoder_path=/nfs/maziyang.mzy/models/Whisper/large-v2-qwen.pt

# llm_path=/nfs/maziyang.mzy/models/TinyLlama-1.1B-intermediate-step-1431k-3T
# lm_path=/nfs/maziyang.mzy/models/TinyLlama-1.1B-Chat-v0.4
# llm_path=/nfs/zhifu.gzf/ckpt/Llama-2-7b-hf
# llm_path=/nfs/maziyang.mzy/models/Llama-2-7b-chat-hf
llm_path=/nfs/maziyang.mzy/models/vicuna-7b-v1.5
# llm_path=/nfs/maziyang.mzy/models/vicuna-13b-v1.5

output_dir=/nfs/maziyang.mzy/exps/vicuna-7b-v1.5-finetune-asr-ds5-proj2048-lr1e-4-qwen-prompt-padding30-20240113
output_dir=/nfs/maziyang.mzy/exps/vicuna-7b-v1.5-finetune-asr-qformer64-steplrwarmupkeep1e-4-whisper-largev2-prompt-padding30-20240125-test

# -m debugpy --listen 5678 --wait-for-client
if [[ $CUDA_VISIBLE_DEVICES != *","* ]]; then
python src/llama_recipes/pipeline/finetune.py \
--config-path "/root/SLAM-LLM/scripts/conf" \
--config-name "asr_vicuna_lora.yaml" \
hydra.run.dir=$output_dir \
++model_config.llm_name="vicuna-7b-v1.5" \
++model_config.llm_path=$llm_path \
++model_config.llm_dim=4096 \
Expand All @@ -41,12 +51,16 @@ python src/llama_recipes/pipeline/finetune.py \
++train_config.freeze_encoder=true \
++train_config.freeze_llm=true \
++train_config.batching_strategy=custom \
++train_config.num_epochs=100 \
++train_config.warmup_steps=1000 \
++train_config.total_steps=100000 \
++train_config.lr=1e-4 \
++train_config.validation_interval=1000 \
++train_config.batch_size_training=4 \
++train_config.val_batch_size=4 \
++train_config.num_workers_dataloader=4 \
++train_config.lr=1e-4 \
++train_config.output_dir=$output_dir \
++train_config.use_peft=true \
++train_config.peft_config.peft_method=lora \
++metric=acc \
#++log_config.log_file=/$output_dir/train.log \
Expand All @@ -64,49 +78,52 @@ python src/llama_recipes/pipeline/finetune.py \
else
torchrun \
--nnodes 1 \
--nproc_per_node 4 \
--master_port=29502 \
--nproc_per_node 2 \
src/llama_recipes/pipeline/finetune.py \
--config-path "/root/SLAM-LLM/scripts/conf" \
--config-name "asr_vicuna_lora.yaml" \
hydra.run.dir=$output_dir \
++model_config.llm_name="vicuna-7b-v1.5" \
++model_config.llm_path=$llm_path \
++model_config.llm_dim=4096 \
++model_config.encoder_name=whisper \
++model_config.encoder_ds_rate=2 \
++model_config.encoder_path=$speech_encoder_path \
++model_config.encoder_dim=1280 \
++model_config.encoder_projector=linear \
++model_config.encoder_projector=q-former \
++model_config.encoder_projector_ds_rate=5 \
++dataset_config.dataset=speech_dataset \
++dataset_config.train_data_path=/nfs/maziyang.mzy/data/librispeech/librispeech_train_960h.jsonl \
++dataset_config.val_data_path=/nfs/maziyang.mzy/data/librispeech/librispeech_dev_other_filtered.jsonl \
++dataset_config.fix_length_audio=64 \
++train_config.model_name=asr \
++train_config.freeze_encoder=true \
++train_config.freeze_llm=true \
++train_config.batching_strategy=custom \
++train_config.num_epochs=100 \
++train_config.warmup_steps=1000 \
++train_config.total_steps=100000 \
++train_config.lr=1e-4 \
++train_config.validation_interval=1000 \
++train_config.batch_size_training=4 \
++train_config.val_batch_size=4 \
++train_config.num_workers_dataloader=4 \
++train_config.lr=1e-4 \
++train_config.output_dir=$output_dir \
++train_config.peft_config.peft_method=lora \
++train_config.enable_fsdp=true \
++train_config.enable_ddp=false \
++train_config.enable_fsdp=false \
++train_config.enable_ddp=true \
++train_config.use_fp16=true \
++metric=acc \
#++log_config.log_file=/$output_dir/train.log \
#++log_config.use_wandb=true \
#++log_config.wandb_dir=$output_dir \
#++log_config.wandb_entity_name=zym22 \
#++log_config.wandb_project_name=slam-llm \
#++log_config.wandb_exp_name=test \
#++log_config.log_interval 5 \

++log_config.log_file=/$output_dir/train.log \
++log_config.use_wandb=true \
++log_config.wandb_dir=$output_dir \
++log_config.wandb_entity_name=zym22 \
++log_config.wandb_project_name=slam-llm \
++log_config.wandb_exp_name=${0##*/%.*} \
++log_config.log_interval=5 \
# ++train_config.use_peft=true \
# ++train_config.peft_config.peft_method=lora \
# --peft_ckpt "/nfs/maziyang.mzy/exps/llama-2-hf-finetune-asr-ds5-proj2048-lr1e-5-whisper-prompt-padding30-20231228/asr/4" \
# --ckpt_path "/nfs/maziyang.mzy/exps/llama-2-hf-finetune-asr-ds5-proj2048-lr1e-5-whisper-prompt-padding30-20231228/asr/4/model.pt" \
# --use_peft --peft_method lora \
# --master_port=29501 \
fi

# {"key": "1001-134707-0000_ASR", "prompt": "<ASR>", "source": "/cpfs01/shared/Group-speech/beinian.lzr/data/open_data/librispeech_audio/audio/se_librispeech_1001-134707-0000.wav", "target": "1 little recks the laborer. How near his work is holding him to God, The loving laborer through space and time, after all, not to create, only or found only.", "target_len": 157, "source_len": 1581, "text-type": "Transcribe", "audio_language": "en", "text_language": "en", "task-type": "<ASR>"}
Expand Down
26 changes: 16 additions & 10 deletions scripts/inference_asr_batch.sh
Original file line number Diff line number Diff line change
Expand Up @@ -6,27 +6,31 @@ export TOKENIZERS_PARALLELISM=false

cd /root/SLAM-LLM

speech_encoder_path=/nfs/zhifu.gzf/ckpt/Whisper/large-v2.pt
# speech_encoder_path=/nfs/maziyang.mzy/models/Whisper/tiny.pt
# speech_encoder_path=/nfs/maziyang.mzy/models/Whisper/base.pt
# speech_encoder_path=/nfs/maziyang.mzy/models/Whisper/small.pt
# speech_encoder_path=/nfs/maziyang.mzy/models/Whisper/medium.pt
speech_encoder_path=/nfs/maziyang.mzy/models/Whisper/large-v2.pt
# speech_encoder_path=/nfs/maziyang.mzy/models/Whisper/large-v2-qwen.pt

# llm_path=/nfs/maziyang.mzy/models/TinyLlama-1.1B-intermediate-step-1431k-3T
# llm_path=/nfs/maziyang.mzy/models/TinyLlama-1.1B-Chat-v0.4
# llm_path=/nfs/zhifu.gzf/ckpt/Llama-2-7b-hf
# llm_path=/nfs/maziyang.mzy/models/Llama-2-7b-chat-hf
llm_path=/nfs/maziyang.mzy/models/Llama-2-7b-chat-hf
# llm_path=/nfs/maziyang.mzy/models/vicuna-7b-v1.5
llm_path=/nfs/maziyang.mzy/models/TinyLlama-1.1B-Chat-v0.4

output_dir=/nfs/maziyang.mzy/exps/TinyLlama-1.1B-Chat-v0.4-finetune-asr-ds5-proj2048-lr1e-4-freeze-whisper-large-v2-prompt-padding30-20240115
ckpt_path=$output_dir/asr/2
output_dir=/nfs/maziyang.mzy/exps/Llama-2-7b-chat-hf-finetune-asr-ds5-proj2048-steplr-whisper-largev2-prompt-lowergt-padding30-20240124
ckpt_path=$output_dir/asr/4
# peft_ckpt=/nfs/maziyang.mzy/exps/llama-2-hf-finetune-asr-ds5-proj2048-lr1e-4-whisper-lora-prompt-paddinglr-20240102/asr/4
val_data_path=/nfs/maziyang.mzy/data/librispeech/librispeech_test_other_filtered.jsonl
decode_log=$ckpt_path/decode_log_test_other_beam4_repetition_penalty1
val_data_path=/nfs/maziyang.mzy/data/librispeech/librispeech_test_clean_filtered.jsonl
decode_log=$ckpt_path/decode_log_test_clean_beam4_repetition_penalty1

# -m debugpy --listen 5678 --wait-for-client
python src/llama_recipes/pipeline/inference_batch.py \
--model_name asr \
--freeze_encoder \
--llm_name tinyllama-1.1b-chat-v0.4 \
--llm_name Llama-2-7b-chat-hf \
--llm_path $llm_path \
--llm_dim 2048 \
--llm_dim 4096 \
--encoder_name whisper \
--encoder_ds_rate 2 \
--encoder_path $speech_encoder_path \
Expand All @@ -44,5 +48,7 @@ python src/llama_recipes/pipeline/inference_batch.py \
--ckpt_path $ckpt_path/model.pt \
--decode_log $decode_log \
--freeze_llm \
--freeze_encoder \
# --speech_dataset.prompt "Transcribe speech to text." \
# --peft_ckpt $peft_ckpt \
# --use_peft --peft_method lora \
30 changes: 27 additions & 3 deletions src/llama_recipes/datasets/speech_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,19 @@ def __init__(self,

# self.data_list = contents
self.IGNORE_INDEX = -100 # The default setting in CrossEntropyLoss
self.prompt = dataset_config.get("prompt", None)
# self.prompt_library = [
# "Begin by converting the spoken words into written text. ",
# "Can you transcribe the speech into a written format? ",
# "Focus on translating the audible content into text. ",
# "Transcribe the speech by carefully listening to it. ",
# "Would you kindly write down the content of the speech? ",
# "Analyze the speech and create a written transcription. ",
# "Engage with the speech to produce a text-based version. ",
# "Can you document the speech in written form? ",
# "Transform the spoken words into text accurately. ",
# "How about putting the speech's content into writing? "
# ]
self.prompt_template = "USER: {}\n ASSISTANT:"
self.answer_template = "{}"
self.fix_length_audio = dataset_config.get("fix_length_audio", -1)
Expand All @@ -45,7 +58,15 @@ def __init__(self,
data_dict = json.loads(line.strip())
self.data_list.append(data_dict)


# # debug
# with open(dataset_config.train_data_path, encoding='utf-8') as fin:
# for line in fin:
# data_dict = json.loads(line.strip())
# self.data_list.append(data_dict)
# if split == "train":
# self.data_list = self.data_list[:80]
# else:
# self.data_list = self.data_list[80:100]

def get_source_len(self, data_dict):
return data_dict["source_len"]
Expand All @@ -68,8 +89,11 @@ def __getitem__(self, index):
# audio_raw = np.concatenate((np.zeros(random.randint(0, 16000)), audio_raw, np.zeros(random.randint(0, 16000)))).astype(audio_raw.dtype)[:16000*30]
audio_mel = whisper.log_mel_spectrogram(audio_raw).permute(1, 0)

prompt = "Transcribe speech to text. Output the transcription directly without redundant content. Ensure that the output is not duplicated. "

prompt = self.prompt
if prompt is None:
# prompt = random.choice(self.prompt_library)
# prompt = "Transcribe speech to text. "
prompt = "Transcribe speech to text. Output the transcription directly without redundant content. Ensure that the output is not duplicated. "
prompt = self.prompt_template.format(prompt)
answer = self.answer_template.format(target)

Expand Down
18 changes: 17 additions & 1 deletion src/llama_recipes/datasets/speech_dataset_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,19 @@ def __init__(self,

# self.data_list = contents
self.IGNORE_INDEX = -100 # The default setting in CrossEntropyLoss
self.prompt = dataset_config.prompt
# self.prompt_library = [
# "Begin by converting the spoken words into written text. ",
# "Can you transcribe the speech into a written format? ",
# "Focus on translating the audible content into text. ",
# "Transcribe the speech by carefully listening to it. ",
# "Would you kindly write down the content of the speech? ",
# "Analyze the speech and create a written transcription. ",
# "Engage with the speech to produce a text-based version. ",
# "Can you document the speech in written form? ",
# "Transform the spoken words into text accurately. ",
# "How about putting the speech's content into writing? "
# ]
self.prompt_template = "USER: {}\n ASSISTANT:"
self.fix_length_audio = dataset_config.fix_length_audio

Expand Down Expand Up @@ -72,7 +85,10 @@ def __getitem__(self, index):
# audio_raw = np.concatenate((np.zeros(random.randint(0, 16000)), audio_raw, np.zeros(random.randint(0, 16000)))).astype(audio_raw.dtype)[:16000*30]
audio_mel = whisper.log_mel_spectrogram(audio_raw).permute(1, 0)

prompt = "Transcribe speech to text. Output the transcription directly without redundant content. Ensure that the output is not duplicated. "
prompt = self.prompt
if prompt is None:
# prompt = random.choice(self.prompt_library)
prompt = "Transcribe speech to text. Output the transcription directly without redundant content. Ensure that the output is not duplicated. "

prompt = self.prompt_template.format(prompt)
prompt_ids = self.tokenizer.encode(prompt)
Expand Down
5 changes: 4 additions & 1 deletion src/llama_recipes/model_checkpointing/checkpoint_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,7 +169,10 @@ def save_model_checkpoint_peft(model, optimizer, rank, cfg, epoch=0):
save_dir = os.path.join(cfg.output_dir, cfg.model_name, str(epoch+1))
os.makedirs(save_dir, exist_ok=True)
if not cfg.freeze_llm:
model.llm.save_pretrained(save_dir)
if hasattr(model, "module"): #(FIX:MZY): a hack to deal with the model wrapped in DDP
model.module.llm.save_pretrained(save_dir)
else:
model.llm.save_pretrained(save_dir)
logger.info(f"llm saved at {save_dir}")

save_full_path = os.path.join(save_dir, "model.pt")
Expand Down
15 changes: 9 additions & 6 deletions src/llama_recipes/models/slam_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,14 +25,15 @@ def setup_model(tokenizer, train_config, model_config, **kwargs):

def setup_tokenizer(train_config, model_config, **kwargs):
# Load the tokenizer and add special tokens
if "llama" in model_config.llm_name or "vicuna" in model_config.llm_name:
tokenizer = AutoTokenizer.from_pretrained(model_config.llm_path)
tokenizer.pad_token_id = tokenizer.eos_token_id
return tokenizer
tokenizer = AutoTokenizer.from_pretrained(model_config.llm_path)
tokenizer.pad_token_id = tokenizer.eos_token_id
return tokenizer


def setup_encoder(train_config, model_config, **kwargs):
encoder_list = model_config.encoder_name.split(",")
encoder_list = model_config.encoder_name.split(",") if model_config.encoder_name else []
if len(encoder_list) == 0:
return None
if len(encoder_list) == 1:
encoder_name = encoder_list[0]
if encoder_name == "whisper" or encoder_name == "qwen-audio":
Expand Down Expand Up @@ -198,6 +199,8 @@ def forward(self,
encoder_outs, audio_mel_post_mask = self.encoder.extract_features(audio_mel, audio_mel_mask) # bs*seq*dim
if self.model_config.encoder_name == "moco_wav2vec2":
encoder_outs , inputLenBatch, audio_mel_post_mask = self.encoder((audio, audio_mask, visual, vis_len) ,maskw2v) # bs*seq*dim
if self.encoder is None:
encoder_outs = audio_mel if audio_mel is not None else audio

if self.model_config.encoder_projector == "q-former":
encoder_outs = self.encoder_projector(encoder_outs, audio_mel_post_mask)
Expand Down Expand Up @@ -309,7 +312,7 @@ def inference(
negative_prompt_ids = None,
negative_prompt_attention_mask = None,
**kwargs,
): # TODO: Now you need to set your customized sampling rate manually
):

device = kwargs.get("device", "cuda")
if os.path.exists(wav_path): # Audio-Text QA
Expand Down
Loading

0 comments on commit 6e0f191

Please sign in to comment.