From d1754d2c46b52dc61dccb80c7bf11b9c05902a38 Mon Sep 17 00:00:00 2001 From: makaveli10 Date: Wed, 31 Jan 2024 17:37:03 +0530 Subject: [PATCH 1/3] fix: model_size, no_speech, segment timings --- whisper_live/server.py | 17 +++++++++++++++-- 1 file changed, 15 insertions(+), 2 deletions(-) diff --git a/whisper_live/server.py b/whisper_live/server.py index be0e9598..1d22f501 100644 --- a/whisper_live/server.py +++ b/whisper_live/server.py @@ -397,6 +397,7 @@ def __init__( language=self.language, task=self.task ) + self.warmup() # threading self.trans_thread = threading.Thread(target=self.speech_to_text) @@ -410,7 +411,13 @@ def __init__( } ) ) - + + def warmup(self, warmup_steps=10): + logging.info("[INFO:] Warming up TensorRT engine..") + mel, duration = self.transcriber.log_mel_spectrogram("tests/jfk.flac") + for i in range(warmup_steps): + last_segment = self.transcriber.transcribe(mel) + def set_eos(self, eos): self.lock.acquire() self.eos = eos @@ -553,7 +560,7 @@ def __init__( multilingual=False, language=None, client_uid=None, - model="small", + model="small.en", initial_prompt=None, vad_parameters=None, ): @@ -587,6 +594,7 @@ def __init__( self.task = task self.initial_prompt = initial_prompt self.vad_parameters = vad_parameters or {"threshold": 0.5} + self.no_speech_thresh = 0.45 device = "cuda" if torch.cuda.is_available() else "cpu" @@ -721,6 +729,7 @@ def speech_to_text(self): if time.time() - self.t_start > self.add_pause_thresh: self.text.append('') + if not len(segments): continue try: self.websocket.send( json.dumps({ @@ -773,6 +782,10 @@ def update_segments(self, segments, duration): text_ = s.text self.text.append(text_) start, end = self.timestamp_offset + s.start, self.timestamp_offset + min(duration, s.end) + + if start >= end: continue + if s.no_speech_prob > self.no_speech_thresh: continue + self.transcript.append(self.format_segment(start, end, text_)) offset = min(duration, s.end) From f4027de343f2aebe688dec39819fc4c8f7223015 Mon Sep 17 00:00:00 2001 From: makaveli10 Date: Wed, 31 Jan 2024 17:37:30 +0530 Subject: [PATCH 2/3] add: save_transcript to srt file --- whisper_live/client.py | 46 ++++++++++++++++++++++++++++++++++++++++-- 1 file changed, 44 insertions(+), 2 deletions(-) diff --git a/whisper_live/client.py b/whisper_live/client.py index 96971c96..d61e4b53 100644 --- a/whisper_live/client.py +++ b/whisper_live/client.py @@ -13,6 +13,29 @@ import time +def format_time(s): + """Convert seconds (float) to SRT time format.""" + hours = int(s // 3600) + minutes = int((s % 3600) // 60) + seconds = int(s % 60) + milliseconds = int((s - int(s)) * 1000) + return f"{hours:02}:{minutes:02}:{seconds:02},{milliseconds:03}" + +def create_srt_file(segments, output_file): + with open(output_file, 'w', encoding='utf-8') as srt_file: + segment_number = 1 + for segment in segments: + start_time = format_time(float(segment['start'])) + end_time = format_time(float(segment['end'])) + text = segment['text'] + + srt_file.write(f"{segment_number}\n") + srt_file.write(f"{start_time} --> {end_time}\n") + srt_file.write(f"{text}\n\n") + + segment_number += 1 + + def resample(file: str, sr: int = 16000): """ # https://github.com/openai/whisper/blob/7858aa9c08d98f75575035ecd6481f462d66ca27/whisper/audio.py#L22 @@ -57,6 +80,7 @@ def __init__( lang=None, translate=False, model="small", + srt_file_path="output.srt" ): """ Initializes a Client instance for audio recording and streaming to a server. @@ -89,6 +113,7 @@ def __init__( self.language = lang self.model = model self.server_error = False + self.srt_file_path = srt_file_path if translate: self.task = "translate" @@ -127,6 +152,7 @@ def __init__( self.ws_thread.start() self.frames = b"" + self.transcript = [] print("[INFO]: * recording") def on_message(self, ws, message): @@ -181,12 +207,21 @@ def on_message(self, ws, message): message = message["segments"] text = [] - if len(message): - for seg in message: + n_segments = len(message) + + if n_segments: + for i, seg in enumerate(message): if text and text[-1] == seg["text"]: # already got it continue text.append(seg["text"]) + + if i == n_segments-1: + self.last_segment = seg + else: + if not len(self.transcript) or float(seg['start']) >= float(self.transcript[-1]['end']): + self.transcript.append(seg) + # keep only last 3 if len(text) > 3: text = text[-3:] @@ -302,6 +337,7 @@ def play_file(self, filename): assert self.last_response_recieved while time.time() - self.last_response_recieved < self.disconnect_if_no_response_for: continue + self.write_srt_file(self.srt_file_path) self.stream.close() self.close_websocket() @@ -311,6 +347,7 @@ def play_file(self, filename): self.stream.close() self.p.terminate() self.close_websocket() + self.write_srt_file(self.srt_file_path) print("[INFO]: Keyboard interrupt.") def close_websocket(self): @@ -438,6 +475,7 @@ def record(self, out_file="output_recording.wav"): t.start() n_audio_file += 1 self.frames = b"" + self.write_srt_file(self.srt_file_path) except KeyboardInterrupt: if len(self.frames): @@ -451,6 +489,7 @@ def record(self, out_file="output_recording.wav"): self.close_websocket() self.write_output_recording(n_audio_file, out_file) + self.write_srt_file(self.srt_file_path) def write_output_recording(self, n_audio_file, out_file): """ @@ -487,6 +526,9 @@ def write_output_recording(self, n_audio_file, out_file): os.remove(in_file) wavfile.close() + def write_srt_file(self, output_path="output.srt"): + self.transcript.append(self.last_segment) + create_srt_file(self.transcript, output_path) class TranscriptionClient: """ From 08575a03c2cef87869acffd4c01281ea4ff88b62 Mon Sep 17 00:00:00 2001 From: makaveli10 Date: Thu, 1 Feb 2024 14:22:45 +0530 Subject: [PATCH 3/3] write srt file only for faster_whisper backend --- whisper_live/client.py | 18 +++++++++++++----- whisper_live/server.py | 6 ++++-- 2 files changed, 17 insertions(+), 7 deletions(-) diff --git a/whisper_live/client.py b/whisper_live/client.py index d61e4b53..3832556b 100644 --- a/whisper_live/client.py +++ b/whisper_live/client.py @@ -192,6 +192,8 @@ def on_message(self, ws, message): if "message" in message.keys() and message["message"] == "SERVER_READY": self.recording = True + self.server_backend = message["backend"] + print(f"[INFO]: Server Running with backend {self.server_backend}") return if "language" in message.keys(): @@ -218,7 +220,7 @@ def on_message(self, ws, message): if i == n_segments-1: self.last_segment = seg - else: + elif self.server_backend == "faster_whisper": if not len(self.transcript) or float(seg['start']) >= float(self.transcript[-1]['end']): self.transcript.append(seg) @@ -337,7 +339,9 @@ def play_file(self, filename): assert self.last_response_recieved while time.time() - self.last_response_recieved < self.disconnect_if_no_response_for: continue - self.write_srt_file(self.srt_file_path) + + if self.server_backend == "faster_whisper": + self.write_srt_file(self.srt_file_path) self.stream.close() self.close_websocket() @@ -347,7 +351,8 @@ def play_file(self, filename): self.stream.close() self.p.terminate() self.close_websocket() - self.write_srt_file(self.srt_file_path) + if self.server_backend == "faster_whisper": + self.write_srt_file(self.srt_file_path) print("[INFO]: Keyboard interrupt.") def close_websocket(self): @@ -475,7 +480,8 @@ def record(self, out_file="output_recording.wav"): t.start() n_audio_file += 1 self.frames = b"" - self.write_srt_file(self.srt_file_path) + if self.server_backend == "faster_whisper": + self.write_srt_file(self.srt_file_path) except KeyboardInterrupt: if len(self.frames): @@ -489,7 +495,8 @@ def record(self, out_file="output_recording.wav"): self.close_websocket() self.write_output_recording(n_audio_file, out_file) - self.write_srt_file(self.srt_file_path) + if self.server_backend == "faster_whisper": + self.write_srt_file(self.srt_file_path) def write_output_recording(self, n_audio_file, out_file): """ @@ -530,6 +537,7 @@ def write_srt_file(self, output_path="output.srt"): self.transcript.append(self.last_segment) create_srt_file(self.transcript, output_path) + class TranscriptionClient: """ Client for handling audio transcription tasks via a WebSocket connection. diff --git a/whisper_live/server.py b/whisper_live/server.py index cf40d53d..b189576b 100644 --- a/whisper_live/server.py +++ b/whisper_live/server.py @@ -407,7 +407,8 @@ def __init__( json.dumps( { "uid": self.client_uid, - "message": self.SERVER_READY + "message": self.SERVER_READY, + "backend": "tensorrt" } ) ) @@ -615,7 +616,8 @@ def __init__( json.dumps( { "uid": self.client_uid, - "message": self.SERVER_READY + "message": self.SERVER_READY, + "backend": "faster_whisper" } ) )