diff --git a/faster_whisper/transcribe.py b/faster_whisper/transcribe.py index 7996321e..352f8fc3 100644 --- a/faster_whisper/transcribe.py +++ b/faster_whisper/transcribe.py @@ -14,7 +14,7 @@ from faster_whisper.audio import decode_audio from faster_whisper.feature_extractor import FeatureExtractor from faster_whisper.tokenizer import _LANGUAGE_CODES, Tokenizer -from faster_whisper.utils import download_model, format_timestamp, get_logger +from faster_whisper.utils import download_model, format_timestamp, get_end, get_logger from faster_whisper.vad import ( SpeechTimestampsMap, VadOptions, @@ -66,6 +66,8 @@ class TranscriptionOptions(NamedTuple): word_timestamps: bool prepend_punctuations: str append_punctuations: str + clip_timestamps: Union[str, List[float]] + hallucination_silence_threshold: Optional[float] class TranscriptionInfo(NamedTuple): @@ -213,6 +215,8 @@ def transcribe( append_punctuations: str = "\"'.。,,!!??::”)]}、", vad_filter: bool = False, vad_parameters: Optional[Union[dict, VadOptions]] = None, + clip_timestamps: Union[str, List[float]] = "0", + hallucination_silence_threshold: Optional[float] = None, ) -> Tuple[Iterable[Segment], TranscriptionInfo]: """Transcribes an input file. @@ -264,6 +268,12 @@ def transcribe( https://github.com/snakers4/silero-vad. vad_parameters: Dictionary of Silero VAD parameters or VadOptions class (see available parameters and default values in the class `VadOptions`). + clip_timestamps: Union[str, List[float]] + Comma-separated list start,end,start,end,... timestamps (in seconds) of clips to process. + The last end timestamp defaults to the end of the file. + hallucination_silence_threshold: Optional[float] + When word_timestamps is True, skip silent periods longer than this threshold (in seconds) + when a possible hallucination is detected Returns: A tuple with: @@ -379,6 +389,8 @@ def transcribe( word_timestamps=word_timestamps, prepend_punctuations=prepend_punctuations, append_punctuations=append_punctuations, + clip_timestamps=clip_timestamps, + hallucination_silence_threshold=hallucination_silence_threshold, ) segments = self.generate_segments(features, tokenizer, options, encoder_output) @@ -406,8 +418,33 @@ def generate_segments( encoder_output: Optional[ctranslate2.StorageView] = None, ) -> Iterable[Segment]: content_frames = features.shape[-1] - self.feature_extractor.nb_max_frames + content_duration = float(content_frames * self.feature_extractor.time_per_frame) + + if isinstance(options.clip_timestamps, str): + TranscriptionOptions.clip_timestamps = [ + float(ts) + for ts in ( + options.clip_timestamps.split(",") + if options.clip_timestamps + else [] + ) + ] + seek_points: List[int] = [ + round(ts * self.frames_per_second) for ts in options.clip_timestamps + ] + if len(seek_points) == 0: + seek_points.append(0) + if len(seek_points) % 2 == 1: + seek_points.append(content_frames) + seek_clips: List[Tuple[int, int]] = list( + zip(seek_points[::2], seek_points[1::2]) + ) + + punctuation = "\"'“¿([{-\"'.。,,!!??::”)]}、" + idx = 0 - seek = 0 + clip_idx = 0 + seek = seek_clips[clip_idx][0] all_tokens = [] prompt_reset_since = 0 @@ -420,12 +457,30 @@ def generate_segments( all_tokens.extend(options.initial_prompt) last_speech_timestamp = 0.0 - while seek < content_frames: + # NOTE: This loop is obscurely flattened to make the diff readable. + # A later commit should turn this into a simpler nested loop. + # for seek_clip_start, seek_clip_end in seek_clips: + # while seek < seek_clip_end + while clip_idx < len(seek_clips): + seek_clip_start, seek_clip_end = seek_clips[clip_idx] + if seek < seek_clip_start: + seek = seek_clip_start + if seek >= seek_clip_end: + clip_idx += 1 + if clip_idx < len(seek_clips): + seek = seek_clips[clip_idx][0] + continue time_offset = seek * self.feature_extractor.time_per_frame - segment = features[:, seek : seek + self.feature_extractor.nb_max_frames] + window_end_time = float( + (seek + self.feature_extractor.nb_max_frames) + * self.feature_extractor.time_per_frame + ) segment_size = min( - self.feature_extractor.nb_max_frames, content_frames - seek + self.feature_extractor.nb_max_frames, + content_frames - seek, + seek_clip_end - seek, ) + segment = features[:, seek : seek + segment_size] segment_duration = segment_size * self.feature_extractor.time_per_frame if self.logger.isEnabledFor(logging.DEBUG): @@ -478,10 +533,33 @@ def generate_segments( previous_seek = seek current_segments = [] + # anomalous words are very long/short/improbable + def word_anomaly_score(word: dict) -> float: + probability = word.get("probability", 0.0) + duration = word["end"] - word["start"] + score = 0.0 + if probability < 0.15: + score += 1.0 + if duration < 0.133: + score += (0.133 - duration) * 15 + if duration > 2.0: + score += duration - 2.0 + return score + + def is_segment_anomaly(segment: Optional[dict]) -> bool: + if segment is None or not segment["words"]: + return False + words = [w for w in segment["words"] if w["word"] not in punctuation] + words = words[:8] + score = sum(word_anomaly_score(w) for w in words) + return score >= 3 or score + 0.01 >= len(words) + + def next_words_segment(segments: List[dict]) -> Optional[dict]: + return next((s for s in segments if s["words"]), None) + single_timestamp_ending = ( len(tokens) >= 2 - and tokens[-2] < tokenizer.timestamp_begin - and tokens[-1] >= tokenizer.timestamp_begin + and tokens[-2] < tokenizer.timestamp_begin <= tokens[-1] ) consecutive_timestamps = [ @@ -564,18 +642,70 @@ def generate_segments( last_speech_timestamp=last_speech_timestamp, ) - word_end_timestamps = [ - w["end"] for s in current_segments for w in s["words"] - ] - if len(word_end_timestamps) > 0: - last_speech_timestamp = word_end_timestamps[-1] - if not single_timestamp_ending and len(word_end_timestamps) > 0: - seek_shift = round( - (word_end_timestamps[-1] - time_offset) * self.frames_per_second - ) - - if seek_shift > 0: - seek = previous_seek + seek_shift + if not single_timestamp_ending: + last_word_end = get_end(current_segments) + if last_word_end is not None and last_word_end > time_offset: + seek = round(last_word_end * self.frames_per_second) + + # skip silence before possible hallucinations + if options.hallucination_silence_threshold is not None: + threshold = options.hallucination_silence_threshold + if not single_timestamp_ending: + last_word_end = get_end(current_segments) + if last_word_end is not None and last_word_end > time_offset: + remaining_duration = window_end_time - last_word_end + if remaining_duration > threshold: + seek = round(last_word_end * self.frames_per_second) + else: + seek = previous_seek + segment_size + + # if first segment might be a hallucination, skip leading silence + first_segment = next_words_segment(current_segments) + if first_segment is not None and is_segment_anomaly(first_segment): + gap = first_segment["start"] - time_offset + if gap > threshold: + seek = previous_seek + round(gap * self.frames_per_second) + continue + + # skip silence before any possible hallucination that is surrounded + # by silence or more hallucinations + hal_last_end = last_speech_timestamp + for si in range(len(current_segments)): + segment = current_segments[si] + if not segment["words"]: + continue + if is_segment_anomaly(segment): + next_segment = next_words_segment( + current_segments[si + 1 :] + ) + if next_segment is not None: + hal_next_start = next_segment["words"][0]["start"] + else: + hal_next_start = time_offset + segment_duration + silence_before = ( + segment["start"] - hal_last_end > threshold + or segment["start"] < threshold + or segment["start"] - time_offset < 2.0 + ) + silence_after = ( + hal_next_start - segment["end"] > threshold + or is_segment_anomaly(next_segment) + or window_end_time - segment["end"] < 2.0 + ) + if silence_before and silence_after: + seek = round( + max(time_offset + 1, segment["start"]) + * self.frames_per_second + ) + if content_duration - segment["end"] < threshold: + seek = content_frames + current_segments[si:] = [] + break + hal_last_end = segment["end"] + + last_word_end = get_end(current_segments) + if last_word_end is not None: + last_speech_timestamp = last_word_end for segment in current_segments: tokens = segment["tokens"] @@ -794,6 +924,7 @@ def add_word_timestamps( word_durations = np.array([word["end"] - word["start"] for word in alignment]) word_durations = word_durations[word_durations.nonzero()] median_duration = np.median(word_durations) if len(word_durations) > 0 else 0.0 + median_duration = min(0.7, float(median_duration)) max_duration = median_duration * 2 # hack: truncate long words at sentence boundaries. diff --git a/faster_whisper/utils.py b/faster_whisper/utils.py index 343a6357..26289c49 100644 --- a/faster_whisper/utils.py +++ b/faster_whisper/utils.py @@ -143,3 +143,10 @@ class disabled_tqdm(tqdm): def __init__(self, *args, **kwargs): kwargs["disable"] = True super().__init__(*args, **kwargs) + + +def get_end(segments: List[dict]) -> Optional[float]: + return next( + (w["end"] for s in reversed(segments) for w in reversed(s["words"])), + segments[-1]["end"] if segments else None, + )