From 5a86bfba5a600d1ca00ca99d340de1b17006fb4e Mon Sep 17 00:00:00 2001 From: ddlBoJack Date: Tue, 16 Jan 2024 09:59:28 +0800 Subject: [PATCH] update llama to auto for hf model --- scripts/compute_wer.sh | 4 +- scripts/finetune_asr_tinyllama.sh | 100 +++++++++++++++++++++++++ scripts/inference_asr.sh | 24 +++--- scripts/inference_asr_batch.sh | 16 ++-- src/llama_recipes/models/slam_model.py | 34 ++++----- 5 files changed, 143 insertions(+), 35 deletions(-) create mode 100644 scripts/finetune_asr_tinyllama.sh diff --git a/scripts/compute_wer.sh b/scripts/compute_wer.sh index c6035eda..9bf4c0bc 100644 --- a/scripts/compute_wer.sh +++ b/scripts/compute_wer.sh @@ -1,7 +1,7 @@ #cd /root/SLAM-LLM -trans="/nfs/maziyang.mzy/exps/vicuna-7b-v1.5-finetune-asr-ds5-proj2048-lr1e-4-whisper-padding30-20240113/asr/2/decode_log_test_other_beam4_repetition_penalty1_gt" -preds="/nfs/maziyang.mzy/exps/vicuna-7b-v1.5-finetune-asr-ds5-proj2048-lr1e-4-whisper-padding30-20240113/asr/2/decode_log_test_other_beam4_repetition_penalty1_pred" +trans="/nfs/maziyang.mzy/exps/TinyLlama-1.1B-Chat-v0.4-finetune-asr-ds5-proj2048-lr1e-4-freeze-whisper-large-v2-prompt-padding30-20240115/asr/2/decode_log_test_other_beam4_repetition_penalty1_gt" +preds="/nfs/maziyang.mzy/exps/TinyLlama-1.1B-Chat-v0.4-finetune-asr-ds5-proj2048-lr1e-4-freeze-whisper-large-v2-prompt-padding30-20240115/asr/2/decode_log_test_other_beam4_repetition_penalty1_pred" # python src/llama_recipes/utils/preprocess_text.py ${preds} ${preds}.proc # python src/llama_recipes/utils/compute_wer.py ${trans} ${preds}.proc ${preds}.proc.wer diff --git a/scripts/finetune_asr_tinyllama.sh b/scripts/finetune_asr_tinyllama.sh new file mode 100644 index 00000000..a38a6243 --- /dev/null +++ b/scripts/finetune_asr_tinyllama.sh @@ -0,0 +1,100 @@ +#!/bin/bash +# export PYTHONPATH=/root/whisper:$PYTHONPATH +export PYTHONPATH=/root/fairseq:$PYTHONPATH +export CUDA_VISIBLE_DEVICES=4,5,6,7 +# export CUDA_LAUNCH_BLOCKING=1 +export OMP_NUM_THREADS=1 + +# debug setting for multiple gpus +# export NCCL_DEBUG=INFO +# export NCCL_DEBUG_SUBSYS=ALL +# export TORCH_DISTRIBUTED_DEBUG=INFO + +cd /root/SLAM-LLM + +speech_encoder_path=/nfs/zhifu.gzf/ckpt/Whisper/large-v2.pt +# speech_encoder_path=/nfs/maziyang.mzy/models/Whisper/large-v2-qwen.pt + +llm_path=/nfs/maziyang.mzy/models/TinyLlama-1.1B-Chat-v0.4 + +output_dir=/nfs/maziyang.mzy/exps/TinyLlama-1.1B-Chat-v0.4-finetune-asr-ds5-proj2048-lr1e-4-finetune-whisper-large-v2-prompt-padding30-20240115 + +# -m debugpy --listen 5678 --wait-for-client +if [[ $CUDA_VISIBLE_DEVICES != *","* ]]; then +python -m debugpy --listen 5678 --wait-for-client src/llama_recipes/pipeline/finetune.py \ +--model_name asr \ +--freeze_encoder \ +--freeze_llm \ +--llm_name vicuna-13b-v1.5 \ +--llm_path $llm_path \ +--llm_dim 5120 \ +--encoder_name whisper \ +--encoder_ds_rate 2 \ +--encoder_path $speech_encoder_path \ +--encoder_dim 1280 \ +--encoder_projector linear \ +--encoder_projector_ds_rate 5 \ +--dataset speech_dataset \ +--speech_dataset.train_data_path /nfs/maziyang.mzy/data/librispeech/librispeech_train_960h.jsonl \ +--speech_dataset.val_data_path /nfs/maziyang.mzy/data/librispeech/librispeech_dev_other_filtered.jsonl \ +--batching_strategy custom \ +--num_epochs 100 \ +--batch_size_training 4 \ +--val_batch_size 4 \ +--num_workers_dataloader 4 \ +--lr 1e-4 \ +--output_dir $output_dir \ +--metric acc \ +# --log_file $output_dir/test.log \ +# --use_wandb \ +# --wandb_dir $output_dir \ +# --wandb_entity_name zym22 \ +# --wandb_project_name slam-llm \ +# --wandb_exp_name test \ +# --log_interval 5 \ +# --ckpt_path "/nfs/maziyang.mzy/exps/llama-2-hf-finetune-asr-ds5-proj2048-lr1e-5-whisper-lora-prompt/asr/5/model.pt" \ +# --peft_ckpt "/nfs/maziyang.mzy/exps/llama-2-hf-finetune-asr-ds5-proj2048-lr1e-5-whisper-lora-prompt/asr/5" \ +# --use_peft --peft_method lora \ + +else +torchrun \ +--nnodes 1 \ +--nproc_per_node 4 \ +--master_port=29501 \ +src/llama_recipes/pipeline/finetune.py \ +--model_name asr \ +--freeze_llm \ +--use_fp16 \ +--enable_fsdp \ +--llm_name tinyllama-1.1b-chat-v0.4 \ +--llm_path $llm_path \ +--llm_dim 2048 \ +--encoder_name whisper \ +--encoder_ds_rate 2 \ +--encoder_path $speech_encoder_path \ +--encoder_dim 1280 \ +--encoder_projector linear \ +--encoder_projector_ds_rate 5 \ +--dataset speech_dataset \ +--speech_dataset.train_data_path /nfs/maziyang.mzy/data/librispeech/librispeech_train_960h.jsonl \ +--speech_dataset.val_data_path /nfs/maziyang.mzy/data/librispeech/librispeech_dev_other_filtered.jsonl \ +--batching_strategy custom \ +--num_epochs 100 \ +--batch_size_training 4 \ +--val_batch_size 4 \ +--num_workers_dataloader 4 \ +--lr 1e-4 \ +--output_dir $output_dir \ +--metric acc \ +--log_file /$output_dir/train.log \ +--use_wandb \ +--wandb_dir $output_dir \ +--wandb_entity_name zym22 \ +--wandb_project_name slam-llm \ +--wandb_exp_name test \ +--log_interval 5 \ +# --peft_ckpt "/nfs/maziyang.mzy/exps/llama-2-hf-finetune-asr-ds5-proj2048-lr1e-5-whisper-prompt-padding30-20231228/asr/4" \ +# --ckpt_path "/nfs/maziyang.mzy/exps/llama-2-hf-finetune-asr-ds5-proj2048-lr1e-5-whisper-prompt-padding30-20231228/asr/4/model.pt" \ +# --use_peft --peft_method lora \ +# --freeze_encoder \ +fi \ No newline at end of file diff --git a/scripts/inference_asr.sh b/scripts/inference_asr.sh index da7c3b32..65677ff9 100644 --- a/scripts/inference_asr.sh +++ b/scripts/inference_asr.sh @@ -1,32 +1,38 @@ #!/bin/bash #export PYTHONPATH=/root/whisper:$PYTHONPATH export CUDA_VISIBLE_DEVICES=0 +export TOKENIZERS_PARALLELISM=false # export CUDA_LAUNCH_BLOCKING=1 cd /root/SLAM-LLM speech_encoder_path=/nfs/zhifu.gzf/ckpt/Whisper/large-v2.pt # speech_encoder_path=/nfs/maziyang.mzy/models/Whisper/large-v2-qwen.pt -llm_path=/nfs/zhifu.gzf/ckpt/Llama-2-7b-hf -output_dir=/nfs/maziyang.mzy/exps/llama-2-hf-finetune-asr-ds5-proj2048-lr1e-4-whisper-lora-prompt-paddinglr-20240102-renew5 -ckpt_path=/nfs/maziyang.mzy/exps/llama-2-hf-finetune-asr-ds5-proj2048-lr1e-4-whisper-lora-prompt-paddinglr-20240102-renew5/asr/1/model.pt -peft_ckpt=/nfs/maziyang.mzy/exps/llama-2-hf-finetune-asr-ds5-proj2048-lr1e-4-whisper-lora-prompt-paddinglr-20240102-renew5/asr/1 + +# llm_path=/nfs/zhifu.gzf/ckpt/Llama-2-7b-hf +llm_path=/nfs/maziyang.mzy/models/vicuna-7b-v1.5 + +output_dir=/nfs/maziyang.mzy/exps/vicuna-7b-v1.5-finetune-asr-ds5-proj2048-lr1e-4-whisper-prompt-padding30-20240106 +ckpt_path=/nfs/maziyang.mzy/exps/vicuna-7b-v1.5-finetune-asr-ds5-proj2048-lr1e-4-whisper-prompt-padding30-20240106/asr/2/model.pt +# peft_ckpt=/nfs/maziyang.mzy/exps/llama-2-hf-finetune-asr-ds5-proj2048-lr1e-4-whisper-lora-prompt-paddinglr-20240102-renew5/asr/1 # -m debugpy --listen 5678 --wait-for-client -python src/llama_recipes/pipeline/inference.py \ +python -m debugpy --listen 5678 --wait-for-client src/llama_recipes/pipeline/inference.py \ --model_name asr \ --freeze_encoder \ ---llm_name llama-2-7b-hf \ +--freeze_llm \ +--llm_name vicuna-7b-v1.5 \ --llm_path $llm_path \ +--llm_dim 4096 \ --encoder_name whisper \ --encoder_ds_rate 2 \ --encoder_path $speech_encoder_path \ +--encoder_dim 1280 \ --encoder_projector linear \ --encoder_projector_ds_rate 5 \ --output_dir $output_dir \ --ckpt_path $ckpt_path \ --wav_path "/cpfs01/shared/Group-speech/beinian.lzr/data/open_data/librispeech_audio/audio/se_librispeech_1001-134707-0032.wav" \ --prompt "Transcribe speech to text. Output the transcription directly without redundant content. Ensure that the output is not duplicated. " \ ---peft_ckpt $peft_ckpt \ -# --use_peft --peft_method lora \ -# --freeze_llm \ \ No newline at end of file +# --peft_ckpt $peft_ckpt \ +# --use_peft --peft_method lora \ \ No newline at end of file diff --git a/scripts/inference_asr_batch.sh b/scripts/inference_asr_batch.sh index 2a0efc8f..32fd8811 100644 --- a/scripts/inference_asr_batch.sh +++ b/scripts/inference_asr_batch.sh @@ -1,18 +1,20 @@ #!/bin/bash #export PYTHONPATH=/root/whisper:$PYTHONPATH -export CUDA_VISIBLE_DEVICES=7 +export CUDA_VISIBLE_DEVICES=1 +export TOKENIZERS_PARALLELISM=false # export CUDA_LAUNCH_BLOCKING=1 cd /root/SLAM-LLM -# speech_encoder_path=/nfs/zhifu.gzf/ckpt/Whisper/large-v2.pt -speech_encoder_path=/nfs/maziyang.mzy/models/Whisper/large-v2-qwen.pt +speech_encoder_path=/nfs/zhifu.gzf/ckpt/Whisper/large-v2.pt +# speech_encoder_path=/nfs/maziyang.mzy/models/Whisper/large-v2-qwen.pt # llm_path=/nfs/zhifu.gzf/ckpt/Llama-2-7b-hf # llm_path=/nfs/maziyang.mzy/models/Llama-2-7b-chat-hf -llm_path=/nfs/maziyang.mzy/models/vicuna-7b-v1.5 +# llm_path=/nfs/maziyang.mzy/models/vicuna-7b-v1.5 +llm_path=/nfs/maziyang.mzy/models/TinyLlama-1.1B-Chat-v0.4 -output_dir=/nfs/maziyang.mzy/exps/vicuna-7b-v1.5-finetune-asr-ds5-proj2048-lr1e-4-qwen-prompt-padding30-20240113 +output_dir=/nfs/maziyang.mzy/exps/TinyLlama-1.1B-Chat-v0.4-finetune-asr-ds5-proj2048-lr1e-4-freeze-whisper-large-v2-prompt-padding30-20240115 ckpt_path=$output_dir/asr/2 # peft_ckpt=/nfs/maziyang.mzy/exps/llama-2-hf-finetune-asr-ds5-proj2048-lr1e-4-whisper-lora-prompt-paddinglr-20240102/asr/4 val_data_path=/nfs/maziyang.mzy/data/librispeech/librispeech_test_other_filtered.jsonl @@ -22,9 +24,9 @@ decode_log=$ckpt_path/decode_log_test_other_beam4_repetition_penalty1 python src/llama_recipes/pipeline/inference_batch.py \ --model_name asr \ --freeze_encoder \ ---llm_name vicuna-7b-v1.5 \ +--llm_name tinyllama-1.1b-chat-v0.4 \ --llm_path $llm_path \ ---llm_dim 4096 \ +--llm_dim 2048 \ --encoder_name whisper \ --encoder_ds_rate 2 \ --encoder_path $speech_encoder_path \ diff --git a/src/llama_recipes/models/slam_model.py b/src/llama_recipes/models/slam_model.py index bf661fb5..ad3d1e8a 100644 --- a/src/llama_recipes/models/slam_model.py +++ b/src/llama_recipes/models/slam_model.py @@ -6,12 +6,8 @@ import torch.nn.functional as F import torch.distributed as dist from typing import List, Optional, Tuple, Union +from transformers import AutoModelForCausalLM, AutoTokenizer, AutoConfig from peft import LoraConfig, TaskType, get_peft_model, prepare_model_for_kbit_training -from transformers import ( - LlamaForCausalLM, - LlamaTokenizer, - LlamaConfig, -) from llama_recipes.utils.config_utils import generate_peft_config from llama_recipes.utils.train_utils import print_module_size @@ -30,7 +26,7 @@ def setup_model(tokenizer, train_config, model_config, **kwargs): def setup_tokenizer(train_config, model_config, **kwargs): # Load the tokenizer and add special tokens if "llama" in model_config.llm_name or "vicuna" in model_config.llm_name: - tokenizer = LlamaTokenizer.from_pretrained(model_config.llm_path) + tokenizer = AutoTokenizer.from_pretrained(model_config.llm_path) tokenizer.pad_token_id = tokenizer.eos_token_id return tokenizer @@ -75,20 +71,20 @@ def setup_llm(train_config, model_config, **kwargs): # "please install latest nightly.") rank = int(os.environ["RANK"]) if rank == 0: - model = LlamaForCausalLM.from_pretrained( + model = AutoModelForCausalLM.from_pretrained( model_config.llm_path, load_in_8bit=True if train_config.quantization else None, device_map="auto" if train_config.quantization else None, use_cache=use_cache, ) else: - llama_config = LlamaConfig.from_pretrained(model_config.llm_path) + llama_config = AutoConfig.from_pretrained(model_config.llm_path) llama_config.use_cache = use_cache # with torch.device("meta"): - model = LlamaForCausalLM(llama_config) #(FIX:MZY): torch 2.0.1 does not support `meta` + model = AutoModelForCausalLM(llama_config) #(FIX:MZY): torch 2.0.1 does not support `meta` else: - model = LlamaForCausalLM.from_pretrained( + model = AutoModelForCausalLM.from_pretrained( model_config.llm_path, load_in_8bit=True if train_config.quantization else None, device_map="auto" if train_config.quantization else None, @@ -282,6 +278,7 @@ def generate(self, model_outputs = self.llm.generate( inputs_embeds=inputs_embeds, max_length=kwargs.get("max_length", 200), + max_new_tokens=kwargs.get("max_new_tokens", 200), num_beams=kwargs.get("num_beams", 4), do_sample=kwargs.get("do_sample", False), min_length=kwargs.get("min_length", 1), @@ -315,13 +312,16 @@ def inference( ): # TODO: Now you need to set your customized sampling rate manually device = kwargs.get("device", "cuda") - assert os.path.exists(wav_path) - audio_raw = whisper.load_audio(wav_path) - # audio_raw = whisper.pad_or_trim(audio_raw) - audio_mel = whisper.log_mel_spectrogram(audio_raw).permute(1,0)[None, :, :].to(device) - - encoder_outs = self.encoder.extract_variable_length_features(audio_mel.permute(0, 2, 1)) - encoder_outs = self.encoder_projector(encoder_outs) + if os.path.exists(wav_path): # Audio-Text QA + import whisper + audio_raw = whisper.load_audio(wav_path) + audio_raw = whisper.pad_or_trim(audio_raw) + audio_mel = whisper.log_mel_spectrogram(audio_raw).permute(1,0)[None, :, :].to(device) + + encoder_outs = self.encoder.extract_variable_length_features(audio_mel.permute(0, 2, 1)) + encoder_outs = self.encoder_projector(encoder_outs) + else: # Text QA + encoder_outs = torch.empty(1, 0, self.llm.model.embed_tokens.embedding_dim).to(device) prompt = "USER: {}\n ASSISTANT:".format(prompt) prompt_ids = self.tokenizer.encode(prompt)