Skip to content

Commit

Permalink
Merge pull request #29 from ddlBoJack/dev-mzy
Browse files Browse the repository at this point in the history
update llama to auto for hf model
  • Loading branch information
ddlBoJack authored Jan 16, 2024
2 parents a4d5e3f + 5a86bfb commit e0bdc00
Show file tree
Hide file tree
Showing 5 changed files with 143 additions and 35 deletions.
4 changes: 2 additions & 2 deletions scripts/compute_wer.sh
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
#cd /root/SLAM-LLM

trans="/nfs/maziyang.mzy/exps/vicuna-7b-v1.5-finetune-asr-ds5-proj2048-lr1e-4-whisper-padding30-20240113/asr/2/decode_log_test_other_beam4_repetition_penalty1_gt"
preds="/nfs/maziyang.mzy/exps/vicuna-7b-v1.5-finetune-asr-ds5-proj2048-lr1e-4-whisper-padding30-20240113/asr/2/decode_log_test_other_beam4_repetition_penalty1_pred"
trans="/nfs/maziyang.mzy/exps/TinyLlama-1.1B-Chat-v0.4-finetune-asr-ds5-proj2048-lr1e-4-freeze-whisper-large-v2-prompt-padding30-20240115/asr/2/decode_log_test_other_beam4_repetition_penalty1_gt"
preds="/nfs/maziyang.mzy/exps/TinyLlama-1.1B-Chat-v0.4-finetune-asr-ds5-proj2048-lr1e-4-freeze-whisper-large-v2-prompt-padding30-20240115/asr/2/decode_log_test_other_beam4_repetition_penalty1_pred"

# python src/llama_recipes/utils/preprocess_text.py ${preds} ${preds}.proc
# python src/llama_recipes/utils/compute_wer.py ${trans} ${preds}.proc ${preds}.proc.wer
Expand Down
100 changes: 100 additions & 0 deletions scripts/finetune_asr_tinyllama.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
#!/bin/bash
# export PYTHONPATH=/root/whisper:$PYTHONPATH
export PYTHONPATH=/root/fairseq:$PYTHONPATH
export CUDA_VISIBLE_DEVICES=4,5,6,7
# export CUDA_LAUNCH_BLOCKING=1
export OMP_NUM_THREADS=1

# debug setting for multiple gpus
# export NCCL_DEBUG=INFO
# export NCCL_DEBUG_SUBSYS=ALL
# export TORCH_DISTRIBUTED_DEBUG=INFO

cd /root/SLAM-LLM

speech_encoder_path=/nfs/zhifu.gzf/ckpt/Whisper/large-v2.pt
# speech_encoder_path=/nfs/maziyang.mzy/models/Whisper/large-v2-qwen.pt

llm_path=/nfs/maziyang.mzy/models/TinyLlama-1.1B-Chat-v0.4

output_dir=/nfs/maziyang.mzy/exps/TinyLlama-1.1B-Chat-v0.4-finetune-asr-ds5-proj2048-lr1e-4-finetune-whisper-large-v2-prompt-padding30-20240115

# -m debugpy --listen 5678 --wait-for-client
if [[ $CUDA_VISIBLE_DEVICES != *","* ]]; then
python -m debugpy --listen 5678 --wait-for-client src/llama_recipes/pipeline/finetune.py \
--model_name asr \
--freeze_encoder \
--freeze_llm \
--llm_name vicuna-13b-v1.5 \
--llm_path $llm_path \
--llm_dim 5120 \
--encoder_name whisper \
--encoder_ds_rate 2 \
--encoder_path $speech_encoder_path \
--encoder_dim 1280 \
--encoder_projector linear \
--encoder_projector_ds_rate 5 \
--dataset speech_dataset \
--speech_dataset.train_data_path /nfs/maziyang.mzy/data/librispeech/librispeech_train_960h.jsonl \
--speech_dataset.val_data_path /nfs/maziyang.mzy/data/librispeech/librispeech_dev_other_filtered.jsonl \
--batching_strategy custom \
--num_epochs 100 \
--batch_size_training 4 \
--val_batch_size 4 \
--num_workers_dataloader 4 \
--lr 1e-4 \
--output_dir $output_dir \
--metric acc \
# --log_file $output_dir/test.log \
# --use_wandb \
# --wandb_dir $output_dir \
# --wandb_entity_name zym22 \
# --wandb_project_name slam-llm \
# --wandb_exp_name test \
# --log_interval 5 \
# --ckpt_path "/nfs/maziyang.mzy/exps/llama-2-hf-finetune-asr-ds5-proj2048-lr1e-5-whisper-lora-prompt/asr/5/model.pt" \
# --peft_ckpt "/nfs/maziyang.mzy/exps/llama-2-hf-finetune-asr-ds5-proj2048-lr1e-5-whisper-lora-prompt/asr/5" \
# --use_peft --peft_method lora \

