diff --git a/ovos_stt_plugin_citrinet/engine.py b/ovos_stt_plugin_citrinet/engine.py index f0e6cff..18333f2 100644 --- a/ovos_stt_plugin_citrinet/engine.py +++ b/ovos_stt_plugin_citrinet/engine.py @@ -35,10 +35,9 @@ import numpy as np import onnxruntime as ort import sentencepiece as spm -import soxr -import torch +import torch # TODO - try to drop dependency if we can convert preprocessor to onnx, currently not possible from huggingface_hub import hf_hub_download -from pydub import AudioSegment +from ovos_utils.log import LOG languages = { "en": { @@ -152,7 +151,6 @@ def _run_tokenizer(self, logits: np.array): @staticmethod def _ctc_decode(logits: np.array, blank_id: int): labels = logits.argmax(axis=1).tolist() - previous = blank_id decoded_prediction = [] for p in labels: @@ -172,19 +170,6 @@ def stt(self, audio_buffer: np.array, sr: int): self._trim_memory() return current_hypotheses - def stt_file(self, file_path: str): - audio_buffer, sr = self.read_file(file_path) - current_hypotheses = self.stt(audio_buffer, sr) - return current_hypotheses - - def read_file(self, file_path: str): - audio_file = AudioSegment.from_file(file_path) - sr = audio_file.frame_rate - - samples = audio_file.get_array_of_samples() - audio_buffer = np.array(samples) - return audio_buffer, sr - @staticmethod def _trim_memory(): """ @@ -195,10 +180,18 @@ def _trim_memory(): gc.collect() def _resample(self, audio_fp32: np.array, sr: int): + if sr == self.sample_rate: + return audio_fp32 + try: + import soxr + except ImportError: + LOG.error("Either provide audio at 16000 sample rate or install soxr for automatic resampling") + raise audio_16k = soxr.resample(audio_fp32, sr, self.sample_rate) return audio_16k - def _to_float32(self, audio_buffer: np.array): + @staticmethod + def _to_float32(audio_buffer: np.array): audio_fp32 = np.divide(audio_buffer, np.iinfo(audio_buffer.dtype).max, dtype=np.float32) return audio_fp32 diff --git a/requirements.txt b/requirements.txt index 7d706c8..8afa2bf 100644 --- a/requirements.txt +++ b/requirements.txt @@ -5,9 +5,6 @@ SpeechRecognition~=3.8 torch>=1.13.1 onnxruntime sentencepiece -# resampling -soxr -pydub # huggingface huggingface-hub numpy<2.0.0 \ No newline at end of file diff --git a/setup.py b/setup.py index 592daac..39b97cd 100755 --- a/setup.py +++ b/setup.py @@ -62,6 +62,9 @@ def required(requirements_file): license='Apache-2.0', packages=['ovos_stt_plugin_citrinet'], install_requires=required("requirements.txt"), + extras_require={ + 'extra': ["soxr"] + }, zip_safe=True, classifiers=[ 'Development Status :: 3 - Alpha',