From ca250eded33ce3c13713dca00e3181c5b9e804ae Mon Sep 17 00:00:00 2001 From: MahmoudAshraf97 Date: Wed, 20 Nov 2024 22:27:15 +0200 Subject: [PATCH 1/3] use `jiwer` instead of `evaluate` --- benchmark/evaluate_yt_commons.py | 18 ++++-------------- benchmark/requirements.benchmark.txt | 1 - benchmark/wer_benchmark.py | 11 +++-------- 3 files changed, 7 insertions(+), 23 deletions(-) diff --git a/benchmark/evaluate_yt_commons.py b/benchmark/evaluate_yt_commons.py index 0511be6d..94803d79 100644 --- a/benchmark/evaluate_yt_commons.py +++ b/benchmark/evaluate_yt_commons.py @@ -5,9 +5,8 @@ from io import BytesIO from datasets import load_dataset -from evaluate import load +from jiwer import wer from pytubefix import YouTube -from torch.utils.data import DataLoader from tqdm import tqdm from transformers.models.whisper.english_normalizer import EnglishTextNormalizer @@ -39,19 +38,12 @@ def url_to_audio(row): ) args = parser.parse_args() -# define the evaluation metric -wer_metric = load("wer") - with open(os.path.join(os.path.dirname(__file__), "normalizer.json"), "r") as f: normalizer = EnglishTextNormalizer(json.load(f)) dataset = load_dataset("mobiuslabsgmbh/youtube-commons-asr-eval", streaming=True).map( url_to_audio ) -dataset = iter( - DataLoader(dataset["test"], batch_size=1, prefetch_factor=4, num_workers=2) -) - model = WhisperModel("large-v3", device="cuda") pipeline = BatchedInferencePipeline(model, device="cuda") @@ -59,7 +51,7 @@ def url_to_audio(row): all_transcriptions = [] all_references = [] # iterate over the dataset and run inference -for i, row in tqdm(enumerate(dataset), desc="Evaluating..."): +for i, row in tqdm(enumerate(dataset["test"]), desc="Evaluating..."): result, info = pipeline.transcribe( row["audio"][0], batch_size=8, @@ -77,7 +69,5 @@ def url_to_audio(row): all_references = [normalizer(reference) for reference in all_references] # compute the WER metric -wer = 100 * wer_metric.compute( - predictions=all_transcriptions, references=all_references -) -print("WER: %.3f" % wer) +word_error_rate = 100 * wer(hypothesis=all_transcriptions, reference=all_references) +print("WER: %.3f" % word_error_rate) diff --git a/benchmark/requirements.benchmark.txt b/benchmark/requirements.benchmark.txt index c49dccaf..674c23ec 100644 --- a/benchmark/requirements.benchmark.txt +++ b/benchmark/requirements.benchmark.txt @@ -1,6 +1,5 @@ transformers jiwer -evaluate datasets memory_profiler py3nvml diff --git a/benchmark/wer_benchmark.py b/benchmark/wer_benchmark.py index f7a0b792..2bc1bfb3 100644 --- a/benchmark/wer_benchmark.py +++ b/benchmark/wer_benchmark.py @@ -3,7 +3,7 @@ import os from datasets import load_dataset -from evaluate import load +from jiwer import wer from tqdm import tqdm from transformers.models.whisper.english_normalizer import EnglishTextNormalizer @@ -25,9 +25,6 @@ # load the dataset with streaming mode dataset = load_dataset("librispeech_asr", "clean", split="validation", streaming=True) -# define the evaluation metric -wer_metric = load("wer") - with open(os.path.join(os.path.dirname(__file__), "normalizer.json"), "r") as f: normalizer = EnglishTextNormalizer(json.load(f)) @@ -58,7 +55,5 @@ def inference(batch): all_references = [normalizer(reference) for reference in all_references] # compute the WER metric -wer = 100 * wer_metric.compute( - predictions=all_transcriptions, references=all_references -) -print("WER: %.3f" % wer) +word_error_rate = 100 * wer(hypothesis=all_transcriptions, reference=all_references) +print("WER: %.3f" % word_error_rate) From b64bdfbde44397d06fbd7ff99d3334faad0012d1 Mon Sep 17 00:00:00 2001 From: MahmoudAshraf97 Date: Wed, 20 Nov 2024 22:45:52 +0200 Subject: [PATCH 2/3] skip failing audios --- benchmark/evaluate_yt_commons.py | 24 +++++++++++++++--------- 1 file changed, 15 insertions(+), 9 deletions(-) diff --git a/benchmark/evaluate_yt_commons.py b/benchmark/evaluate_yt_commons.py index 94803d79..15353da8 100644 --- a/benchmark/evaluate_yt_commons.py +++ b/benchmark/evaluate_yt_commons.py @@ -16,15 +16,19 @@ def url_to_audio(row): buffer = BytesIO() yt = YouTube(row["link"]) - video = ( - yt.streams.filter(only_audio=True, mime_type="audio/mp4") - .order_by("bitrate") - .desc() - .first() - ) - video.stream_to_buffer(buffer) - buffer.seek(0) - row["audio"] = decode_audio(buffer) + try: + video = ( + yt.streams.filter(only_audio=True, mime_type="audio/mp4") + .order_by("bitrate") + .desc() + .last() + ) + video.stream_to_buffer(buffer) + buffer.seek(0) + row["audio"] = decode_audio(buffer) + except: + print(f'Failed to download: {row["link"]}') + row["audio"] = [] return row @@ -52,6 +56,8 @@ def url_to_audio(row): all_references = [] # iterate over the dataset and run inference for i, row in tqdm(enumerate(dataset["test"]), desc="Evaluating..."): + if not row["audio"]: + continue result, info = pipeline.transcribe( row["audio"][0], batch_size=8, From a17ca258c91a2e5c84b58269b21dbbe1141a1674 Mon Sep 17 00:00:00 2001 From: MahmoudAshraf97 Date: Wed, 20 Nov 2024 22:49:35 +0200 Subject: [PATCH 3/3] remove bare except --- benchmark/evaluate_yt_commons.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/benchmark/evaluate_yt_commons.py b/benchmark/evaluate_yt_commons.py index 15353da8..cbdce4f1 100644 --- a/benchmark/evaluate_yt_commons.py +++ b/benchmark/evaluate_yt_commons.py @@ -7,6 +7,7 @@ from datasets import load_dataset from jiwer import wer from pytubefix import YouTube +from pytubefix.exceptions import VideoUnavailable from tqdm import tqdm from transformers.models.whisper.english_normalizer import EnglishTextNormalizer @@ -26,7 +27,7 @@ def url_to_audio(row): video.stream_to_buffer(buffer) buffer.seek(0) row["audio"] = decode_audio(buffer) - except: + except VideoUnavailable: print(f'Failed to download: {row["link"]}') row["audio"] = [] return row