diff --git a/scripts/conf/asr_vicuna_lora.yaml b/scripts/conf/asr_vicuna_lora.yaml index 4c9cdb79..78d94a9c 100644 --- a/scripts/conf/asr_vicuna_lora.yaml +++ b/scripts/conf/asr_vicuna_lora.yaml @@ -82,7 +82,7 @@ train_config: freeze_encoder: false dataset_config: - dataset: "samsum_dataset" + dataset: "speech_dataset" file: "src/llama_recipes/datasets/speech_dataset.py:get_speech_dataset" train_data_path: null val_data_path: null @@ -92,6 +92,7 @@ dataset_config: max_words: null max_mel: null fix_length_audio: -1 + inference_mode: false fsdp_config: mixed_precision: true diff --git a/scripts/inference_asr_batch.sh b/scripts/inference_asr_batch.sh index cd608405..2d6d752e 100644 --- a/scripts/inference_asr_batch.sh +++ b/scripts/inference_asr_batch.sh @@ -1,6 +1,6 @@ #!/bin/bash #export PYTHONPATH=/root/whisper:$PYTHONPATH -export CUDA_VISIBLE_DEVICES=1 +export CUDA_VISIBLE_DEVICES=0 export TOKENIZERS_PARALLELISM=false # export CUDA_LAUNCH_BLOCKING=1 @@ -16,39 +16,44 @@ speech_encoder_path=/nfs/maziyang.mzy/models/Whisper/large-v2.pt # llm_path=/nfs/maziyang.mzy/models/TinyLlama-1.1B-intermediate-step-1431k-3T # llm_path=/nfs/maziyang.mzy/models/TinyLlama-1.1B-Chat-v0.4 # llm_path=/nfs/zhifu.gzf/ckpt/Llama-2-7b-hf -llm_path=/nfs/maziyang.mzy/models/Llama-2-7b-chat-hf -# llm_path=/nfs/maziyang.mzy/models/vicuna-7b-v1.5 +# llm_path=/nfs/maziyang.mzy/models/Llama-2-7b-chat-hf +llm_path=/nfs/maziyang.mzy/models/vicuna-7b-v1.5 -output_dir=/nfs/maziyang.mzy/exps/Llama-2-7b-chat-hf-finetune-asr-ds5-proj2048-steplr-whisper-largev2-prompt-lowergt-padding30-20240124 -ckpt_path=$output_dir/asr/4 +output_dir=/nfs/maziyang.mzy/exps/vicuna-7b-v1.5-finetune-asr-ds5-proj2048-steplrwarmupkeep1e-4-whisper-largev2-promptshort-lowergt-padding30-20240125 +ckpt_path=$output_dir/asr/2 # peft_ckpt=/nfs/maziyang.mzy/exps/llama-2-hf-finetune-asr-ds5-proj2048-lr1e-4-whisper-lora-prompt-paddinglr-20240102/asr/4 val_data_path=/nfs/maziyang.mzy/data/librispeech/librispeech_test_clean_filtered.jsonl decode_log=$ckpt_path/decode_log_test_clean_beam4_repetition_penalty1 # -m debugpy --listen 5678 --wait-for-client python src/llama_recipes/pipeline/inference_batch.py \ ---model_name asr \ ---llm_name Llama-2-7b-chat-hf \ ---llm_path $llm_path \ ---llm_dim 4096 \ ---encoder_name whisper \ ---encoder_ds_rate 2 \ ---encoder_path $speech_encoder_path \ ---encoder_dim 1280 \ ---encoder_projector linear \ ---encoder_projector_ds_rate 5 \ ---dataset speech_dataset \ ---speech_dataset.file src/llama_recipes/datasets/speech_dataset_inference.py:get_speech_dataset \ ---speech_dataset.val_data_path $val_data_path \ ---batching_strategy custom \ ---num_epochs 1 \ ---val_batch_size 4 \ ---num_workers_dataloader 4 \ ---output_dir $output_dir \ ---ckpt_path $ckpt_path/model.pt \ ---decode_log $decode_log \ ---freeze_llm \ ---freeze_encoder \ -# --speech_dataset.prompt "Transcribe speech to text." \ +--config-path "/root/SLAM-LLM/scripts/conf" \ +--config-name "asr_vicuna_lora.yaml" \ +hydra.run.dir=$ckpt_path \ +++model_config.llm_name="vicuna-7b-v1.5" \ +++model_config.llm_path=$llm_path \ +++model_config.llm_dim=4096 \ +++model_config.encoder_name=whisper \ +++model_config.encoder_ds_rate=2 \ +++model_config.encoder_path=$speech_encoder_path \ +++model_config.encoder_dim=1280 \ +++model_config.encoder_projector=linear \ +++model_config.encoder_projector_ds_rate=5 \ +++dataset_config.dataset=speech_dataset \ +++dataset_config.prompt="Transcribe speech to text. " \ +++dataset_config.val_data_path=$val_data_path \ +++dataset_config.inference_mode=true \ +++train_config.model_name=asr \ +++train_config.batching_strategy=custom \ +++train_config.num_epochs=1 \ +++train_config.val_batch_size=4 \ +++train_config.num_workers_dataloader=4 \ +++train_config.output_dir=$output_dir \ +++ckpt_path=$ckpt_path/model.pt \ +++decode_log=$decode_log \ +++train_config.freeze_encoder=true \ +++train_config.freeze_llm=true \ +# ++model_config.encoder_projector=q-former \ +# ++dataset_config.fix_length_audio=64 \ # --peft_ckpt $peft_ckpt \ # --use_peft --peft_method lora \ \ No newline at end of file diff --git a/src/llama_recipes/datasets/speech_dataset.py b/src/llama_recipes/datasets/speech_dataset.py index c91380c8..7d6f450a 100644 --- a/src/llama_recipes/datasets/speech_dataset.py +++ b/src/llama_recipes/datasets/speech_dataset.py @@ -45,6 +45,7 @@ def __init__(self, self.prompt_template = "USER: {}\n ASSISTANT:" self.answer_template = "{}" self.fix_length_audio = dataset_config.get("fix_length_audio", -1) + self.inference_mode = dataset_config.get("inference_mode", False) self.data_list = [] if split == "train": @@ -83,6 +84,7 @@ def __getitem__(self, index): audio_path = data_dict.get("source") target = data_dict.get("target", None) task = data_dict.get("prompt", "ASR") + key = data_dict.get("key", None) audio_raw = whisper.load_audio(audio_path) audio_raw = whisper.pad_or_trim(audio_raw) @@ -94,11 +96,9 @@ def __getitem__(self, index): # prompt = random.choice(self.prompt_library) # prompt = "Transcribe speech to text. " prompt = "Transcribe speech to text. Output the transcription directly without redundant content. Ensure that the output is not duplicated. " + prompt = self.prompt_template.format(prompt) - answer = self.answer_template.format(target) - prompt_ids = self.tokenizer.encode(prompt) - prompt_length = len(prompt_ids) audio_length = (audio_mel.shape[0] + 1) // 2 # ad-hoc for whisper for 2x downsample from mel to feats audio_length = audio_length // 5 # ad-hoc for 5x fc downsample @@ -107,6 +107,21 @@ def __getitem__(self, index): audio_length = self.fix_length_audio audio_pseudo = torch.full((audio_length,), -1) # placeholder + if self.inference_mode: + prompt_ids = torch.tensor(prompt_ids, dtype=torch.int64) + example_ids = torch.cat((audio_pseudo, prompt_ids)) # [audio,prompt] + example_mask = example_ids.ge(-1) # [True,True] + + return { + "input_ids": example_ids, + "attention_mask": example_mask, + 'audio_mel': audio_mel, + 'audio_length': audio_length, + 'key': key, + 'target': target, + } + + answer = self.answer_template.format(target) example = prompt + answer # FIX(MZY): avoid putting a bos token before answer. example_ids = self.tokenizer.encode(example) # [prompt,answer] example_ids.append(self.tokenizer.eos_token_id) # [prompt,answer,eos] @@ -153,8 +168,6 @@ def collator(self, samples): input_ids_max_length = max([s['input_ids'].shape[0] for s in samples]) input_ids = torch.stack([self.pad(s['input_ids'], input_ids_max_length, self.tokenizer.pad_token_id) for s in samples]) - labels = torch.stack([self.pad(s['labels'], input_ids_max_length, self.IGNORE_INDEX) - for s in samples]) attention_mask = torch.stack([self.pad(s['attention_mask'], input_ids_max_length, False) for s in samples]) @@ -168,7 +181,23 @@ def collator(self, samples): modality_mask = torch.zeros_like(attention_mask) for line, sample in enumerate(samples): modality_mask[line, :sample['audio_length']] = 1 - + + if self.inference_mode: + keys = [s['key'] for s in samples] + targets = [s['target'] for s in samples] + + return { + 'input_ids': input_ids, + 'attention_mask': attention_mask, + 'audio_mel': audio_mel, + 'audio_mel_post_mask': audio_mel_post_mask, + 'modality_mask': modality_mask, + 'keys': keys, + 'targets': targets + } + + labels = torch.stack([self.pad(s['labels'], input_ids_max_length, self.IGNORE_INDEX) + for s in samples]) return { 'input_ids': input_ids, 'labels': labels, diff --git a/src/llama_recipes/datasets/speech_dataset_inference.py b/src/llama_recipes/datasets/speech_dataset_inference.py deleted file mode 100644 index 8b6ecb9e..00000000 --- a/src/llama_recipes/datasets/speech_dataset_inference.py +++ /dev/null @@ -1,167 +0,0 @@ -import os.path as osp -import random -import json, yaml -import copy - -import numpy as np -from scipy import signal -import soundfile as sf - -import torch -import torchaudio -from torch.utils.data import Dataset -import whisper -from llama_recipes.utils.compute_utils import calculate_output_length_1d - - -class SpeechDatasetJsonl(torch.utils.data.Dataset): - - def __init__(self, - dataset_config, - tokenizer=None, - split='train', - ): - super().__init__() - self.dataset_config = dataset_config - self.tokenizer = tokenizer - # data_parallel_size = dist.get_world_size() - data_parallel_size = 1 - - # self.data_list = contents - self.IGNORE_INDEX = -100 # The default setting in CrossEntropyLoss - self.prompt = dataset_config.prompt - # self.prompt_library = [ - # "Begin by converting the spoken words into written text. ", - # "Can you transcribe the speech into a written format? ", - # "Focus on translating the audible content into text. ", - # "Transcribe the speech by carefully listening to it. ", - # "Would you kindly write down the content of the speech? ", - # "Analyze the speech and create a written transcription. ", - # "Engage with the speech to produce a text-based version. ", - # "Can you document the speech in written form? ", - # "Transform the spoken words into text accurately. ", - # "How about putting the speech's content into writing? " - # ] - self.prompt_template = "USER: {}\n ASSISTANT:" - self.fix_length_audio = dataset_config.fix_length_audio - - self.data_list = [] - if split == "train": - with open(dataset_config.train_data_path, encoding='utf-8') as fin: - for line in fin: - data_dict = json.loads(line.strip()) - self.data_list.append(data_dict) - else: - with open(dataset_config.val_data_path, encoding='utf-8') as fin: - for line in fin: - data_dict = json.loads(line.strip()) - self.data_list.append(data_dict) - - # # debug - # if split == "train": - # self.data_list = contents[:80] - # else: - # self.data_list = contents[80:100] - - def get_source_len(self, data_dict): - return data_dict["source_len"] - - def get_target_len(self, data_dict): - - return data_dict["target_len"] if "target_len" in data_dict else 0 - - def __len__(self): - return len(self.data_list) - - def __getitem__(self, index): - data_dict = self.data_list[index] - audio_path = data_dict.get("source") - target = data_dict.get("target", None) - task = data_dict.get("prompt", "ASR") - key = data_dict.get("key", None) - - audio_raw = whisper.load_audio(audio_path) - audio_raw = whisper.pad_or_trim(audio_raw) - # audio_raw = np.concatenate((np.zeros(random.randint(0, 16000)), audio_raw, np.zeros(random.randint(0, 16000)))).astype(audio_raw.dtype)[:16000*30] - audio_mel = whisper.log_mel_spectrogram(audio_raw).permute(1, 0) - - prompt = self.prompt - if prompt is None: - # prompt = random.choice(self.prompt_library) - prompt = "Transcribe speech to text. Output the transcription directly without redundant content. Ensure that the output is not duplicated. " - - prompt = self.prompt_template.format(prompt) - prompt_ids = self.tokenizer.encode(prompt) - prompt_length = len(prompt_ids) - audio_length = (audio_mel.shape[0] + 1) // 2 # ad-hoc for whisper for 2x downsample from mel to feats - audio_length = audio_length // 5 # ad-hoc for 5x cov1d downsample - if self.fix_length_audio > 0: - audio_length = self.fix_length_audio - audio_pseudo = torch.full((audio_length,), -1) # placeholder - prompt_ids = torch.tensor(prompt_ids, dtype=torch.int64) - - example_ids = torch.cat((audio_pseudo, prompt_ids)) # [audio,prompt] - example_mask = example_ids.ge(-1) # [True,True] - - return { - "input_ids": example_ids, - "attention_mask": example_mask, - 'audio_mel': audio_mel, - 'audio_length': audio_length, - 'key': key, - 'target':target - } - - def pad(self, sequence, max_length, padding_idx=0): - if isinstance(sequence, (int, list, tuple)): - if len(sequence) < max_length: - sequence = sequence + [padding_idx] * (max_length - len(sequence)) - else: - sequence = sequence[:max_length] - elif isinstance(sequence, torch.Tensor): - if len(sequence) < max_length: - sequence = torch.cat( - (sequence, torch.full(([max_length - len(sequence)] + list(sequence.size())[1:]), padding_idx))) - else: - sequence = sequence[:max_length] - else: - raise Exception("Type mismatch during padding!") - return sequence - - def collator(self, samples): - assert samples is not None - input_ids_max_length = max([s['input_ids'].shape[0] for s in samples]) - input_ids = torch.stack([self.pad(s['input_ids'], input_ids_max_length, self.tokenizer.pad_token_id) - for s in samples]) - attention_mask = torch.stack([self.pad(s['attention_mask'], input_ids_max_length, False) - for s in samples]) - - audio_mel_max_length = max([s['audio_mel'].shape[0] for s in samples]) - audio_mel = torch.stack([self.pad(s['audio_mel'], audio_mel_max_length, 0) - for s in samples]) - audio_mel_post_mask = torch.zeros(len(samples), (audio_mel_max_length + 1) // 2) # ad-hoc for whisper for 2x downsample from mel to feats - for line, sample in enumerate(samples): - audio_mel_post_mask[line, :(sample['audio_mel'].shape[0] + 1) // 2] = 1 - - modality_mask = torch.zeros_like(attention_mask) - for line, sample in enumerate(samples): - modality_mask[line, :sample['audio_length']] = 1 - keys = [s['key'] for s in samples] - targets = [s['target'] for s in samples] - - return { - 'input_ids': input_ids, - 'attention_mask': attention_mask, - 'audio_mel': audio_mel, - 'audio_mel_post_mask': audio_mel_post_mask, - 'modality_mask': modality_mask, - 'keys': keys, - 'targets': targets - } - - - -def get_speech_dataset(dataset_config, tokenizer, split): - dataset = SpeechDatasetJsonl(dataset_config, tokenizer, split) - - return dataset diff --git a/src/llama_recipes/pipeline/inference_batch.py b/src/llama_recipes/pipeline/inference_batch.py index 60f07160..8d6f1f2a 100644 --- a/src/llama_recipes/pipeline/inference_batch.py +++ b/src/llama_recipes/pipeline/inference_batch.py @@ -9,7 +9,7 @@ # from llama_recipes.configs import train_config as TRAIN_CONFIG # from llama_recipes.configs import model_config as MODEL_CONFIG # from llama_recipes.configs import log_config as LOG_CONFIG -from llama_recipes.utils.config_utils import generate_dataset_config + from llama_recipes.pipeline.model_factory import model_factory from llama_recipes.utils.dataset_utils import get_preprocessed_dataset import os @@ -53,6 +53,13 @@ def main(kwargs: DictConfig): kwargs.model_config, \ kwargs.log_config, \ kwargs.dataset_config + + del kwargs.train_config + del kwargs.fsdp_config + del kwargs.model_config + del kwargs.log_config + del kwargs.dataset_config + # Set log if not os.path.exists(os.path.dirname(log_config.log_file)): os.makedirs(os.path.dirname(log_config.log_file), exist_ok=True) @@ -92,7 +99,7 @@ def main(kwargs: DictConfig): model.to(device) model.eval() - dataset_config = generate_dataset_config(train_config, kwargs) + # dataset_config = generate_dataset_config(train_config, kwargs) logger.info("dataset_config: {}".format(dataset_config)) dataset_test = get_preprocessed_dataset( tokenizer, @@ -119,7 +126,7 @@ def main(kwargs: DictConfig): with open(pred_path, "w") as pred, open(gt_path, "w") as gt: for step, batch in tqdm(enumerate(test_dataloader), total=len(test_dataloader)): for key in batch.keys(): - batch[key] = batch[key].to(device) if key not in ["keys", "targets"] else batch[key] + batch[key] = batch[key].to(device) if isinstance(batch[key], torch.Tensor) else batch[key] model_outputs = model.generate(**batch) output_text = model.tokenizer.batch_decode(model_outputs, add_special_tokens=False, skip_special_tokens=True) for key, text, target in zip(batch["keys"], output_text, batch["targets"]):