else
torchrun \
--nnodes 1 \
--nproc_per_node 4 \
--master_port=29501 \
src/llama_recipes/pipeline/finetune.py \
--model_name asr \
--freeze_llm \
--use_fp16 \
--enable_fsdp \
--llm_name tinyllama-1.1b-chat-v0.4 \
--llm_path $llm_path \
--llm_dim 2048 \
--encoder_name whisper \
--encoder_ds_rate 2 \
--encoder_path $speech_encoder_path \
--encoder_dim 1280 \
--encoder_projector linear \
--encoder_projector_ds_rate 5 \
--dataset speech_dataset \
--speech_dataset.train_data_path /nfs/maziyang.mzy/data/librispeech/librispeech_train_960h.jsonl \
--speech_dataset.val_data_path /nfs/maziyang.mzy/data/librispeech/librispeech_dev_other_filtered.jsonl \
--batching_strategy custom \
--num_epochs 100 \
--batch_size_training 4 \
--val_batch_size 4 \
--num_workers_dataloader 4 \
--lr 1e-4 \
--output_dir $output_dir \
--metric acc \
--log_file /$output_dir/train.log \
--use_wandb \
--wandb_dir $output_dir \
--wandb_entity_name zym22 \
--wandb_project_name slam-llm \
--wandb_exp_name test \
--log_interval 5 \
# --peft_ckpt "/nfs/maziyang.mzy/exps/llama-2-hf-finetune-asr-ds5-proj2048-lr1e-5-whisper-prompt-padding30-20231228/asr/4" \
# --ckpt_path "/nfs/maziyang.mzy/exps/llama-2-hf-finetune-asr-ds5-proj2048-lr1e-5-whisper-prompt-padding30-20231228/asr/4/model.pt" \
# --use_peft --peft_method lora \
# --freeze_encoder \
fi
24 changes: 15 additions & 9 deletions scripts/inference_asr.sh
Original file line number Diff line number Diff line change
@@ -1,32 +1,38 @@
#!/bin/bash
#export PYTHONPATH=/root/whisper:$PYTHONPATH
export CUDA_VISIBLE_DEVICES=0
export TOKENIZERS_PARALLELISM=false
# export CUDA_LAUNCH_BLOCKING=1

cd /root/SLAM-LLM

speech_encoder_path=/nfs/zhifu.gzf/ckpt/Whisper/large-v2.pt
# speech_encoder_path=/nfs/maziyang.mzy/models/Whisper/large-v2-qwen.pt
llm_path=/nfs/zhifu.gzf/ckpt/Llama-2-7b-hf
output_dir=/nfs/maziyang.mzy/exps/llama-2-hf-finetune-asr-ds5-proj2048-lr1e-4-whisper-lora-prompt-paddinglr-20240102-renew5
ckpt_path=/nfs/maziyang.mzy/exps/llama-2-hf-finetune-asr-ds5-proj2048-lr1e-4-whisper-lora-prompt-paddinglr-20240102-renew5/asr/1/model.pt
peft_ckpt=/nfs/maziyang.mzy/exps/llama-2-hf-finetune-asr-ds5-proj2048-lr1e-4-whisper-lora-prompt-paddinglr-20240102-renew5/asr/1

# llm_path=/nfs/zhifu.gzf/ckpt/Llama-2-7b-hf
llm_path=/nfs/maziyang.mzy/models/vicuna-7b-v1.5

