Skip to content

Commit

Permalink
Add clip_timestamps and hallucination_silence_threshold options
Browse files Browse the repository at this point in the history
  • Loading branch information
trungkienbkhn committed Jan 16, 2024
1 parent 44f7e58 commit 9efdffb
Show file tree
Hide file tree
Showing 2 changed files with 157 additions and 19 deletions.
169 changes: 150 additions & 19 deletions faster_whisper/transcribe.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from faster_whisper.audio import decode_audio
from faster_whisper.feature_extractor import FeatureExtractor
from faster_whisper.tokenizer import _LANGUAGE_CODES, Tokenizer
from faster_whisper.utils import download_model, format_timestamp, get_logger
from faster_whisper.utils import download_model, format_timestamp, get_end, get_logger
from faster_whisper.vad import (
SpeechTimestampsMap,
VadOptions,
Expand Down Expand Up @@ -66,6 +66,8 @@ class TranscriptionOptions(NamedTuple):
word_timestamps: bool
prepend_punctuations: str
append_punctuations: str
clip_timestamps: Union[str, List[float]]
hallucination_silence_threshold: Optional[float]


class TranscriptionInfo(NamedTuple):
Expand Down Expand Up @@ -213,6 +215,8 @@ def transcribe(
append_punctuations: str = "\"'.。,,!!??::”)]}、",
vad_filter: bool = False,
vad_parameters: Optional[Union[dict, VadOptions]] = None,
clip_timestamps: Union[str, List[float]] = "0",
hallucination_silence_threshold: Optional[float] = None,
) -> Tuple[Iterable[Segment], TranscriptionInfo]:
"""Transcribes an input file.
Expand Down Expand Up @@ -264,6 +268,12 @@ def transcribe(
https://github.com/snakers4/silero-vad.
vad_parameters: Dictionary of Silero VAD parameters or VadOptions class (see available
parameters and default values in the class `VadOptions`).
clip_timestamps: Union[str, List[float]]
Comma-separated list start,end,start,end,... timestamps (in seconds) of clips to process.
The last end timestamp defaults to the end of the file.
hallucination_silence_threshold: Optional[float]
When word_timestamps is True, skip silent periods longer than this threshold (in seconds)
when a possible hallucination is detected
Returns:
A tuple with:
Expand Down Expand Up @@ -379,6 +389,8 @@ def transcribe(
word_timestamps=word_timestamps,
prepend_punctuations=prepend_punctuations,
append_punctuations=append_punctuations,
clip_timestamps=clip_timestamps,
hallucination_silence_threshold=hallucination_silence_threshold,
)

segments = self.generate_segments(features, tokenizer, options, encoder_output)
Expand Down Expand Up @@ -406,8 +418,33 @@ def generate_segments(
encoder_output: Optional[ctranslate2.StorageView] = None,
) -> Iterable[Segment]:
content_frames = features.shape[-1] - self.feature_extractor.nb_max_frames
content_duration = float(content_frames * self.feature_extractor.time_per_frame)

if isinstance(options.clip_timestamps, str):
TranscriptionOptions.clip_timestamps = [
float(ts)
for ts in (
options.clip_timestamps.split(",")
if options.clip_timestamps
else []
)
]
seek_points: List[int] = [
round(ts * self.frames_per_second) for ts in options.clip_timestamps
]
if len(seek_points) == 0:
seek_points.append(0)
if len(seek_points) % 2 == 1:
seek_points.append(content_frames)
seek_clips: List[Tuple[int, int]] = list(
zip(seek_points[::2], seek_points[1::2])
)

punctuation = "\"'“¿([{-\"'.。,,!!??::”)]}、"

idx = 0
seek = 0
clip_idx = 0
seek = seek_clips[clip_idx][0]
all_tokens = []
prompt_reset_since = 0

Expand All @@ -420,12 +457,30 @@ def generate_segments(
all_tokens.extend(options.initial_prompt)

last_speech_timestamp = 0.0
while seek < content_frames:
# NOTE: This loop is obscurely flattened to make the diff readable.
# A later commit should turn this into a simpler nested loop.
# for seek_clip_start, seek_clip_end in seek_clips:
# while seek < seek_clip_end
while clip_idx < len(seek_clips):
seek_clip_start, seek_clip_end = seek_clips[clip_idx]
if seek < seek_clip_start:
seek = seek_clip_start
if seek >= seek_clip_end:
clip_idx += 1
if clip_idx < len(seek_clips):
seek = seek_clips[clip_idx][0]
continue
time_offset = seek * self.feature_extractor.time_per_frame
segment = features[:, seek : seek + self.feature_extractor.nb_max_frames]
window_end_time = float(
(seek + self.feature_extractor.nb_max_frames)
* self.feature_extractor.time_per_frame
)
segment_size = min(
self.feature_extractor.nb_max_frames, content_frames - seek
self.feature_extractor.nb_max_frames,
content_frames - seek,
seek_clip_end - seek,
)
segment = features[:, seek : seek + segment_size]
segment_duration = segment_size * self.feature_extractor.time_per_frame

if self.logger.isEnabledFor(logging.DEBUG):
Expand Down Expand Up @@ -478,10 +533,33 @@ def generate_segments(
previous_seek = seek
current_segments = []

# anomalous words are very long/short/improbable
def word_anomaly_score(word: dict) -> float:
probability = word.get("probability", 0.0)
duration = word["end"] - word["start"]
score = 0.0
if probability < 0.15:
score += 1.0
if duration < 0.133:
score += (0.133 - duration) * 15
if duration > 2.0:
score += duration - 2.0
return score

def is_segment_anomaly(segment: Optional[dict]) -> bool:
if segment is None or not segment["words"]:
return False
words = [w for w in segment["words"] if w["word"] not in punctuation]
words = words[:8]
score = sum(word_anomaly_score(w) for w in words)
return score >= 3 or score + 0.01 >= len(words)

def next_words_segment(segments: List[dict]) -> Optional[dict]:
return next((s for s in segments if s["words"]), None)

single_timestamp_ending = (
len(tokens) >= 2
and tokens[-2] < tokenizer.timestamp_begin
and tokens[-1] >= tokenizer.timestamp_begin
and tokens[-2] < tokenizer.timestamp_begin <= tokens[-1]
)

consecutive_timestamps = [
Expand Down Expand Up @@ -564,18 +642,70 @@ def generate_segments(
last_speech_timestamp=last_speech_timestamp,
)

word_end_timestamps = [
w["end"] for s in current_segments for w in s["words"]
]
if len(word_end_timestamps) > 0:
last_speech_timestamp = word_end_timestamps[-1]
if not single_timestamp_ending and len(word_end_timestamps) > 0:
seek_shift = round(
(word_end_timestamps[-1] - time_offset) * self.frames_per_second
)

if seek_shift > 0:
seek = previous_seek + seek_shift
if not single_timestamp_ending:
last_word_end = get_end(current_segments)
if last_word_end is not None and last_word_end > time_offset:
seek = round(last_word_end * self.frames_per_second)

# skip silence before possible hallucinations
if options.hallucination_silence_threshold is not None:
threshold = options.hallucination_silence_threshold
if not single_timestamp_ending:
last_word_end = get_end(current_segments)
if last_word_end is not None and last_word_end > time_offset:
remaining_duration = window_end_time - last_word_end
if remaining_duration > threshold:
seek = round(last_word_end * self.frames_per_second)
else:
seek = previous_seek + segment_size

# if first segment might be a hallucination, skip leading silence
first_segment = next_words_segment(current_segments)
if first_segment is not None and is_segment_anomaly(first_segment):
gap = first_segment["start"] - time_offset
if gap > threshold:
seek = previous_seek + round(gap * self.frames_per_second)
continue

# skip silence before any possible hallucination that is surrounded
# by silence or more hallucinations
hal_last_end = last_speech_timestamp
for si in range(len(current_segments)):
segment = current_segments[si]
if not segment["words"]:
continue
if is_segment_anomaly(segment):
next_segment = next_words_segment(
current_segments[si + 1 :]
)
if next_segment is not None:
hal_next_start = next_segment["words"][0]["start"]
else:
hal_next_start = time_offset + segment_duration
silence_before = (
segment["start"] - hal_last_end > threshold
or segment["start"] < threshold
or segment["start"] - time_offset < 2.0
)
silence_after = (
hal_next_start - segment["end"] > threshold
or is_segment_anomaly(next_segment)
or window_end_time - segment["end"] < 2.0
)
if silence_before and silence_after:
seek = round(
max(time_offset + 1, segment["start"])
* self.frames_per_second
)
if content_duration - segment["end"] < threshold:
seek = content_frames
current_segments[si:] = []
break
hal_last_end = segment["end"]

last_word_end = get_end(current_segments)
if last_word_end is not None:
last_speech_timestamp = last_word_end

for segment in current_segments:
tokens = segment["tokens"]
Expand Down Expand Up @@ -794,6 +924,7 @@ def add_word_timestamps(
word_durations = np.array([word["end"] - word["start"] for word in alignment])
word_durations = word_durations[word_durations.nonzero()]
median_duration = np.median(word_durations) if len(word_durations) > 0 else 0.0
median_duration = min(0.7, float(median_duration))
max_duration = median_duration * 2

# hack: truncate long words at sentence boundaries.
Expand Down
7 changes: 7 additions & 0 deletions faster_whisper/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,3 +143,10 @@ class disabled_tqdm(tqdm):
def __init__(self, *args, **kwargs):
kwargs["disable"] = True
super().__init__(*args, **kwargs)


def get_end(segments: List[dict]) -> Optional[float]:
return next(
(w["end"] for s in reversed(segments) for w in reversed(s["words"])),
segments[-1]["end"] if segments else None,
)

0 comments on commit 9efdffb

Please sign in to comment.