Skip to content

Commit

Permalink
split pkgs only for eval usage address #97; clean-up
Browse files Browse the repository at this point in the history
  • Loading branch information
SWivid committed Oct 15, 2024
1 parent 423fe4a commit bc63315
Show file tree
Hide file tree
Showing 5 changed files with 18 additions and 17 deletions.
6 changes: 6 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,12 @@ bash scripts/eval_infer_batch.sh

### Objective Evaluation

Install packages for evaluation:

```bash
pip install -r requirements_eval.txt
```

**Some Notes**

For faster-whisper with CUDA 11:
Expand Down
2 changes: 0 additions & 2 deletions gradio_app.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import os
import re
import torch
import torchaudio
Expand All @@ -17,7 +16,6 @@
save_spectrogram,
)
from transformers import pipeline
import librosa
import click
import soundfile as sf

Expand Down
13 changes: 6 additions & 7 deletions model/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,12 +22,6 @@

import jieba
from pypinyin import lazy_pinyin, Style
import zhconv
from zhon.hanzi import punctuation
from jiwer import compute_measures

from funasr import AutoModel
from faster_whisper import WhisperModel

from model.ecapa_tdnn import ECAPA_TDNN_SMALL
from model.modules import MelSpec
Expand Down Expand Up @@ -432,6 +426,7 @@ def get_librispeech_test(metalst, gen_wav_dir, gpus, librispeech_test_clean_path

def load_asr_model(lang, ckpt_dir = ""):
if lang == "zh":
from funasr import AutoModel
model = AutoModel(
model = os.path.join(ckpt_dir, "paraformer-zh"),
# vad_model = os.path.join(ckpt_dir, "fsmn-vad"),
Expand All @@ -440,6 +435,7 @@ def load_asr_model(lang, ckpt_dir = ""):
disable_update=True,
) # following seed-tts setting
elif lang == "en":
from faster_whisper import WhisperModel
model_size = "large-v3" if ckpt_dir == "" else ckpt_dir
model = WhisperModel(model_size, device="cuda", compute_type="float16")
return model
Expand All @@ -451,17 +447,20 @@ def run_asr_wer(args):
rank, lang, test_set, ckpt_dir = args

if lang == "zh":
import zhconv
torch.cuda.set_device(rank)
elif lang == "en":
os.environ["CUDA_VISIBLE_DEVICES"] = str(rank)
else:
raise NotImplementedError("lang support only 'zh' (funasr paraformer-zh), 'en' (faster-whisper-large-v3), for now.")

asr_model = load_asr_model(lang, ckpt_dir = ckpt_dir)


from zhon.hanzi import punctuation
punctuation_all = punctuation + string.punctuation
wers = []

from jiwer import compute_measures
for gen_wav, prompt_wav, truth in tqdm(test_set):
if lang == "zh":
res = asr_model.generate(input=gen_wav, batch_size_s=300, disable_pbar=True)
Expand Down
9 changes: 1 addition & 8 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -5,26 +5,19 @@ datasets
einops>=0.8.0
einx>=0.3.0
ema_pytorch>=0.5.2
faster_whisper
funasr
gradio
jieba
jiwer
librosa
matplotlib
numpy<=1.26.4
pydub
pypinyin
safetensors
soundfile
# torch>=2.0
# torchaudio>=2.3.0
tomli
torchdiffeq
tqdm>=4.65.0
transformers
vocos
wandb
x_transformers>=1.31.14
zhconv
zhon
tomli
5 changes: 5 additions & 0 deletions requirements_eval.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
faster_whisper
funasr
jiwer
zhconv
zhon

0 comments on commit bc63315

Please sign in to comment.