Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Yxdu #179

Merged
merged 3 commits into from
Nov 27, 2024
Merged

Yxdu #179

Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15,529 changes: 0 additions & 15,529 deletions examples/st_covost2/covost2_zh.jsonl

This file was deleted.

53 changes: 11 additions & 42 deletions examples/st_covost2/dataset/hf_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,47 +27,12 @@ def __init__(self,
super().__init__()
self.mel_size = dataset_config.get("mel_size", 80) # 80 for whisper large v1 and v2, 128 for large v3

rank = dist.get_rank()


data_name = "yxdu/covost2_en_x"
local_dataset_path= data_name.split("/")[-1]+"_"+split+"_cache"

if os.path.exists(local_dataset_path):
ds = load_from_disk(local_dataset_path)
print(ds)
else:
if rank==0:
ds = load_dataset(data_name, split=split)
ds = ds.cast_column("audio", Audio(sampling_rate=16000))
print(ds)



def prepare_dataset(example):
audio_raw = whisper.pad_or_trim(example["audio"]["array"])

audio_raw = torch.tensor(audio_raw, dtype=torch.float32)
audio_mel = whisper.log_mel_spectrogram(audio_raw, n_mels=self.mel_size).permute(1, 0)

example["audio_mel"] = audio_mel


return example

ds = ds.map(prepare_dataset, remove_columns="audio")

ds.save_to_disk(local_dataset_path)

dist.barrier()
if rank != 0:
if os.path.exists(local_dataset_path):
ds = load_from_disk(local_dataset_path)
else:
raise FileNotFoundError("No Dataset。")



if split=="val":
split="validation"
ds = load_dataset("yxdu/covost2_en_x",split=split)
ds = ds.cast_column("audio", Audio(sampling_rate=16000))
print(ds)


self.ds = ds
self.tokenizer = tokenizer
Expand Down Expand Up @@ -111,8 +76,12 @@ def __getitem__(self, index):
print(target)
self.printed = True

audio_raw = whisper.pad_or_trim(data_dict["audio"]["array"])
audio_raw = torch.tensor(audio_raw, dtype=torch.float32)
audio_mel = whisper.log_mel_spectrogram(audio_raw, n_mels=self.mel_size).permute(1, 0)

if self.bf16:
audio_mel = torch.tensor(data_dict["audio_mel"], dtype=torch.bfloat16)
audio_mel = audio_mel.to(torch.bfloat16)


if self.fix_length_audio > 0:
Expand Down
10 changes: 6 additions & 4 deletions examples/st_covost2/inference_asr_batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,11 +76,12 @@ def __len__(self):
def Inference(kwargs: DictConfig):

# Update the configuration for the training and sharding process
train_config, fsdp_config, model_config, log_config, dataset_config = kwargs.train_config, \
train_config, fsdp_config, model_config, log_config, dataset_config,ckpt_path = kwargs.train_config, \
kwargs.fsdp_config, \
kwargs.model_config, \
kwargs.log_config, \
kwargs.dataset_config
kwargs.dataset_config, \
kwargs.ckpt_path

OmegaConf.set_struct(kwargs,False)
del kwargs["train_config"]
Expand Down Expand Up @@ -114,8 +115,8 @@ def Inference(kwargs: DictConfig):

config = AutoConfig.from_pretrained("Qwen/Qwen2-7B") # 加载 Qwen2-7B 的配置
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2-7B")
model = CustomSLM(config,ckpt_path="cotst/model.pt")
model = CustomSLM(config,ckpt_path=ckpt_path)
# model = AutoModel.from_pretrained("/home/yxdu/hit/SLAM-LLM/examples/st_covost2/output/step_10/test")


device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # FIX(MZY): put the whole model to device.
Expand Down Expand Up @@ -143,6 +144,7 @@ def Inference(kwargs: DictConfig):
batch_size=train_config.val_batch_size,
drop_last=False,
prefetch_factor=1000,
persistent_workers=True,
collate_fn=dataset_test.collator
)

