Skip to content

Commit

Permalink
Merge pull request #38 from ddlBoJack/dev-mzy
Browse files Browse the repository at this point in the history
merge raw dataset and mel dataset; support wavlm
  • Loading branch information
ddlBoJack authored Jan 27, 2024
2 parents 25f5eea + f6e6468 commit 87fc616
Show file tree
Hide file tree
Showing 9 changed files with 1,678 additions and 57 deletions.
4 changes: 2 additions & 2 deletions scripts/compute_wer.sh
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
#cd /root/SLAM-LLM

trans="/nfs/maziyang.mzy/exps/vicuna-7b-v1.5-finetune-asr-qformer64-steplrwarmupkeep1e-4-whisper-largev2-promptshort-lowergt-padding30-20240126/asr/3/decode_log_test_clean_beam4_repetition_penalty1_gt"
preds="/nfs/maziyang.mzy/exps/vicuna-7b-v1.5-finetune-asr-qformer64-steplrwarmupkeep1e-4-whisper-largev2-promptshort-lowergt-padding30-20240126/asr/3/decode_log_test_clean_beam4_repetition_penalty1_pred"
trans="/nfs/maziyang.mzy/exps/Llama-2-7b-chat-hf-finetune-asr-ds5-proj2048-steplrwarmup1e-4keep-whisper-largev2-promptshort-lowergt-padding30-20240126/asr/3/decode_log_test_clean_beam4_repetition_penalty1_gt"
preds="/nfs/maziyang.mzy/exps/Llama-2-7b-chat-hf-finetune-asr-ds5-proj2048-steplrwarmup1e-4keep-whisper-largev2-promptshort-lowergt-padding30-20240126/asr/3/decode_log_test_clean_beam4_repetition_penalty1_pred"

# python src/llama_recipes/utils/preprocess_text.py ${preds} ${preds}.proc
# python src/llama_recipes/utils/compute_wer.py ${trans} ${preds}.proc ${preds}.proc.wer
Expand Down
1 change: 1 addition & 0 deletions scripts/finetune_asr_vicuna.sh
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,7 @@ hydra.run.dir=$output_dir \
++dataset_config.dataset=speech_dataset \
++dataset_config.train_data_path=/nfs/maziyang.mzy/data/librispeech/librispeech_train_960h.jsonl \
++dataset_config.val_data_path=/nfs/maziyang.mzy/data/librispeech/librispeech_dev_other_filtered.jsonl \
++dataset_config.input_type=raw \
++train_config.model_name=asr \
++train_config.freeze_encoder=true \
++train_config.freeze_llm=true \
Expand Down
19 changes: 10 additions & 9 deletions scripts/inference_asr_batch.sh
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
#!/bin/bash
#export PYTHONPATH=/root/whisper:$PYTHONPATH
export CUDA_VISIBLE_DEVICES=3
export CUDA_VISIBLE_DEVICES=7
export TOKENIZERS_PARALLELISM=false
# export CUDA_LAUNCH_BLOCKING=1

Expand All @@ -10,29 +10,30 @@ cd /root/SLAM-LLM
# speech_encoder_path=/nfs/maziyang.mzy/models/Whisper/base.pt
# speech_encoder_path=/nfs/maziyang.mzy/models/Whisper/small.pt
# speech_encoder_path=/nfs/maziyang.mzy/models/Whisper/medium.pt
speech_encoder_path=/nfs/maziyang.mzy/models/Whisper/large-v2.pt
# speech_encoder_path=/nfs/maziyang.mzy/models/Whisper/large-v2-qwen.pt
# speech_encoder_path=/nfs/maziyang.mzy/models/Whisper/large-v2.pt
speech_encoder_path=/nfs/maziyang.mzy/models/Whisper/large-v2-qwen.pt

# llm_path=/nfs/maziyang.mzy/models/TinyLlama-1.1B-intermediate-step-1431k-3T
# llm_path=/nfs/maziyang.mzy/models/TinyLlama-1.1B-Chat-v0.4
llm_path=/nfs/maziyang.mzy/models/phi-2
# llm_path=/nfs/zhifu.gzf/ckpt/Llama-2-7b-hf
# llm_path=/nfs/maziyang.mzy/models/Llama-2-7b-chat-hf
llm_path=/nfs/maziyang.mzy/models/vicuna-7b-v1.5
# llm_path=/nfs/maziyang.mzy/models/vicuna-7b-v1.5

