From 636e67b34729983b2d90591c01f2e820fe3955ae Mon Sep 17 00:00:00 2001 From: Wannaphong Phatthiyaphaibun Date: Thu, 11 Aug 2022 16:43:50 +0700 Subject: [PATCH] PyThaiASR v1.1.0 - remove tokenized --- README.md | 3 +-- pythaiasr/__init__.py | 14 +++++--------- setup.py | 2 +- 3 files changed, 7 insertions(+), 12 deletions(-) diff --git a/README.md b/README.md index 32a259f..c13358a 100644 --- a/README.md +++ b/README.md @@ -37,11 +37,10 @@ print(asr(file)) ### API ```python -asr(file: str, show_pad: bool = False, model: str = "airesearch/wav2vec2-large-xlsr-53-th") +asr(file: str, model: str = "airesearch/wav2vec2-large-xlsr-53-th") ``` - file: path of sound file -- show_pad: show [PAD] in output - model: The ASR model - return: thai text from ASR diff --git a/pythaiasr/__init__.py b/pythaiasr/__init__.py index fce4242..9bac88c 100644 --- a/pythaiasr/__init__.py +++ b/pythaiasr/__init__.py @@ -19,6 +19,7 @@ def __init__(self, model: str="airesearch/wav2vec2-large-xlsr-53-th", device=Non * *wannaphong/wav2vec2-large-xlsr-53-th-cv8-deepcut* - Thai Wav2Vec2 with CommonVoice V8 (deepcut tokenizer) + language model """ self.processor = AutoProcessor.from_pretrained(model) + self.model_name = model self.model = AutoModelForCTC.from_pretrained(model) if device!=None: self.device = torch.device(device) @@ -40,10 +41,9 @@ def prepare_dataset(self, batch: dict) -> dict: batch["input_values"] = self.processor(batch["speech"], sampling_rate=batch["sampling_rate"]).input_values return batch - def __call__(self, file: str, tokenized: bool = False) -> str: + def __call__(self, file: str) -> str: """ :param str file: path of sound file - :param bool show_pad: show [PAD] in output :param str model: The ASR model """ b = {} @@ -53,20 +53,16 @@ def __call__(self, file: str, tokenized: bool = False) -> str: logits = self.model(input_dict.input_values).logits pred_ids = torch.argmax(logits, dim=-1)[0] - if tokenized: - txt = self.processor.batch_decode(logits.detach().numpy()).text - else: - txt = self.processor.batch_decode(logits.detach().numpy()).text.replace(' ','') + txt = self.processor.batch_decode(logits.detach().numpy()).text[0] return txt _model_name = "airesearch/wav2vec2-large-xlsr-53-th" _model = None -def asr(file: str, tokenized: bool = False, model: str = _model_name) -> str: +def asr(file: str, model: str = _model_name) -> str: """ :param str file: path of sound file - :param bool show_pad: show [PAD] in output :param str model: The ASR model :return: thai text from ASR :rtype: str @@ -81,4 +77,4 @@ def asr(file: str, tokenized: bool = False, model: str = _model_name) -> str: _model = ASR(model) _model_name = model - return _model(file=file, tokenized=tokenized) + return _model(file=file) diff --git a/setup.py b/setup.py index db394ea..50e533e 100644 --- a/setup.py +++ b/setup.py @@ -27,7 +27,7 @@ def read(*paths): setup( name='pythaiasr', - version='1.0.1', + version='1.1.0', packages=['pythaiasr'], url='https://github.com/pythainlp/pythaiasr', license='Apache Software License 2.0',