Expand Down
20 changes: 10 additions & 10 deletions examples/st_covost2/model/slm_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,12 +66,10 @@ def forward(self,
audio_mel = kwargs.get("audio_mel", None)
audio_mel_post_mask = kwargs.get("audio_mel_post_mask", None) # 2x downsample for whisper


encoder_outs = self.encoder(audio_mel.permute(0, 2, 1)).last_hidden_state # bs*seq*dim
encoder_outs = self.encoder_projector(encoder_outs, audio_mel_post_mask)

input_ids = input_ids[:, 80:]

inputs_embeds = self.llm.model.embed_tokens(input_ids)
inputs_embeds = torch.cat((encoder_outs, inputs_embeds), dim=1)

Expand All @@ -80,14 +78,16 @@ def forward(self,


model_outputs = self.llm(inputs_embeds=inputs_embeds, attention_mask=attention_mask, labels=labels,)
acc = -1
if self.metric:
with torch.no_grad():
preds = torch.argmax(input=model_outputs.logits, dim=-1)
acc = compute_accuracy(preds.detach()[:, :-1], labels.detach()[:, 1:], ignore_label=-100)


return model_outputs, acc


with torch.no_grad():
preds = torch.argmax(input=model_outputs.logits, dim=-1)
acc = compute_accuracy(preds.detach()[:, :-1], labels.detach()[:, 1:], ignore_label=-100)
print(acc)

return model_outputs

# return model_outputs, acc

@torch.no_grad()
def generate(self,
Expand Down
18 changes: 7 additions & 11 deletions examples/st_covost2/scripts/all.sh
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# export TOKENIZERS_PARALLELISM=false
export WANDB_MODE=offline
# export HYDRA_FULL_ERROR=1

export CUDA_VISIBLE_DEVICES=0,1
if command -v nvidia-smi &> /dev/null; then
gpu_count=$(nvidia-smi --query-gpu=name --format=csv,noheader | wc -l)
if [ -n "$CUDA_VISIBLE_DEVICES" ]; then
Expand All @@ -15,7 +15,7 @@ current_dir=$(dirname "$current_script")
code_dir=$(realpath "$current_dir/../../../../")
cd ${code_dir}/SLAM-LLM

source=all
source=zh

checkpoint_dir=${code_dir}/speech/data/qwen/spt-all-7B-4
output_dir=${code_dir}/speech/data/qwen/cotst-all
Expand All @@ -24,11 +24,6 @@ encoder_path_hf=${code_dir}/speech/models/whisper-large-v3
llm_path=${code_dir}/speech/models/Qwen2-7B


#change your train data
train_data_path=${code_dir}/SLAM-LLM/examples/st_covost2/test_st.jsonl
val_data_path=${code_dir}/SLAM-LLM/examples/st_covost2/test_st.jsonl




max_epoch=$(ls -d ${checkpoint_dir}/asr_epoch_*_step_* | sed -n 's/.*asr_epoch_\([0-9]*\)_step_\([0-9]*\).*/\1/p' | sort -n | tail -1)
Expand All @@ -40,7 +35,7 @@ final_path="${checkpoint_dir}/asr_epoch_${max_epoch}_step_${max_step}"


ckpt_name=$final_path/model.pt

ckpt_name=/home/yxdu/hit/SLAM-LLM/cotst/model.pt
# 使用find命令搜索所有.pt文件,并获取最后修改日期最晚的文件


Expand All @@ -62,7 +57,8 @@ hydra.run.dir=$output_dir \
++model_config.encoder_dim=1280 \
++model_config.encoder_projector=q-former \
++model_config.query_len=80 \
++dataset_config.dataset=st_dataset \
++dataset_config.dataset=hf_dataset \
++dataset_config.file=examples/st_covost2/dataset/hf_dataset.py:get_speech_dataset \
++dataset_config.train_data_path=$train_data_path \
++dataset_config.val_data_path=$val_data_path \
++dataset_config.input_type=mel \
Expand All @@ -74,7 +70,7 @@ hydra.run.dir=$output_dir \
++train_config.freeze_encoder=true \
++train_config.freeze_llm=true \
++train_config.batching_strategy=custom \
++train_config.gradient_accumulation_steps=1 \
++train_config.gradient_accumulation_steps=8 \
++train_config.warmup_steps=1000 \
++train_config.total_steps=1000000 \
++train_config.lr=1e-5 \
Expand All @@ -101,7 +97,7 @@ torchrun \
++fsdp_config.pure_bf16=true \
++log_config.use_wandb=true \
++log_config.wandb_project_name=cot \
++train_config.validation_interval=100 \
++train_config.validation_interval=10000 \
++log_config.wandb_exp_name=all \
++train_config.use_peft=false \
$hydra_args
Expand Down
3 changes: 2 additions & 1 deletion examples/st_covost2/scripts/infer_enzh.sh
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
export MASTER_ADDR=localhost
export MASTER_PORT=12345
export WANDB_MODE=offline

export CUDA_VISIBLE_DEVICES=2,3
if command -v nvidia-smi &> /dev/null; then
gpu_count=$(nvidia-smi --query-gpu=name --format=csv,noheader | wc -l)
if [ -n "$CUDA_VISIBLE_DEVICES" ]; then
Expand Down Expand Up @@ -32,6 +32,7 @@ if [ ! -f "$ckpt_path" ]; then
echo "Download ckpt..."
git clone https://huggingface.co/yxdu/cotst
fi

echo $ckpt_path


Expand Down
Loading
Loading