output_dir=/nfs/maziyang.mzy/exps/vicuna-7b-v1.5-finetune-asr-ds5-proj2048-steplrwarmupkeep1e-4-whisper-largev2-promptshort-lowergt-padding30-20240125
output_dir=/nfs/maziyang.mzy/exps/phi-2-finetune-asr-ds5-proj2048-steplrwarmup1e-4keep-whisper-largev2-promptshort-lowergt-padding30-20240126
ckpt_path=$output_dir/asr/2
# peft_ckpt=/nfs/maziyang.mzy/exps/llama-2-hf-finetune-asr-ds5-proj2048-lr1e-4-whisper-lora-prompt-paddinglr-20240102/asr/4
val_data_path=/nfs/maziyang.mzy/data/librispeech/librispeech_test_clean_filtered.jsonl
decode_log=$ckpt_path/decode_log_test_clean_beam4_repetition_penalty1
val_data_path=/nfs/maziyang.mzy/data/librispeech/librispeech_test_other_filtered.jsonl
decode_log=$ckpt_path/decode_log_test_other_beam4_repetition_penalty1

# -m debugpy --listen 5678 --wait-for-client
python src/llama_recipes/pipeline/inference_batch.py \
--config-path "/root/SLAM-LLM/scripts/conf" \
--config-name "asr_vicuna_lora.yaml" \
hydra.run.dir=$ckpt_path \
++model_config.llm_name="vicuna-7b-v1.5" \
++model_config.llm_name="phi-2" \
++model_config.llm_path=$llm_path \
++model_config.llm_dim=4096 \
++model_config.llm_dim=2560 \
++model_config.encoder_name=whisper \
++model_config.encoder_ds_rate=2 \
++model_config.encoder_path=$speech_encoder_path \
Expand Down
103 changes: 65 additions & 38 deletions src/llama_recipes/datasets/speech_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,9 @@ def __init__(self,
self.answer_template = "{}"
self.fix_length_audio = dataset_config.get("fix_length_audio", -1)
self.inference_mode = dataset_config.get("inference_mode", False)
self.normalize = dataset_config.get("normalize", False)
self.input_type = dataset_config.get("input_type", None)
assert self.input_type in ["raw", "mel"], "input_type must be one of [raw, mel]"

self.data_list = []
if split == "train":
Expand Down Expand Up @@ -85,27 +88,33 @@ def __getitem__(self, index):
target = data_dict.get("target", None)
task = data_dict.get("prompt", "ASR")
key = data_dict.get("key", None)

audio_raw = whisper.load_audio(audio_path)
audio_raw = whisper.pad_or_trim(audio_raw)
# audio_raw = np.concatenate((np.zeros(random.randint(0, 16000)), audio_raw, np.zeros(random.randint(0, 16000)))).astype(audio_raw.dtype)[:16000*30]
audio_mel = whisper.log_mel_spectrogram(audio_raw).permute(1, 0)
if self.input_type == "raw":
audio_raw = torch.from_numpy(audio_raw)
if self.normalize:
audio_raw = torch.nn.functional.layer_norm(audio_raw, audio_raw.shape)
audio_length = len(audio_raw) // 320 # ad-hoc for fairseq 320x downsample
audio_length = audio_length // 5 # ad-hoc for 5x fc downsample
elif self.input_type == "mel":
# audio_raw = whisper.pad_or_trim(audio_raw)
# audio_raw = np.concatenate((np.zeros(random.randint(0, 16000)), audio_raw, np.zeros(random.randint(0, 16000)))).astype(audio_raw.dtype)[:16000*30]
audio_mel = whisper.log_mel_spectrogram(audio_raw).permute(1, 0)
audio_length = (audio_mel.shape[0] + 1) // 2 # ad-hoc for whisper for 2x downsample from mel to feats
audio_length = audio_length // 5 # ad-hoc for 5x fc downsample
# audio_length = calculate_output_length_1d(audio_length, 5, 5, 0) # ad-hoc for 5x cov1d downsample
if self.fix_length_audio > 0:
audio_length = self.fix_length_audio
audio_pseudo = torch.full((audio_length,), -1) # placeholder

prompt = self.prompt
if prompt is None:
# prompt = random.choice(self.prompt_library)
# prompt = "Transcribe speech to text. "
prompt = "Transcribe speech to text. Output the transcription directly without redundant content. Ensure that the output is not duplicated. "

prompt = self.prompt_template.format(prompt)
prompt_ids = self.tokenizer.encode(prompt)
prompt_length = len(prompt_ids)
audio_length = (audio_mel.shape[0] + 1) // 2 # ad-hoc for whisper for 2x downsample from mel to feats
audio_length = audio_length // 5 # ad-hoc for 5x fc downsample
# audio_length = calculate_output_length_1d(audio_length, 5, 5, 0) # ad-hoc for 5x cov1d downsample
if self.fix_length_audio > 0:
audio_length = self.fix_length_audio
audio_pseudo = torch.full((audio_length,), -1) # placeholder

if self.inference_mode:
prompt_ids = torch.tensor(prompt_ids, dtype=torch.int64)
Expand All @@ -115,13 +124,14 @@ def __getitem__(self, index):
return {
"input_ids": example_ids,
"attention_mask": example_mask,
'audio_mel': audio_mel,
'audio_length': audio_length,
'key': key,
'target': target,
"audio": audio_raw if self.input_type == "raw" else None,
"audio_mel": audio_mel if self.input_type == "mel" else None,
"audio_length": audio_length,
"key": key,
"target": target,
}

answer = self.answer_template.format(target)
answer = self.answer_template.format(target.lower())
example = prompt + answer # FIX(MZY): avoid putting a bos token before answer.
example_ids = self.tokenizer.encode(example) # [prompt,answer]
example_ids.append(self.tokenizer.eos_token_id) # [prompt,answer,eos]
Expand All @@ -142,9 +152,9 @@ def __getitem__(self, index):
"input_ids": example_ids,
"labels": labels_ids,
"attention_mask": example_mask,
'audio_mel': audio_mel,
'audio_length': audio_length,

"audio": audio_raw if self.input_type == "raw" else None,
"audio_mel": audio_mel if self.input_type == "mel" else None,
"audio_length": audio_length,
}

