diff --git a/whisper_live/client.py b/whisper_live/client.py index 96971c96..3832556b 100644 --- a/whisper_live/client.py +++ b/whisper_live/client.py @@ -13,6 +13,29 @@ import time +def format_time(s): + """Convert seconds (float) to SRT time format.""" + hours = int(s // 3600) + minutes = int((s % 3600) // 60) + seconds = int(s % 60) + milliseconds = int((s - int(s)) * 1000) + return f"{hours:02}:{minutes:02}:{seconds:02},{milliseconds:03}" + +def create_srt_file(segments, output_file): + with open(output_file, 'w', encoding='utf-8') as srt_file: + segment_number = 1 + for segment in segments: + start_time = format_time(float(segment['start'])) + end_time = format_time(float(segment['end'])) + text = segment['text'] + + srt_file.write(f"{segment_number}\n") + srt_file.write(f"{start_time} --> {end_time}\n") + srt_file.write(f"{text}\n\n") + + segment_number += 1 + + def resample(file: str, sr: int = 16000): """ # https://github.com/openai/whisper/blob/7858aa9c08d98f75575035ecd6481f462d66ca27/whisper/audio.py#L22 @@ -57,6 +80,7 @@ def __init__( lang=None, translate=False, model="small", + srt_file_path="output.srt" ): """ Initializes a Client instance for audio recording and streaming to a server. @@ -89,6 +113,7 @@ def __init__( self.language = lang self.model = model self.server_error = False + self.srt_file_path = srt_file_path if translate: self.task = "translate" @@ -127,6 +152,7 @@ def __init__( self.ws_thread.start() self.frames = b"" + self.transcript = [] print("[INFO]: * recording") def on_message(self, ws, message): @@ -166,6 +192,8 @@ def on_message(self, ws, message): if "message" in message.keys() and message["message"] == "SERVER_READY": self.recording = True + self.server_backend = message["backend"] + print(f"[INFO]: Server Running with backend {self.server_backend}") return if "language" in message.keys(): @@ -181,12 +209,21 @@ def on_message(self, ws, message): message = message["segments"] text = [] - if len(message): - for seg in message: + n_segments = len(message) + + if n_segments: + for i, seg in enumerate(message): if text and text[-1] == seg["text"]: # already got it continue text.append(seg["text"]) + + if i == n_segments-1: + self.last_segment = seg + elif self.server_backend == "faster_whisper": + if not len(self.transcript) or float(seg['start']) >= float(self.transcript[-1]['end']): + self.transcript.append(seg) + # keep only last 3 if len(text) > 3: text = text[-3:] @@ -302,6 +339,9 @@ def play_file(self, filename): assert self.last_response_recieved while time.time() - self.last_response_recieved < self.disconnect_if_no_response_for: continue + + if self.server_backend == "faster_whisper": + self.write_srt_file(self.srt_file_path) self.stream.close() self.close_websocket() @@ -311,6 +351,8 @@ def play_file(self, filename): self.stream.close() self.p.terminate() self.close_websocket() + if self.server_backend == "faster_whisper": + self.write_srt_file(self.srt_file_path) print("[INFO]: Keyboard interrupt.") def close_websocket(self): @@ -438,6 +480,8 @@ def record(self, out_file="output_recording.wav"): t.start() n_audio_file += 1 self.frames = b"" + if self.server_backend == "faster_whisper": + self.write_srt_file(self.srt_file_path) except KeyboardInterrupt: if len(self.frames): @@ -451,6 +495,8 @@ def record(self, out_file="output_recording.wav"): self.close_websocket() self.write_output_recording(n_audio_file, out_file) + if self.server_backend == "faster_whisper": + self.write_srt_file(self.srt_file_path) def write_output_recording(self, n_audio_file, out_file): """ @@ -487,6 +533,10 @@ def write_output_recording(self, n_audio_file, out_file): os.remove(in_file) wavfile.close() + def write_srt_file(self, output_path="output.srt"): + self.transcript.append(self.last_segment) + create_srt_file(self.transcript, output_path) + class TranscriptionClient: """ diff --git a/whisper_live/server.py b/whisper_live/server.py index 33059117..b189576b 100644 --- a/whisper_live/server.py +++ b/whisper_live/server.py @@ -407,16 +407,17 @@ def __init__( json.dumps( { "uid": self.client_uid, - "message": self.SERVER_READY + "message": self.SERVER_READY, + "backend": "tensorrt" } ) ) def warmup(self, warmup_steps=10): logging.info("[INFO:] Warming up TensorRT engine..") - mel, duration = self.transcriber.log_mel_spectrogram("tests/jfk.flac") + mel, _ = self.transcriber.log_mel_spectrogram("tests/jfk.flac") for i in range(warmup_steps): - last_segment = self.transcriber.transcribe(mel) + self.transcriber.transcribe(mel) def set_eos(self, eos): self.lock.acquire() @@ -560,7 +561,7 @@ def __init__( multilingual=False, language=None, client_uid=None, - model="small", + model="small.en", initial_prompt=None, vad_parameters=None, ): @@ -594,6 +595,7 @@ def __init__( self.task = task self.initial_prompt = initial_prompt self.vad_parameters = vad_parameters or {"threshold": 0.5} + self.no_speech_thresh = 0.45 device = "cuda" if torch.cuda.is_available() else "cpu" @@ -614,7 +616,8 @@ def __init__( json.dumps( { "uid": self.client_uid, - "message": self.SERVER_READY + "message": self.SERVER_READY, + "backend": "faster_whisper" } ) ) @@ -728,6 +731,7 @@ def speech_to_text(self): if time.time() - self.t_start > self.add_pause_thresh: self.text.append('') + if not len(segments): continue try: self.websocket.send( json.dumps({ @@ -780,6 +784,10 @@ def update_segments(self, segments, duration): text_ = s.text self.text.append(text_) start, end = self.timestamp_offset + s.start, self.timestamp_offset + min(duration, s.end) + + if start >= end: continue + if s.no_speech_prob > self.no_speech_thresh: continue + self.transcript.append(self.format_segment(start, end, text_)) offset = min(duration, s.end)