From c147d82cc063e7d0adfbe38ac13f2e08a35fc661 Mon Sep 17 00:00:00 2001 From: Your Name Date: Fri, 28 Jun 2024 23:23:30 -0400 Subject: [PATCH] Updating faster-whisper with PRs --- .env | 4 +-- Dockerfile | 2 ++ pre_requirements.txt | 3 +- requirements.txt | 2 +- src/wordcab_transcribe/config.py | 4 +-- .../router/v1/audio_file_endpoint.py | 1 - .../services/asr_service.py | 21 ++++++++++---- .../services/transcribe_service.py | 29 +++++++++++++++++-- 8 files changed, 51 insertions(+), 15 deletions(-) diff --git a/.env b/.env index 9b566a5..84a6aba 100644 --- a/.env +++ b/.env @@ -37,10 +37,10 @@ DEBUG=True # Then in your Dockerfile, copy the converted models to the /app/src/wordcab_transcribe/whisper_models folder. # Example for WHISPER_MODEL: COPY cloned_wordcab_transcribe_repo/src/wordcab_transcribe/whisper_models/large-v3 /app/src/wordcab_transcribe/whisper_models/large-v3 # Example for ALIGN_MODEL: COPY cloned_wordcab_transcribe_repo/src/wordcab_transcribe/whisper_models/tiny /app/src/wordcab_transcribe/whisper_models/tiny -WHISPER_MODEL="large-v3" +WHISPER_MODEL="medium" # You can specify one of two engines, "faster-whisper" or "tensorrt-llm". At the moment, "faster-whisper" is more # stable, adjustable, and accurate, while "tensorrt-llm" is faster but less accurate and adjustable. -WHISPER_ENGINE="tensorrt-llm" +WHISPER_ENGINE="faster-whisper-batched" # This helps adjust some build during the conversion of the Whisper model to TensorRT. If you change this, be sure to # it in pre_requirements.txt. The only available options are "0.9.0.dev2024032600" and "0.11.0.dev2024052100". # Note that version "0.11.0.dev2024052100" is not compatible with T4 or V100 GPUs. diff --git a/Dockerfile b/Dockerfile index e930003..00f6fd9 100644 --- a/Dockerfile +++ b/Dockerfile @@ -62,6 +62,8 @@ RUN curl -L ${RELEASE_URL} | tar -zx -C /tmp \ RUN python -m pip install pip --upgrade +COPY faster-whisper /app/faster-whisper + COPY pre_requirements.txt . COPY requirements.txt . diff --git a/pre_requirements.txt b/pre_requirements.txt index cc286fa..341faa8 100644 --- a/pre_requirements.txt +++ b/pre_requirements.txt @@ -10,4 +10,5 @@ tensorrt_llm==0.9.0.dev2024032600 Cython==3.0.10 youtokentome @ git+https://github.com/gburlet/YouTokenToMe.git@dependencies deepmultilingualpunctuation==1.0.1 -pyannote.audio==3.2.0 \ No newline at end of file +pyannote.audio==3.2.0 +ipython==8.24.0 \ No newline at end of file diff --git a/requirements.txt b/requirements.txt index 2d4077e..5260d04 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,7 +1,7 @@ aiohttp==3.9.3 aiofiles==23.2.1 boto3 -faster-whisper @ https://github.com/SYSTRAN/faster-whisper/archive/refs/heads/master.tar.gz +-e /app/faster-whisper ffmpeg-python==0.2.0 transformers==4.38.2 librosa==0.10.1 diff --git a/src/wordcab_transcribe/config.py b/src/wordcab_transcribe/config.py index aa4ce88..39db791 100644 --- a/src/wordcab_transcribe/config.py +++ b/src/wordcab_transcribe/config.py @@ -126,9 +126,9 @@ def whisper_model_compatibility_check(cls, value: str): # noqa: B902, N805 @field_validator("whisper_engine") def whisper_engine_compatibility_check(cls, value: str): # noqa: B902, N805 """Check that the whisper engine is compatible.""" - if value.lower() not in ["faster-whisper", "tensorrt-llm"]: + if value.lower() not in ["faster-whisper", "faster-whisper-batched", "tensorrt-llm"]: raise ValueError( - "The whisper engine must be one of `faster-whisper` or `tensorrt-llm`." + "The whisper engine must be one of `faster-whisper`, `faster-whisper-batched`, or `tensorrt-llm`." ) return value diff --git a/src/wordcab_transcribe/router/v1/audio_file_endpoint.py b/src/wordcab_transcribe/router/v1/audio_file_endpoint.py index eb56c7b..db0c723 100644 --- a/src/wordcab_transcribe/router/v1/audio_file_endpoint.py +++ b/src/wordcab_transcribe/router/v1/audio_file_endpoint.py @@ -117,7 +117,6 @@ async def inference_with_audio( # noqa: C901 ) background_tasks.add_task(delete_file, filepath=filename) - task = asyncio.create_task( asr.process_input( filepath=filepath, diff --git a/src/wordcab_transcribe/services/asr_service.py b/src/wordcab_transcribe/services/asr_service.py index 16e8026..298b851 100644 --- a/src/wordcab_transcribe/services/asr_service.py +++ b/src/wordcab_transcribe/services/asr_service.py @@ -439,7 +439,7 @@ async def process_input( # noqa: C901 filepath (Union[str, List[str]]): Path to the audio file or list of paths to the audio files to process. batch_size (Union[int, None]): - The batch size to use for the transcription. For tensorrt-llm whisper engine only. + The batch size to use for the transcription. For tensorrt-llm and faster-whisper-batch engines only. offset_start (Union[float, None]): The start time of the audio file to process. offset_end (Union[float, None]): @@ -611,11 +611,20 @@ async def process_transcription(self, task: ASRTask, debug_mode: bool) -> None: if isinstance(task.transcription.execution, LocalExecution): out = await time_and_tell_async( lambda: self.local_services.transcription( - task.audio, + audio=task.audio, model_index=task.transcription.execution.index, - suppress_blank=False, - word_timestamps=True, - **task.transcription.options.model_dump(), + source_lang=task.transcription.options.source_lang, + batch_size=task.batch_size, + num_beams=task.transcription.options.num_beams, + suppress_blank=False, # TODO: Add this to the options + vocab=task.transcription.options.vocab, + word_timestamps=task.word_timestamps, + internal_vad=task.transcription.options.internal_vad, + repetition_penalty=task.transcription.options.repetition_penalty, + compression_ratio_threshold=task.transcription.options.compression_ratio_threshold, + log_prob_threshold=task.transcription.options.log_prob_threshold, + no_speech_threshold=task.transcription.options.no_speech_threshold, + condition_on_previous_text=task.transcription.options.condition_on_previous_text, ), func_name="transcription", debug_mode=debug_mode, @@ -880,7 +889,7 @@ async def remote_diarization( if not settings.debug: headers = {"Content-Type": "application/x-www-form-urlencoded"} auth_url = f"{url}/api/v1/auth" - diarization_timeout = aiohttp.ClientTimeout(total=60) + diarization_timeout = aiohttp.ClientTimeout(total=10) async with AsyncLocationTrustedRedirectSession(timeout=diarization_timeout) as session: async with session.post( url=auth_url, diff --git a/src/wordcab_transcribe/services/transcribe_service.py b/src/wordcab_transcribe/services/transcribe_service.py index ba5cacd..9ce6ea8 100644 --- a/src/wordcab_transcribe/services/transcribe_service.py +++ b/src/wordcab_transcribe/services/transcribe_service.py @@ -21,9 +21,9 @@ from typing import Iterable, List, NamedTuple, Optional, Union import torch -from faster_whisper import WhisperModel from loguru import logger from tensorshare import Backend, TensorShare +from faster_whisper import WhisperModel, BatchedInferencePipeline from wordcab_transcribe.config import settings from wordcab_transcribe.engines.tensorrt_llm.model import WhisperModelTRT @@ -87,6 +87,16 @@ def __init__( device_index=device_index, compute_type=self.compute_type, ) + elif self.model_engine == "faster-whisper-batched": + logger.info("Using faster-whisper-batched model engine.") + self.model = BatchedInferencePipeline( + model=WhisperModel( + self.model_path, + device=self.device, + device_index=device_index, + compute_type=self.compute_type, + ) + ) elif self.model_engine == "tensorrt-llm": logger.info("Using tensorrt-llm model engine.") if "v3" in self.model_path: @@ -126,7 +136,7 @@ def __call__( ], source_lang: str, model_index: int, - batch_size: int = 1, + batch_size: int, num_beams: int = 1, suppress_blank: bool = False, vocab: Union[List[str], None] = None, @@ -220,6 +230,21 @@ def __call__( "window_size_samples": 512, }, ) + elif self.model_engine == "faster-whisper-batched": + print("Batch size: ", batch_size) + segments, _ = self.model.transcribe( + audio, + language=source_lang, + hotwords=prompt, + beam_size=num_beams, + repetition_penalty=repetition_penalty, + compression_ratio_threshold=compression_ratio_threshold, + log_prob_threshold=log_prob_threshold, + no_speech_threshold=no_speech_threshold, + suppress_blank=suppress_blank, + word_timestamps=word_timestamps, + batch_size=batch_size, + ) elif self.model_engine == "tensorrt-llm": segments = self.model.transcribe( audio_data=[audio],