diff --git a/whisper_live/client.py b/whisper_live/client.py index 979e3254..8b08d159 100644 --- a/whisper_live/client.py +++ b/whisper_live/client.py @@ -56,8 +56,7 @@ def __init__( is_multilingual=False, lang=None, translate=False, - model_size="small", - use_custom_model=False + model="small", ): """ Initializes a Client instance for audio recording and streaming to a server. @@ -88,9 +87,8 @@ def __init__( self.disconnect_if_no_response_for = 15 self.multilingual = is_multilingual self.language = lang - self.model_size = model_size + self.model = model self.server_error = False - self.use_custom_model = use_custom_model if translate: self.task = "translate" @@ -229,8 +227,7 @@ def on_open(self, ws): "multilingual": self.multilingual, "language": self.language, "task": self.task, - "model_size": self.model_size, - "use_custom_model": self.use_custom_model # if runnning your own server with a custom model + "model": self.model, } ) ) @@ -521,10 +518,9 @@ def __init__(self, is_multilingual=False, lang=None, translate=False, - model_size="small", - use_custom_model=False + model="small", ): - self.client = Client(host, port, is_multilingual, lang, translate, model_size, use_custom_model) + self.client = Client(host, port, is_multilingual, lang, translate, model) def __call__(self, audio=None, hls_url=None): """ diff --git a/whisper_live/server.py b/whisper_live/server.py index ddeaee3e..56e8a64d 100644 --- a/whisper_live/server.py +++ b/whisper_live/server.py @@ -6,7 +6,7 @@ import textwrap import logging -# logging.basicConfig(level = logging.INFO) +logging.basicConfig(level = logging.INFO) from websockets.sync.server import serve @@ -100,9 +100,9 @@ def recv_audio(self, websocket, custom_model_path=None): return # validate custom model - if options["use_custom_model"]: - if custom_model_path is None or not os.path.exists(custom_model_path): - options["use_custom_model"] = False + if custom_model_path is not None and os.path.exists(custom_model_path): + logging.info(f"Using custom model {custom_model_path}") + options["model"] = custom_model_path client = ServeClient( websocket, @@ -110,10 +110,9 @@ def recv_audio(self, websocket, custom_model_path=None): language=options["language"], task=options["task"], client_uid=options["uid"], - model_size_or_path=custom_model_path if options["use_custom_model"] else options["model_size"], + model=options["model"], initial_prompt=options.get("initial_prompt"), vad_parameters=options.get("vad_parameters"), - use_custom_model=options["use_custom_model"] ) self.clients[websocket] = client @@ -206,10 +205,9 @@ def __init__( multilingual=False, language=None, client_uid=None, - model_size_or_path="small", + model="small", initial_prompt=None, vad_parameters=None, - use_custom_model=False ): """ Initialize a ServeClient instance. @@ -230,13 +228,15 @@ def __init__( self.data = b"" self.frames = b"" self.model_sizes = [ - "tiny", "base", "small", "medium", "large-v2", "large-v3" + "tiny", "tiny.en", "base", "base.en", "small", "small.en", + "medium", "medium.en", "large-v2", "large-v3", ] + self.multilingual = multilingual - if not use_custom_model: - self.model_size_or_path = self.get_model_size(model_size_or_path) + if not os.path.exists(model): + self.model_size_or_path = self.get_model_size(model) else: - self.model_size_or_path = model_size_or_path + self.model_size_or_path = model self.language = language if self.multilingual else "en" self.task = task @@ -246,7 +246,7 @@ def __init__( device = "cuda" if torch.cuda.is_available() else "cpu" - if self.model_size_or_path == None: + if self.model_size_or_path is None: return self.transcriber = WhisperModel( @@ -302,12 +302,13 @@ def get_model_size(self, model_size): ) return None - if model_size in ["large-v2", "large-v3"]: + if model_size.endswith("en") and self.multilingual: + logging.info(f"Setting multilingual to false with {model_size} which is english only model.") + self.multilingual = False + + if not model_size.endswith("en") and not self.multilingual: + logging.info(f"Setting multilingual to true with multilingual model {model_size}.") self.multilingual = True - return model_size - - if not self.multilingual: - model_size = model_size + ".en" return model_size