output_dir=/nfs/maziyang.mzy/exps/vicuna-7b-v1.5-finetune-asr-ds5-proj2048-lr1e-4-whisper-prompt-padding30-20240106
ckpt_path=/nfs/maziyang.mzy/exps/vicuna-7b-v1.5-finetune-asr-ds5-proj2048-lr1e-4-whisper-prompt-padding30-20240106/asr/2/model.pt
# peft_ckpt=/nfs/maziyang.mzy/exps/llama-2-hf-finetune-asr-ds5-proj2048-lr1e-4-whisper-lora-prompt-paddinglr-20240102-renew5/asr/1

# -m debugpy --listen 5678 --wait-for-client
python src/llama_recipes/pipeline/inference.py \
python -m debugpy --listen 5678 --wait-for-client src/llama_recipes/pipeline/inference.py \
--model_name asr \
--freeze_encoder \
--llm_name llama-2-7b-hf \
--freeze_llm \
--llm_name vicuna-7b-v1.5 \
--llm_path $llm_path \
--llm_dim 4096 \
--encoder_name whisper \
--encoder_ds_rate 2 \
--encoder_path $speech_encoder_path \
--encoder_dim 1280 \
--encoder_projector linear \
--encoder_projector_ds_rate 5 \
--output_dir $output_dir \
--ckpt_path $ckpt_path \
--wav_path "/cpfs01/shared/Group-speech/beinian.lzr/data/open_data/librispeech_audio/audio/se_librispeech_1001-134707-0032.wav" \
--prompt "Transcribe speech to text. Output the transcription directly without redundant content. Ensure that the output is not duplicated. " \
--peft_ckpt $peft_ckpt \
# --use_peft --peft_method lora \
# --freeze_llm \
# --peft_ckpt $peft_ckpt \
# --use_peft --peft_method lora \
16 changes: 9 additions & 7 deletions scripts/inference_asr_batch.sh
Original file line number Diff line number Diff line change
@@ -1,18 +1,20 @@
#!/bin/bash
#export PYTHONPATH=/root/whisper:$PYTHONPATH
export CUDA_VISIBLE_DEVICES=7
export CUDA_VISIBLE_DEVICES=1
export TOKENIZERS_PARALLELISM=false
# export CUDA_LAUNCH_BLOCKING=1

cd /root/SLAM-LLM

# speech_encoder_path=/nfs/zhifu.gzf/ckpt/Whisper/large-v2.pt
speech_encoder_path=/nfs/maziyang.mzy/models/Whisper/large-v2-qwen.pt
speech_encoder_path=/nfs/zhifu.gzf/ckpt/Whisper/large-v2.pt
# speech_encoder_path=/nfs/maziyang.mzy/models/Whisper/large-v2-qwen.pt

# llm_path=/nfs/zhifu.gzf/ckpt/Llama-2-7b-hf
# llm_path=/nfs/maziyang.mzy/models/Llama-2-7b-chat-hf
llm_path=/nfs/maziyang.mzy/models/vicuna-7b-v1.5
# llm_path=/nfs/maziyang.mzy/models/vicuna-7b-v1.5
llm_path=/nfs/maziyang.mzy/models/TinyLlama-1.1B-Chat-v0.4

output_dir=/nfs/maziyang.mzy/exps/vicuna-7b-v1.5-finetune-asr-ds5-proj2048-lr1e-4-qwen-prompt-padding30-20240113
output_dir=/nfs/maziyang.mzy/exps/TinyLlama-1.1B-Chat-v0.4-finetune-asr-ds5-proj2048-lr1e-4-freeze-whisper-large-v2-prompt-padding30-20240115
ckpt_path=$output_dir/asr/2
# peft_ckpt=/nfs/maziyang.mzy/exps/llama-2-hf-finetune-asr-ds5-proj2048-lr1e-4-whisper-lora-prompt-paddinglr-20240102/asr/4
val_data_path=/nfs/maziyang.mzy/data/librispeech/librispeech_test_other_filtered.jsonl
Expand All @@ -22,9 +24,9 @@ decode_log=$ckpt_path/decode_log_test_other_beam4_repetition_penalty1
python src/llama_recipes/pipeline/inference_batch.py \
--model_name asr \
--freeze_encoder \
--llm_name vicuna-7b-v1.5 \
--llm_name tinyllama-1.1b-chat-v0.4 \
--llm_path $llm_path \
--llm_dim 4096 \
--llm_dim 2048 \
--encoder_name whisper \
--encoder_ds_rate 2 \
--encoder_path $speech_encoder_path \
Expand Down
34 changes: 17 additions & 17 deletions src/llama_recipes/models/slam_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,8 @@
import torch.nn.functional as F
import torch.distributed as dist
from typing import List, Optional, Tuple, Union
from transformers import AutoModelForCausalLM, AutoTokenizer, AutoConfig
from peft import LoraConfig, TaskType, get_peft_model, prepare_model_for_kbit_training
from transformers import (
LlamaForCausalLM,
LlamaTokenizer,
LlamaConfig,
)

