Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

use jiwer instead of evaluate in benchmarks #1159

Merged
merged 3 commits into from
Nov 20, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 20 additions & 23 deletions benchmark/evaluate_yt_commons.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,9 @@
from io import BytesIO

from datasets import load_dataset
from evaluate import load
from jiwer import wer
from pytubefix import YouTube
from torch.utils.data import DataLoader
from pytubefix.exceptions import VideoUnavailable
from tqdm import tqdm
from transformers.models.whisper.english_normalizer import EnglishTextNormalizer

Expand All @@ -17,15 +17,19 @@
def url_to_audio(row):
buffer = BytesIO()
yt = YouTube(row["link"])
video = (
yt.streams.filter(only_audio=True, mime_type="audio/mp4")
.order_by("bitrate")
.desc()
.first()
)
video.stream_to_buffer(buffer)
buffer.seek(0)
row["audio"] = decode_audio(buffer)
try:
video = (
yt.streams.filter(only_audio=True, mime_type="audio/mp4")
.order_by("bitrate")
.desc()
.last()
)
video.stream_to_buffer(buffer)
buffer.seek(0)
row["audio"] = decode_audio(buffer)
except VideoUnavailable:
print(f'Failed to download: {row["link"]}')
row["audio"] = []
return row


Expand All @@ -39,27 +43,22 @@ def url_to_audio(row):
)
args = parser.parse_args()

# define the evaluation metric
wer_metric = load("wer")

with open(os.path.join(os.path.dirname(__file__), "normalizer.json"), "r") as f:
normalizer = EnglishTextNormalizer(json.load(f))

dataset = load_dataset("mobiuslabsgmbh/youtube-commons-asr-eval", streaming=True).map(
url_to_audio
)
dataset = iter(
DataLoader(dataset["test"], batch_size=1, prefetch_factor=4, num_workers=2)
)

model = WhisperModel("large-v3", device="cuda")
pipeline = BatchedInferencePipeline(model, device="cuda")


all_transcriptions = []
all_references = []
# iterate over the dataset and run inference
for i, row in tqdm(enumerate(dataset), desc="Evaluating..."):
for i, row in tqdm(enumerate(dataset["test"]), desc="Evaluating..."):
if not row["audio"]:
continue
result, info = pipeline.transcribe(
row["audio"][0],
batch_size=8,
Expand All @@ -77,7 +76,5 @@ def url_to_audio(row):
all_references = [normalizer(reference) for reference in all_references]

# compute the WER metric
wer = 100 * wer_metric.compute(
predictions=all_transcriptions, references=all_references
)
print("WER: %.3f" % wer)
word_error_rate = 100 * wer(hypothesis=all_transcriptions, reference=all_references)
print("WER: %.3f" % word_error_rate)
1 change: 0 additions & 1 deletion benchmark/requirements.benchmark.txt
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
transformers
jiwer
evaluate
datasets
memory_profiler
py3nvml
Expand Down
11 changes: 3 additions & 8 deletions benchmark/wer_benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import os

from datasets import load_dataset
from evaluate import load
from jiwer import wer
from tqdm import tqdm
from transformers.models.whisper.english_normalizer import EnglishTextNormalizer

Expand All @@ -25,9 +25,6 @@
# load the dataset with streaming mode
dataset = load_dataset("librispeech_asr", "clean", split="validation", streaming=True)

# define the evaluation metric
wer_metric = load("wer")

with open(os.path.join(os.path.dirname(__file__), "normalizer.json"), "r") as f:
normalizer = EnglishTextNormalizer(json.load(f))

Expand Down Expand Up @@ -58,7 +55,5 @@ def inference(batch):
all_references = [normalizer(reference) for reference in all_references]

# compute the WER metric
wer = 100 * wer_metric.compute(
predictions=all_transcriptions, references=all_references
)
print("WER: %.3f" % wer)
word_error_rate = 100 * wer(hypothesis=all_transcriptions, reference=all_references)
print("WER: %.3f" % word_error_rate)
Loading