From cb04833f8e6b48e599018ee44a175cde86fdc98d Mon Sep 17 00:00:00 2001 From: yaozengwei Date: Tue, 20 Feb 2024 21:18:03 +0800 Subject: [PATCH] remove extra tokens --- egs/ljspeech/TTS/local/prepare_token_file.py | 19 +++----------- egs/ljspeech/TTS/prepare.sh | 6 +++-- egs/ljspeech/TTS/vits/tokenizer.py | 27 ++++++++++---------- 3 files changed, 20 insertions(+), 32 deletions(-) diff --git a/egs/ljspeech/TTS/local/prepare_token_file.py b/egs/ljspeech/TTS/local/prepare_token_file.py index 29e4a50c9f..dd76c1565c 100755 --- a/egs/ljspeech/TTS/local/prepare_token_file.py +++ b/egs/ljspeech/TTS/local/prepare_token_file.py @@ -43,23 +43,10 @@ def get_args(): def get_token2id(filename: Path) -> Dict[str, int]: """Get a dict that maps token to IDs, and save it to the given filename.""" - extra_tokens = [ - "", # 0 for blank - "", # 1 for sos - "", # 2 for eos - "", # 3 for OOV - ] - - all_tokens = list(get_espeak_map().keys()) - - for t in extra_tokens: - assert t not in all_tokens, t - - all_tokens = extra_tokens + all_tokens - + all_tokens = get_espeak_map() with open(filename, "w", encoding="utf-8") as f: - for i, token in enumerate(all_tokens): - f.write(f"{token} {i}\n") + for token, token_id in all_tokens.items(): + f.write(f"{token} {token_id[0]}\n") if __name__ == "__main__": diff --git a/egs/ljspeech/TTS/prepare.sh b/egs/ljspeech/TTS/prepare.sh index 890bc841f8..cbf27bd423 100755 --- a/egs/ljspeech/TTS/prepare.sh +++ b/egs/ljspeech/TTS/prepare.sh @@ -82,7 +82,8 @@ if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then log "Stage 3: Prepare phoneme tokens for LJSpeech" # We assume you have installed piper_phonemize and espnet_tts_frontend. # If not, please install them with: - # - piper_phonemize: refer to https://github.com/rhasspy/piper-phonemize + # - piper_phonemize: refer to https://github.com/rhasspy/piper-phonemize, + # could install the pre-built wheels from https://github.com/csukuangfj/piper-phonemize/releases/tag/2023.12.5 # - espnet_tts_frontend, `pip install espnet_tts_frontend`, refer to https://github.com/espnet/espnet_tts_frontend/ if [ ! -e data/spectrogram/.ljspeech_with_token.done ]; then ./local/prepare_tokens_ljspeech.py @@ -119,7 +120,8 @@ if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then log "Stage 5: Generate token file" # We assume you have installed piper_phonemize and espnet_tts_frontend. # If not, please install them with: - # - piper_phonemize: refer to https://github.com/rhasspy/piper-phonemize + # - piper_phonemize: refer to https://github.com/rhasspy/piper-phonemize, + # could install the pre-built wheels from https://github.com/csukuangfj/piper-phonemize/releases/tag/2023.12.5 # - espnet_tts_frontend, `pip install espnet_tts_frontend`, refer to https://github.com/espnet/espnet_tts_frontend/ if [ ! -e data/tokens.txt ]; then ./local/prepare_token_file.py --tokens data/tokens.txt diff --git a/egs/ljspeech/TTS/vits/tokenizer.py b/egs/ljspeech/TTS/vits/tokenizer.py index 64530fa335..e005fc1845 100644 --- a/egs/ljspeech/TTS/vits/tokenizer.py +++ b/egs/ljspeech/TTS/vits/tokenizer.py @@ -38,12 +38,15 @@ def __init__(self, tokens: str): id = int(info[0]) else: token, id = info[0], int(info[1]) + assert token not in self.token2id, token self.token2id[token] = id - self.blank_id = self.token2id[""] - self.sos_id = self.token2id[""] - self.eos_id = self.token2id[""] - self.oov_id = self.token2id[""] + # Refer to https://github.com/rhasspy/piper/blob/master/TRAINING.md + self.pad_id = self.token2id["_"] # padding + self.sos_id = self.token2id["^"] # beginning of an utterance (bos) + self.eos_id = self.token2id["$"] # end of an utterance (eos) + self.space_id = self.token2id[" "] # word separator (whitespace) + self.vocab_size = len(self.token2id) def texts_to_token_ids( @@ -80,13 +83,11 @@ def texts_to_token_ids( token_ids = [] for t in tokens: - if t in self.token2id: - token_ids.append(self.token2id[t]) - else: - token_ids.append(self.oov_id) + assert t in self.token2id, t + token_ids.append(self.token2id[t]) if intersperse_blank: - token_ids = intersperse(token_ids, self.blank_id) + token_ids = intersperse(token_ids, self.pad_id) if add_sos: token_ids = [self.sos_id] + token_ids if add_eos: @@ -122,13 +123,11 @@ def tokens_to_token_ids( for tokens in tokens_list: token_ids = [] for t in tokens: - if t in self.token2id: - token_ids.append(self.token2id[t]) - else: - token_ids.append(self.oov_id) + assert t in self.token2id, t + token_ids.append(self.token2id[t]) if intersperse_blank: - token_ids = intersperse(token_ids, self.blank_id) + token_ids = intersperse(token_ids, self.pad_id) if add_sos: token_ids = [self.sos_id] + token_ids if add_eos: