Skip to content

Commit

Permalink
ENH: add model config for Whisper (#2755)
Browse files Browse the repository at this point in the history
Co-authored-by: codingl2k1 <[email protected]>
  • Loading branch information
fonsc and codingl2k1 authored Jan 14, 2025
1 parent 7c6249a commit d0dff35
Showing 1 changed file with 35 additions and 10 deletions.
45 changes: 35 additions & 10 deletions xinference/model/audio/whisper.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,12 @@
# limitations under the License.
import logging
import os
import typing
from glob import glob
from typing import TYPE_CHECKING, Dict, List, Optional, Union

from typing_extensions import TypedDict

from ...device_utils import (
get_available_device,
get_device_preferred_dtype,
Expand All @@ -28,21 +31,43 @@
logger = logging.getLogger(__name__)


class WhisperModelConfig(TypedDict, total=False):
chunk_length_s: Optional[float]
stride_length_s: Optional[float]
return_timestamps: Optional[bool]
batch_size: Optional[int]


class WhisperModel:
def __init__(
self,
model_uid: str,
model_path: str,
model_spec: "AudioModelFamilyV1",
device: Optional[str] = None,
max_new_tokens: Optional[int] = 128,
**kwargs,
):
self._model_uid = model_uid
self._model_path = model_path
self._model_spec = model_spec
self._device = device
self._model = None
self._kwargs = kwargs
self._max_new_tokens = max_new_tokens
self._model_config: WhisperModelConfig = self._sanitize_model_config(
typing.cast(WhisperModelConfig, kwargs)
)

def _sanitize_model_config(
self, model_config: Optional[WhisperModelConfig]
) -> WhisperModelConfig:
if model_config is None:
model_config = WhisperModelConfig()
model_config.setdefault("chunk_length_s", 30)
model_config.setdefault("stride_length_s", None)
model_config.setdefault("return_timestamps", False)
model_config.setdefault("batch_size", 16)
return model_config

@property
def model_ability(self):
Expand Down Expand Up @@ -75,10 +100,10 @@ def load(self):
model=model,
tokenizer=processor.tokenizer,
feature_extractor=processor.feature_extractor,
max_new_tokens=128,
chunk_length_s=30,
batch_size=16,
return_timestamps=False,
chunk_length_s=self._model_config.get("chunk_length_s"),
stride_length_s=self._model_config.get("stride_length_s"),
return_timestamps=self._model_config.get("return_timestamps"),
batch_size=self._model_config.get("batch_size"),
torch_dtype=torch_dtype,
device=self._device,
)
Expand Down Expand Up @@ -185,13 +210,13 @@ def transcriptions(
logger.warning(
"Prompt for whisper transcriptions will be ignored: %s", prompt
)
generate_kwargs = {"max_new_tokens": self._max_new_tokens, "task": "transcribe"}
if language is not None:
generate_kwargs["language"] = language

return self._call_model(
audio=audio,
generate_kwargs=(
{"language": language, "task": "transcribe"}
if language is not None
else {"task": "transcribe"}
),
generate_kwargs=generate_kwargs,
response_format=response_format,
temperature=temperature,
timestamp_granularities=timestamp_granularities,
Expand Down

0 comments on commit d0dff35

Please sign in to comment.