From beeb467cca32d06dd4f1b37c897f2c85dc8cb8c9 Mon Sep 17 00:00:00 2001 From: trungkienbkhn Date: Tue, 16 Jan 2024 19:18:28 +0700 Subject: [PATCH] Add clip_timestamps and hallucination_silence_threshold options --- faster_whisper/transcribe.py | 169 +++++++++++++++++++++++++++++++---- faster_whisper/utils.py | 7 ++ 2 files changed, 157 insertions(+), 19 deletions(-) diff --git a/faster_whisper/transcribe.py b/faster_whisper/transcribe.py index 83bbf5cb..e5d5e4f8 100644 --- a/faster_whisper/transcribe.py +++ b/faster_whisper/transcribe.py @@ -14,7 +14,7 @@ from faster_whisper.audio import decode_audio from faster_whisper.feature_extractor import FeatureExtractor from faster_whisper.tokenizer import _LANGUAGE_CODES, Tokenizer -from faster_whisper.utils import download_model, format_timestamp, get_logger +from faster_whisper.utils import download_model, format_timestamp, get_end, get_logger from faster_whisper.vad import ( SpeechTimestampsMap, VadOptions, @@ -67,6 +67,8 @@ class TranscriptionOptions(NamedTuple): prepend_punctuations: str append_punctuations: str max_new_tokens: Optional[int] + clip_timestamps: Union[str, List[float]] + hallucination_silence_threshold: Optional[float] class TranscriptionInfo(NamedTuple): @@ -216,6 +218,8 @@ def transcribe( vad_parameters: Optional[Union[dict, VadOptions]] = None, max_new_tokens: Optional[int] = None, chunk_length: Optional[int] = None, + clip_timestamps: Union[str, List[float]] = "0", + hallucination_silence_threshold: Optional[float] = None, ) -> Tuple[Iterable[Segment], TranscriptionInfo]: """Transcribes an input file. @@ -271,6 +275,12 @@ def transcribe( set by the default max_length. chunk_length: The length of audio segments. If it is not None, it will overwrite the default chunk_length of the FeatureExtractor. + clip_timestamps: Union[str, List[float]] + Comma-separated list start,end,start,end,... timestamps (in seconds) of clips to + process. The last end timestamp defaults to the end of the file. + hallucination_silence_threshold: Optional[float] + When word_timestamps is True, skip silent periods longer than this threshold + (in seconds) when a possible hallucination is detected Returns: A tuple with: @@ -387,6 +397,8 @@ def transcribe( prepend_punctuations=prepend_punctuations, append_punctuations=append_punctuations, max_new_tokens=max_new_tokens, + clip_timestamps=clip_timestamps, + hallucination_silence_threshold=hallucination_silence_threshold, ) segments = self.generate_segments(features, tokenizer, options, encoder_output) @@ -414,8 +426,33 @@ def generate_segments( encoder_output: Optional[ctranslate2.StorageView] = None, ) -> Iterable[Segment]: content_frames = features.shape[-1] - self.feature_extractor.nb_max_frames + content_duration = float(content_frames * self.feature_extractor.time_per_frame) + + if isinstance(options.clip_timestamps, str): + TranscriptionOptions.clip_timestamps = [ + float(ts) + for ts in ( + options.clip_timestamps.split(",") + if options.clip_timestamps + else [] + ) + ] + seek_points: List[int] = [ + round(ts * self.frames_per_second) for ts in options.clip_timestamps + ] + if len(seek_points) == 0: + seek_points.append(0) + if len(seek_points) % 2 == 1: + seek_points.append(content_frames) + seek_clips: List[Tuple[int, int]] = list( + zip(seek_points[::2], seek_points[1::2]) + ) + + punctuation = "\"'“¿([{-\"'.。,,!!??::”)]}、" + idx = 0 - seek = 0 + clip_idx = 0 + seek = seek_clips[clip_idx][0] all_tokens = [] prompt_reset_since = 0 @@ -428,12 +465,30 @@ def generate_segments( all_tokens.extend(options.initial_prompt) last_speech_timestamp = 0.0 - while seek < content_frames: + # NOTE: This loop is obscurely flattened to make the diff readable. + # A later commit should turn this into a simpler nested loop. + # for seek_clip_start, seek_clip_end in seek_clips: + # while seek < seek_clip_end + while clip_idx < len(seek_clips): + seek_clip_start, seek_clip_end = seek_clips[clip_idx] + if seek < seek_clip_start: + seek = seek_clip_start + if seek >= seek_clip_end: + clip_idx += 1 + if clip_idx < len(seek_clips): + seek = seek_clips[clip_idx][0] + continue time_offset = seek * self.feature_extractor.time_per_frame - segment = features[:, seek : seek + self.feature_extractor.nb_max_frames] + window_end_time = float( + (seek + self.feature_extractor.nb_max_frames) + * self.feature_extractor.time_per_frame + ) segment_size = min( - self.feature_extractor.nb_max_frames, content_frames - seek + self.feature_extractor.nb_max_frames, + content_frames - seek, + seek_clip_end - seek, ) + segment = features[:, seek : seek + segment_size] segment_duration = segment_size * self.feature_extractor.time_per_frame if self.logger.isEnabledFor(logging.DEBUG): @@ -486,10 +541,33 @@ def generate_segments( previous_seek = seek current_segments = [] + # anomalous words are very long/short/improbable + def word_anomaly_score(word: dict) -> float: + probability = word.get("probability", 0.0) + duration = word["end"] - word["start"] + score = 0.0 + if probability < 0.15: + score += 1.0 + if duration < 0.133: + score += (0.133 - duration) * 15 + if duration > 2.0: + score += duration - 2.0 + return score + + def is_segment_anomaly(segment: Optional[dict]) -> bool: + if segment is None or not segment["words"]: + return False + words = [w for w in segment["words"] if w["word"] not in punctuation] + words = words[:8] + score = sum(word_anomaly_score(w) for w in words) + return score >= 3 or score + 0.01 >= len(words) + + def next_words_segment(segments: List[dict]) -> Optional[dict]: + return next((s for s in segments if s["words"]), None) + single_timestamp_ending = ( len(tokens) >= 2 - and tokens[-2] < tokenizer.timestamp_begin - and tokens[-1] >= tokenizer.timestamp_begin + and tokens[-2] < tokenizer.timestamp_begin <= tokens[-1] ) consecutive_timestamps = [ @@ -572,18 +650,70 @@ def generate_segments( last_speech_timestamp=last_speech_timestamp, ) - word_end_timestamps = [ - w["end"] for s in current_segments for w in s["words"] - ] - if len(word_end_timestamps) > 0: - last_speech_timestamp = word_end_timestamps[-1] - if not single_timestamp_ending and len(word_end_timestamps) > 0: - seek_shift = round( - (word_end_timestamps[-1] - time_offset) * self.frames_per_second - ) - - if seek_shift > 0: - seek = previous_seek + seek_shift + if not single_timestamp_ending: + last_word_end = get_end(current_segments) + if last_word_end is not None and last_word_end > time_offset: + seek = round(last_word_end * self.frames_per_second) + + # skip silence before possible hallucinations + if options.hallucination_silence_threshold is not None: + threshold = options.hallucination_silence_threshold + if not single_timestamp_ending: + last_word_end = get_end(current_segments) + if last_word_end is not None and last_word_end > time_offset: + remaining_duration = window_end_time - last_word_end + if remaining_duration > threshold: + seek = round(last_word_end * self.frames_per_second) + else: + seek = previous_seek + segment_size + + # if first segment might be a hallucination, skip leading silence + first_segment = next_words_segment(current_segments) + if first_segment is not None and is_segment_anomaly(first_segment): + gap = first_segment["start"] - time_offset + if gap > threshold: + seek = previous_seek + round(gap * self.frames_per_second) + continue + + # skip silence before any possible hallucination that is surrounded + # by silence or more hallucinations + hal_last_end = last_speech_timestamp + for si in range(len(current_segments)): + segment = current_segments[si] + if not segment["words"]: + continue + if is_segment_anomaly(segment): + next_segment = next_words_segment( + current_segments[si + 1 :] + ) + if next_segment is not None: + hal_next_start = next_segment["words"][0]["start"] + else: + hal_next_start = time_offset + segment_duration + silence_before = ( + segment["start"] - hal_last_end > threshold + or segment["start"] < threshold + or segment["start"] - time_offset < 2.0 + ) + silence_after = ( + hal_next_start - segment["end"] > threshold + or is_segment_anomaly(next_segment) + or window_end_time - segment["end"] < 2.0 + ) + if silence_before and silence_after: + seek = round( + max(time_offset + 1, segment["start"]) + * self.frames_per_second + ) + if content_duration - segment["end"] < threshold: + seek = content_frames + current_segments[si:] = [] + break + hal_last_end = segment["end"] + + last_word_end = get_end(current_segments) + if last_word_end is not None: + last_speech_timestamp = last_word_end for segment in current_segments: tokens = segment["tokens"] @@ -819,6 +949,7 @@ def add_word_timestamps( word_durations = np.array([word["end"] - word["start"] for word in alignment]) word_durations = word_durations[word_durations.nonzero()] median_duration = np.median(word_durations) if len(word_durations) > 0 else 0.0 + median_duration = min(0.7, float(median_duration)) max_duration = median_duration * 2 # hack: truncate long words at sentence boundaries. diff --git a/faster_whisper/utils.py b/faster_whisper/utils.py index 98769569..0b5f3755 100644 --- a/faster_whisper/utils.py +++ b/faster_whisper/utils.py @@ -146,3 +146,10 @@ class disabled_tqdm(tqdm): def __init__(self, *args, **kwargs): kwargs["disable"] = True super().__init__(*args, **kwargs) + + +def get_end(segments: List[dict]) -> Optional[float]: + return next( + (w["end"] for s in reversed(segments) for w in reversed(s["words"])), + segments[-1]["end"] if segments else None, + )