Skip to content

Commit

Permalink
PyThaiASR v1.1.0
Browse files Browse the repository at this point in the history
- remove tokenized
  • Loading branch information
wannaphong committed Aug 11, 2022
1 parent 30abe5f commit 636e67b
Show file tree
Hide file tree
Showing 3 changed files with 7 additions and 12 deletions.
3 changes: 1 addition & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -37,11 +37,10 @@ print(asr(file))
### API

```python
asr(file: str, show_pad: bool = False, model: str = "airesearch/wav2vec2-large-xlsr-53-th")
asr(file: str, model: str = "airesearch/wav2vec2-large-xlsr-53-th")
```

- file: path of sound file
- show_pad: show [PAD] in output
- model: The ASR model
- return: thai text from ASR

Expand Down
14 changes: 5 additions & 9 deletions pythaiasr/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ def __init__(self, model: str="airesearch/wav2vec2-large-xlsr-53-th", device=Non
* *wannaphong/wav2vec2-large-xlsr-53-th-cv8-deepcut* - Thai Wav2Vec2 with CommonVoice V8 (deepcut tokenizer) + language model
"""
self.processor = AutoProcessor.from_pretrained(model)
self.model_name = model
self.model = AutoModelForCTC.from_pretrained(model)
if device!=None:
self.device = torch.device(device)
Expand All @@ -40,10 +41,9 @@ def prepare_dataset(self, batch: dict) -> dict:
batch["input_values"] = self.processor(batch["speech"], sampling_rate=batch["sampling_rate"]).input_values
return batch

def __call__(self, file: str, tokenized: bool = False) -> str:
def __call__(self, file: str) -> str:
"""
:param str file: path of sound file
:param bool show_pad: show [PAD] in output
:param str model: The ASR model
"""
b = {}
Expand All @@ -53,20 +53,16 @@ def __call__(self, file: str, tokenized: bool = False) -> str:
logits = self.model(input_dict.input_values).logits
pred_ids = torch.argmax(logits, dim=-1)[0]

if tokenized:
txt = self.processor.batch_decode(logits.detach().numpy()).text
else:
txt = self.processor.batch_decode(logits.detach().numpy()).text.replace(' ','')
txt = self.processor.batch_decode(logits.detach().numpy()).text[0]
return txt

_model_name = "airesearch/wav2vec2-large-xlsr-53-th"
_model = None


def asr(file: str, tokenized: bool = False, model: str = _model_name) -> str:
def asr(file: str, model: str = _model_name) -> str:
"""
:param str file: path of sound file
:param bool show_pad: show [PAD] in output
:param str model: The ASR model
:return: thai text from ASR
:rtype: str
Expand All @@ -81,4 +77,4 @@ def asr(file: str, tokenized: bool = False, model: str = _model_name) -> str:
_model = ASR(model)
_model_name = model

return _model(file=file, tokenized=tokenized)
return _model(file=file)
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ def read(*paths):

setup(
name='pythaiasr',
version='1.0.1',
version='1.1.0',
packages=['pythaiasr'],
url='https://github.com/pythainlp/pythaiasr',
license='Apache Software License 2.0',
Expand Down

0 comments on commit 636e67b

Please sign in to comment.