from llama_recipes.utils.config_utils import generate_peft_config
from llama_recipes.utils.train_utils import print_module_size
Expand All @@ -30,7 +26,7 @@ def setup_model(tokenizer, train_config, model_config, **kwargs):
def setup_tokenizer(train_config, model_config, **kwargs):
# Load the tokenizer and add special tokens
if "llama" in model_config.llm_name or "vicuna" in model_config.llm_name:
tokenizer = LlamaTokenizer.from_pretrained(model_config.llm_path)
tokenizer = AutoTokenizer.from_pretrained(model_config.llm_path)
tokenizer.pad_token_id = tokenizer.eos_token_id
return tokenizer

Expand Down Expand Up @@ -75,20 +71,20 @@ def setup_llm(train_config, model_config, **kwargs):
# "please install latest nightly.")
rank = int(os.environ["RANK"])
if rank == 0:
model = LlamaForCausalLM.from_pretrained(
model = AutoModelForCausalLM.from_pretrained(
model_config.llm_path,
load_in_8bit=True if train_config.quantization else None,
device_map="auto" if train_config.quantization else None,
use_cache=use_cache,
)
else:
llama_config = LlamaConfig.from_pretrained(model_config.llm_path)
llama_config = AutoConfig.from_pretrained(model_config.llm_path)
llama_config.use_cache = use_cache
# with torch.device("meta"):
model = LlamaForCausalLM(llama_config) #(FIX:MZY): torch 2.0.1 does not support `meta`
model = AutoModelForCausalLM(llama_config) #(FIX:MZY): torch 2.0.1 does not support `meta`

else:
model = LlamaForCausalLM.from_pretrained(
model = AutoModelForCausalLM.from_pretrained(
model_config.llm_path,
load_in_8bit=True if train_config.quantization else None,
device_map="auto" if train_config.quantization else None,
Expand Down Expand Up @@ -282,6 +278,7 @@ def generate(self,
model_outputs = self.llm.generate(
inputs_embeds=inputs_embeds,
max_length=kwargs.get("max_length", 200),
max_new_tokens=kwargs.get("max_new_tokens", 200),
num_beams=kwargs.get("num_beams", 4),
do_sample=kwargs.get("do_sample", False),
min_length=kwargs.get("min_length", 1),
Expand Down Expand Up @@ -315,13 +312,16 @@ def inference(
): # TODO: Now you need to set your customized sampling rate manually

device = kwargs.get("device", "cuda")
assert os.path.exists(wav_path)
audio_raw = whisper.load_audio(wav_path)
# audio_raw = whisper.pad_or_trim(audio_raw)
audio_mel = whisper.log_mel_spectrogram(audio_raw).permute(1,0)[None, :, :].to(device)

encoder_outs = self.encoder.extract_variable_length_features(audio_mel.permute(0, 2, 1))
encoder_outs = self.encoder_projector(encoder_outs)
if os.path.exists(wav_path): # Audio-Text QA
import whisper
audio_raw = whisper.load_audio(wav_path)
audio_raw = whisper.pad_or_trim(audio_raw)
audio_mel = whisper.log_mel_spectrogram(audio_raw).permute(1,0)[None, :, :].to(device)

encoder_outs = self.encoder.extract_variable_length_features(audio_mel.permute(0, 2, 1))
encoder_outs = self.encoder_projector(encoder_outs)
else: # Text QA
encoder_outs = torch.empty(1, 0, self.llm.model.embed_tokens.embedding_dim).to(device)

prompt = "USER: {}\n ASSISTANT:".format(prompt)
prompt_ids = self.tokenizer.encode(prompt)
Expand Down

0 comments on commit e0bdc00

Please sign in to comment.