Skip to content

Commit

Permalink
Merge pull request #102 from makaveli10/change_model_size_param_name
Browse files Browse the repository at this point in the history
Server to control custom model usage.
  • Loading branch information
zoq authored Jan 18, 2024
2 parents 0c01d7b + 881fd55 commit 0942dc2
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 27 deletions.
14 changes: 5 additions & 9 deletions whisper_live/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,8 +56,7 @@ def __init__(
is_multilingual=False,
lang=None,
translate=False,
model_size="small",
use_custom_model=False
model="small",
):
"""
Initializes a Client instance for audio recording and streaming to a server.
Expand Down Expand Up @@ -88,9 +87,8 @@ def __init__(
self.disconnect_if_no_response_for = 15
self.multilingual = is_multilingual
self.language = lang
self.model_size = model_size
self.model = model
self.server_error = False
self.use_custom_model = use_custom_model

if translate:
self.task = "translate"
Expand Down Expand Up @@ -229,8 +227,7 @@ def on_open(self, ws):
"multilingual": self.multilingual,
"language": self.language,
"task": self.task,
"model_size": self.model_size,
"use_custom_model": self.use_custom_model # if runnning your own server with a custom model
"model": self.model,
}
)
)
Expand Down Expand Up @@ -521,10 +518,9 @@ def __init__(self,
is_multilingual=False,
lang=None,
translate=False,
model_size="small",
use_custom_model=False
model="small",
):
self.client = Client(host, port, is_multilingual, lang, translate, model_size, use_custom_model)
self.client = Client(host, port, is_multilingual, lang, translate, model)

def __call__(self, audio=None, hls_url=None):
"""
Expand Down
37 changes: 19 additions & 18 deletions whisper_live/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import textwrap

import logging
# logging.basicConfig(level = logging.INFO)
logging.basicConfig(level = logging.INFO)

from websockets.sync.server import serve

Expand Down Expand Up @@ -100,20 +100,19 @@ def recv_audio(self, websocket, custom_model_path=None):
return

# validate custom model
if options["use_custom_model"]:
if custom_model_path is None or not os.path.exists(custom_model_path):
options["use_custom_model"] = False
if custom_model_path is not None and os.path.exists(custom_model_path):
logging.info(f"Using custom model {custom_model_path}")
options["model"] = custom_model_path

client = ServeClient(
websocket,
multilingual=options["multilingual"],
language=options["language"],
task=options["task"],
client_uid=options["uid"],
model_size_or_path=custom_model_path if options["use_custom_model"] else options["model_size"],
model=options["model"],
initial_prompt=options.get("initial_prompt"),
vad_parameters=options.get("vad_parameters"),
use_custom_model=options["use_custom_model"]
)

self.clients[websocket] = client
Expand Down Expand Up @@ -206,10 +205,9 @@ def __init__(
multilingual=False,
language=None,
client_uid=None,
model_size_or_path="small",
model="small",
initial_prompt=None,
vad_parameters=None,
use_custom_model=False
):
"""
Initialize a ServeClient instance.
Expand All @@ -230,13 +228,15 @@ def __init__(
self.data = b""
self.frames = b""
self.model_sizes = [
"tiny", "base", "small", "medium", "large-v2", "large-v3"
"tiny", "tiny.en", "base", "base.en", "small", "small.en",
"medium", "medium.en", "large-v2", "large-v3",
]

self.multilingual = multilingual
if not use_custom_model:
self.model_size_or_path = self.get_model_size(model_size_or_path)
if not os.path.exists(model):
self.model_size_or_path = self.get_model_size(model)
else:
self.model_size_or_path = model_size_or_path
self.model_size_or_path = model

self.language = language if self.multilingual else "en"
self.task = task
Expand All @@ -246,7 +246,7 @@ def __init__(

device = "cuda" if torch.cuda.is_available() else "cpu"

if self.model_size_or_path == None:
if self.model_size_or_path is None:
return

self.transcriber = WhisperModel(
Expand Down Expand Up @@ -302,12 +302,13 @@ def get_model_size(self, model_size):
)
return None

if model_size in ["large-v2", "large-v3"]:
if model_size.endswith("en") and self.multilingual:
logging.info(f"Setting multilingual to false with {model_size} which is english only model.")
self.multilingual = False

if not model_size.endswith("en") and not self.multilingual:
logging.info(f"Setting multilingual to true with multilingual model {model_size}.")
self.multilingual = True
return model_size

if not self.multilingual:
model_size = model_size + ".en"

return model_size

Expand Down

0 comments on commit 0942dc2

Please sign in to comment.