Skip to content

Commit

Permalink
merge inference dataset
Browse files Browse the repository at this point in the history
  • Loading branch information
ddlBoJack committed Jan 26, 2024
1 parent de6d84e commit 4e96ea6
Show file tree
Hide file tree
Showing 5 changed files with 80 additions and 205 deletions.
3 changes: 2 additions & 1 deletion scripts/conf/asr_vicuna_lora.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ train_config:
freeze_encoder: false

dataset_config:
dataset: "samsum_dataset"
dataset: "speech_dataset"
file: "src/llama_recipes/datasets/speech_dataset.py:get_speech_dataset"
train_data_path: null
val_data_path: null
Expand All @@ -92,6 +92,7 @@ dataset_config:
max_words: null
max_mel: null
fix_length_audio: -1
inference_mode: false

fsdp_config:
mixed_precision: true
Expand Down
61 changes: 33 additions & 28 deletions scripts/inference_asr_batch.sh
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
#!/bin/bash
#export PYTHONPATH=/root/whisper:$PYTHONPATH
export CUDA_VISIBLE_DEVICES=1
export CUDA_VISIBLE_DEVICES=0
export TOKENIZERS_PARALLELISM=false
# export CUDA_LAUNCH_BLOCKING=1

Expand All @@ -16,39 +16,44 @@ speech_encoder_path=/nfs/maziyang.mzy/models/Whisper/large-v2.pt
# llm_path=/nfs/maziyang.mzy/models/TinyLlama-1.1B-intermediate-step-1431k-3T
# llm_path=/nfs/maziyang.mzy/models/TinyLlama-1.1B-Chat-v0.4
# llm_path=/nfs/zhifu.gzf/ckpt/Llama-2-7b-hf
llm_path=/nfs/maziyang.mzy/models/Llama-2-7b-chat-hf
# llm_path=/nfs/maziyang.mzy/models/vicuna-7b-v1.5
# llm_path=/nfs/maziyang.mzy/models/Llama-2-7b-chat-hf
llm_path=/nfs/maziyang.mzy/models/vicuna-7b-v1.5

output_dir=/nfs/maziyang.mzy/exps/Llama-2-7b-chat-hf-finetune-asr-ds5-proj2048-steplr-whisper-largev2-prompt-lowergt-padding30-20240124
ckpt_path=$output_dir/asr/4
output_dir=/nfs/maziyang.mzy/exps/vicuna-7b-v1.5-finetune-asr-ds5-proj2048-steplrwarmupkeep1e-4-whisper-largev2-promptshort-lowergt-padding30-20240125
ckpt_path=$output_dir/asr/2
# peft_ckpt=/nfs/maziyang.mzy/exps/llama-2-hf-finetune-asr-ds5-proj2048-lr1e-4-whisper-lora-prompt-paddinglr-20240102/asr/4
val_data_path=/nfs/maziyang.mzy/data/librispeech/librispeech_test_clean_filtered.jsonl
decode_log=$ckpt_path/decode_log_test_clean_beam4_repetition_penalty1