def pad(self, sequence, max_length, padding_idx=0):
Expand All @@ -159,6 +169,12 @@ def pad(self, sequence, max_length, padding_idx=0):
(sequence, torch.full(([max_length - len(sequence)] + list(sequence.size())[1:]), padding_idx)))
else:
sequence = sequence[:max_length]
elif isinstance(sequence, np.ndarray):
if len(sequence) < max_length:
sequence = np.concatenate(
(sequence, np.full((max_length - len(sequence),) + sequence.shape[1:], padding_idx)))
else:
sequence = sequence[:max_length]
else:
raise Exception("Type mismatch during padding!")
return sequence
Expand All @@ -170,13 +186,20 @@ def collator(self, samples):
for s in samples])
attention_mask = torch.stack([self.pad(s['attention_mask'], input_ids_max_length, False)
for s in samples])

audio_mel_max_length = max([s['audio_mel'].shape[0] for s in samples])
audio_mel = torch.stack([self.pad(s['audio_mel'], audio_mel_max_length, 0)
if self.input_type == "raw":
audio_raw_max_length = max([s['audio'].shape[0] for s in samples])
audio_raw = torch.stack([self.pad(s['audio'], audio_raw_max_length, 0)
for s in samples])
audio_mask = torch.zeros(len(samples), audio_raw_max_length)
for line, sample in enumerate(samples):
audio_mask[line, :sample['audio'].shape[0]] = 1
elif self.input_type == "mel":
audio_mel_max_length = max([s['audio_mel'].shape[0] for s in samples])
audio_mel = torch.stack([self.pad(s['audio_mel'], audio_mel_max_length, 0)
for s in samples])
audio_mel_post_mask = torch.zeros(len(samples), (audio_mel_max_length + 1) // 2) # ad-hoc for whisper for 2x downsample from mel to feats
for line, sample in enumerate(samples):
audio_mel_post_mask[line, :(sample['audio_mel'].shape[0] + 1) // 2] = 1
audio_mel_post_mask = torch.zeros(len(samples), (audio_mel_max_length + 1) // 2) # ad-hoc for whisper for 2x downsample from mel to feats
for line, sample in enumerate(samples):
audio_mel_post_mask[line, :(sample['audio_mel'].shape[0] + 1) // 2] = 1

modality_mask = torch.zeros_like(attention_mask)
for line, sample in enumerate(samples):
Expand All @@ -187,24 +210,28 @@ def collator(self, samples):
targets = [s['target'] for s in samples]

return {
'input_ids': input_ids,
'attention_mask': attention_mask,
'audio_mel': audio_mel,
'audio_mel_post_mask': audio_mel_post_mask,
'modality_mask': modality_mask,
'keys': keys,
'targets': targets
"input_ids": input_ids,
"attention_mask": attention_mask,
"audio": audio_raw if self.input_type == "raw" else None,
"audio_mask": audio_mask if self.input_type == "raw" else None,
"audio_mel": audio_mel if self.input_type == "mel" else None,
"audio_mel_post_mask": audio_mel_post_mask if self.input_type == "mel" else None,
"modality_mask": modality_mask,
"keys": keys,
"targets": targets
}

labels = torch.stack([self.pad(s['labels'], input_ids_max_length, self.IGNORE_INDEX)
for s in samples])
return {
'input_ids': input_ids,
'labels': labels,
'attention_mask': attention_mask,
'audio_mel': audio_mel,
'audio_mel_post_mask': audio_mel_post_mask,
'modality_mask': modality_mask
"input_ids": input_ids,
"labels": labels,
"attention_mask": attention_mask,
"audio": audio_raw if self.input_type == "raw" else None,
"audio_mask": audio_mask if self.input_type == "raw" else None,
"audio_mel": audio_mel if self.input_type == "mel" else None,
"audio_mel_post_mask": audio_mel_post_mask if self.input_type == "mel" else None,
"modality_mask": modality_mask
}


Expand Down
21 changes: 21 additions & 0 deletions src/llama_recipes/models/encoder.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import types
import torch
import torch.nn as nn
import torch.nn.functional as F

class WhisperWrappedEncoder:
Expand Down Expand Up @@ -45,6 +46,26 @@ def load(cls, model_config):
return BEATs_model


class WavLMEncoder(nn.Module):
def __init__(self, config, model):
super().__init__()
self.config = config
self.model = model

@classmethod
def load(cls, model_config):
from .wavlm.WavLM import WavLM, WavLMConfig
checkpoint = torch.load(model_config.encoder_path)
cfg = WavLMConfig(checkpoint['cfg'])
WavLM_model = WavLM(cfg)
WavLM_model.load_state_dict(checkpoint['model'])
assert model_config.normalize == cfg.normalize, "normalize flag in config and model checkpoint do not match"

return cls(cfg, WavLM_model)

def extract_features(self, source, padding_mask):
return self.model.extract_features(source, padding_mask)[0]

class AVEncoder:

@classmethod
Expand Down
5 changes: 5 additions & 0 deletions src/llama_recipes/models/slam_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,9 @@ def setup_encoder(train_config, model_config, **kwargs):
if encoder_name == "beats":
from llama_recipes.models.encoder import BEATsEncoder
encoder = BEATsEncoder.load(model_config)
if encoder_name == "wavlm":
from llama_recipes.models.encoder import WavLMEncoder
encoder = WavLMEncoder.load(model_config)
if encoder_name == "moco_wav2vec2":
from llama_recipes.models.encoder import AVEncoder
encoder = AVEncoder.load(model_config)
Expand Down Expand Up @@ -197,6 +200,8 @@ def forward(self,
encoder_outs = self.encoder.extract_variable_length_features(audio_mel.permute(0, 2, 1)) # bs*seq*dim
if self.model_config.encoder_name == "beats":
encoder_outs, audio_mel_post_mask = self.encoder.extract_features(audio_mel, audio_mel_mask) # bs*seq*dim
if self.model_config.encoder_name == "wavlm":
encoder_outs = self.encoder.extract_features(audio, 1 - audio_mask) #(FIX:MZY): 1-audio_mask is needed for wavlm as the padding mask
if self.model_config.encoder_name == "moco_wav2vec2":
encoder_outs , inputLenBatch, audio_mel_post_mask = self.encoder((audio, audio_mask, visual, vis_len) ,maskw2v) # bs*seq*dim
if self.encoder is None:
Expand Down
Loading

0 comments on commit 87fc616

Please sign in to comment.