Skip to content

Commit

Permalink
Added multiprocessing for cpu processing
Browse files Browse the repository at this point in the history
  • Loading branch information
joiemoie committed Jan 19, 2024
1 parent 72ff979 commit dd68247
Showing 1 changed file with 47 additions and 36 deletions.
83 changes: 47 additions & 36 deletions faster_whisper/transcribe.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
collect_chunks,
get_speech_timestamps,
)
import multiprocessing


class Word(NamedTuple):
Expand Down Expand Up @@ -77,6 +78,46 @@ class TranscriptionInfo(NamedTuple):
transcription_options: TranscriptionOptions
vad_options: VadOptions

# Performs the preprocessing on its own process to make use of all CPU cores
def cpu_preprocessing(logger, feature_extractor, audio: Union[str, BinaryIO, np.ndarray], vad_filter: bool = False, vad_parameters: Optional[Union[dict, VadOptions]] = None) -> Tuple[np.ndarray, float, float, Optional[List[dict]]]:
sampling_rate = feature_extractor.sampling_rate
duration = audio.shape[0] / sampling_rate
duration_after_vad = duration

logger.info(
"Processing audio with duration %s", format_timestamp(duration)
)

if vad_filter:
speech_chunks = get_speech_timestamps(audio, vad_parameters)
audio = collect_chunks(audio, speech_chunks)
duration_after_vad = audio.shape[0] / sampling_rate

logger.info(
"VAD filter removed %s of audio",
format_timestamp(duration - duration_after_vad),
)

if logger.isEnabledFor(logging.DEBUG):
logger.debug(
"VAD filter kept the following audio segments: %s",
", ".join(
"[%s -> %s]"
% (
format_timestamp(chunk["start"] / sampling_rate),
format_timestamp(chunk["end"] / sampling_rate),
)
for chunk in speech_chunks
),
)

else:
speech_chunks = None

features = feature_extractor(audio)


return features, duration, duration_after_vad, speech_chunks

class WhisperModel:
def __init__(
Expand Down Expand Up @@ -271,49 +312,19 @@ def transcribe(
- a generator over transcribed segments
- an instance of TranscriptionInfo
"""
sampling_rate = self.feature_extractor.sampling_rate

if not isinstance(audio, np.ndarray):
audio = decode_audio(audio, sampling_rate=sampling_rate)

duration = audio.shape[0] / sampling_rate
duration_after_vad = duration

self.logger.info(
"Processing audio with duration %s", format_timestamp(duration)
)
audio = decode_audio(audio, sampling_rate=self.feature_extractor.sampling_rate)

if vad_filter:
if vad_parameters is None:
vad_parameters = VadOptions()
elif isinstance(vad_parameters, dict):
vad_parameters = VadOptions(**vad_parameters)
speech_chunks = get_speech_timestamps(audio, vad_parameters)
audio = collect_chunks(audio, speech_chunks)
duration_after_vad = audio.shape[0] / sampling_rate

self.logger.info(
"VAD filter removed %s of audio",
format_timestamp(duration - duration_after_vad),
)

if self.logger.isEnabledFor(logging.DEBUG):
self.logger.debug(
"VAD filter kept the following audio segments: %s",
", ".join(
"[%s -> %s]"
% (
format_timestamp(chunk["start"] / sampling_rate),
format_timestamp(chunk["end"] / sampling_rate),
)
for chunk in speech_chunks
),
)

else:
speech_chunks = None

features = self.feature_extractor(audio)

# Spawns a new process to run preprocessing on CPU
with multiprocessing.Pool() as pool:
features, duration, duration_after_vad, speech_chunks = pool.apply(cpu_preprocessing, (self.logger, self.feature_extractor, audio, vad_filter, vad_parameters))

encoder_output = None
all_language_probs = None
Expand Down Expand Up @@ -384,7 +395,7 @@ def transcribe(
segments = self.generate_segments(features, tokenizer, options, encoder_output)

if speech_chunks:
segments = restore_speech_timestamps(segments, speech_chunks, sampling_rate)
segments = restore_speech_timestamps(segments, speech_chunks, self.feature_extractor.sampling_rate)

info = TranscriptionInfo(
language=language,
Expand Down

0 comments on commit dd68247

Please sign in to comment.