# -m debugpy --listen 5678 --wait-for-client
python src/llama_recipes/pipeline/inference_batch.py \
--model_name asr \
--llm_name Llama-2-7b-chat-hf \
--llm_path $llm_path \
--llm_dim 4096 \
--encoder_name whisper \
--encoder_ds_rate 2 \
--encoder_path $speech_encoder_path \
--encoder_dim 1280 \
--encoder_projector linear \
--encoder_projector_ds_rate 5 \
--dataset speech_dataset \
--speech_dataset.file src/llama_recipes/datasets/speech_dataset_inference.py:get_speech_dataset \
--speech_dataset.val_data_path $val_data_path \
--batching_strategy custom \
--num_epochs 1 \
--val_batch_size 4 \
--num_workers_dataloader 4 \
--output_dir $output_dir \
--ckpt_path $ckpt_path/model.pt \
--decode_log $decode_log \
--freeze_llm \
--freeze_encoder \
# --speech_dataset.prompt "Transcribe speech to text." \
--config-path "/root/SLAM-LLM/scripts/conf" \
--config-name "asr_vicuna_lora.yaml" \
hydra.run.dir=$ckpt_path \
++model_config.llm_name="vicuna-7b-v1.5" \
++model_config.llm_path=$llm_path \
++model_config.llm_dim=4096 \
++model_config.encoder_name=whisper \
++model_config.encoder_ds_rate=2 \
++model_config.encoder_path=$speech_encoder_path \
++model_config.encoder_dim=1280 \
++model_config.encoder_projector=linear \
++model_config.encoder_projector_ds_rate=5 \
++dataset_config.dataset=speech_dataset \
++dataset_config.prompt="Transcribe speech to text. " \
++dataset_config.val_data_path=$val_data_path \
++dataset_config.inference_mode=true \
++train_config.model_name=asr \
++train_config.batching_strategy=custom \
++train_config.num_epochs=1 \
++train_config.val_batch_size=4 \
++train_config.num_workers_dataloader=4 \
++train_config.output_dir=$output_dir \
++ckpt_path=$ckpt_path/model.pt \
++decode_log=$decode_log \
++train_config.freeze_encoder=true \
++train_config.freeze_llm=true \
# ++model_config.encoder_projector=q-former \
# ++dataset_config.fix_length_audio=64 \
# --peft_ckpt $peft_ckpt \
# --use_peft --peft_method lora \
41 changes: 35 additions & 6 deletions src/llama_recipes/datasets/speech_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ def __init__(self,
self.prompt_template = "USER: {}\n ASSISTANT:"
self.answer_template = "{}"
self.fix_length_audio = dataset_config.get("fix_length_audio", -1)
self.inference_mode = dataset_config.get("inference_mode", False)

self.data_list = []
if split == "train":
Expand Down Expand Up @@ -83,6 +84,7 @@ def __getitem__(self, index):
audio_path = data_dict.get("source")
target = data_dict.get("target", None)
task = data_dict.get("prompt", "ASR")
key = data_dict.get("key", None)

audio_raw = whisper.load_audio(audio_path)
audio_raw = whisper.pad_or_trim(audio_raw)
Expand All @@ -94,11 +96,9 @@ def __getitem__(self, index):
# prompt = random.choice(self.prompt_library)
# prompt = "Transcribe speech to text. "
prompt = "Transcribe speech to text. Output the transcription directly without redundant content. Ensure that the output is not duplicated. "

prompt = self.prompt_template.format(prompt)
answer = self.answer_template.format(target)

prompt_ids = self.tokenizer.encode(prompt)

prompt_length = len(prompt_ids)
audio_length = (audio_mel.shape[0] + 1) // 2 # ad-hoc for whisper for 2x downsample from mel to feats
audio_length = audio_length // 5 # ad-hoc for 5x fc downsample
Expand All @@ -107,6 +107,21 @@ def __getitem__(self, index):
audio_length = self.fix_length_audio
audio_pseudo = torch.full((audio_length,), -1) # placeholder

if self.inference_mode:
prompt_ids = torch.tensor(prompt_ids, dtype=torch.int64)
example_ids = torch.cat((audio_pseudo, prompt_ids)) # [audio,prompt]
example_mask = example_ids.ge(-1) # [True,True]

return {
"input_ids": example_ids,
"attention_mask": example_mask,
'audio_mel': audio_mel,
'audio_length': audio_length,
'key': key,
'target': target,
}

answer = self.answer_template.format(target)
example = prompt + answer # FIX(MZY): avoid putting a bos token before answer.
example_ids = self.tokenizer.encode(example) # [prompt,answer]
example_ids.append(self.tokenizer.eos_token_id) # [prompt,answer,eos]
Expand Down Expand Up @@ -153,8 +168,6 @@ def collator(self, samples):
input_ids_max_length = max([s['input_ids'].shape[0] for s in samples])
input_ids = torch.stack([self.pad(s['input_ids'], input_ids_max_length, self.tokenizer.pad_token_id)
for s in samples])
labels = torch.stack([self.pad(s['labels'], input_ids_max_length, self.IGNORE_INDEX)
for s in samples])
attention_mask = torch.stack([self.pad(s['attention_mask'], input_ids_max_length, False)
for s in samples])

Expand All @@ -168,7 +181,23 @@ def collator(self, samples):
modality_mask = torch.zeros_like(attention_mask)
for line, sample in enumerate(samples):
modality_mask[line, :sample['audio_length']] = 1


if self.inference_mode:
keys = [s['key'] for s in samples]
targets = [s['target'] for s in samples]

return {
'input_ids': input_ids,
'attention_mask': attention_mask,
'audio_mel': audio_mel,
'audio_mel_post_mask': audio_mel_post_mask,
'modality_mask': modality_mask,
'keys': keys,
'targets': targets
}

labels = torch.stack([self.pad(s['labels'], input_ids_max_length, self.IGNORE_INDEX)
for s in samples])
return {
'input_ids': input_ids,
'labels': labels,
Expand Down
167 changes: 0 additions & 167 deletions src/llama_recipes/datasets/speech_dataset_inference.py

This file was deleted.

13 changes: 10 additions & 3 deletions src/llama_recipes/pipeline/inference_batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
# from llama_recipes.configs import train_config as TRAIN_CONFIG
# from llama_recipes.configs import model_config as MODEL_CONFIG
# from llama_recipes.configs import log_config as LOG_CONFIG
from llama_recipes.utils.config_utils import generate_dataset_config

from llama_recipes.pipeline.model_factory import model_factory
from llama_recipes.utils.dataset_utils import get_preprocessed_dataset
import os
Expand Down Expand Up @@ -53,6 +53,13 @@ def main(kwargs: DictConfig):
kwargs.model_config, \
kwargs.log_config, \
kwargs.dataset_config

del kwargs.train_config
del kwargs.fsdp_config
del kwargs.model_config
del kwargs.log_config
del kwargs.dataset_config

# Set log
if not os.path.exists(os.path.dirname(log_config.log_file)):
os.makedirs(os.path.dirname(log_config.log_file), exist_ok=True)
Expand Down Expand Up @@ -92,7 +99,7 @@ def main(kwargs: DictConfig):
model.to(device)
model.eval()

dataset_config = generate_dataset_config(train_config, kwargs)
# dataset_config = generate_dataset_config(train_config, kwargs)
logger.info("dataset_config: {}".format(dataset_config))
dataset_test = get_preprocessed_dataset(
tokenizer,
Expand All @@ -119,7 +126,7 @@ def main(kwargs: DictConfig):
with open(pred_path, "w") as pred, open(gt_path, "w") as gt:
for step, batch in tqdm(enumerate(test_dataloader), total=len(test_dataloader)):
for key in batch.keys():
batch[key] = batch[key].to(device) if key not in ["keys", "targets"] else batch[key]
batch[key] = batch[key].to(device) if isinstance(batch[key], torch.Tensor) else batch[key]
model_outputs = model.generate(**batch)
output_text = model.tokenizer.batch_decode(model_outputs, add_special_tokens=False, skip_special_tokens=True)
for key, text, target in zip(batch["keys"], output_text, batch["targets"]):
Expand Down

0 comments on commit 4e96ea6

Please sign in to comment.