Skip to content

Commit

Permalink
able to use an english phonemizer instead of character level utf8
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Aug 19, 2024
1 parent 17f7477 commit c2bb6ce
Show file tree
Hide file tree
Showing 2 changed files with 41 additions and 9 deletions.
47 changes: 39 additions & 8 deletions e2_tts_pytorch/e2_tts.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from random import random

import torch
from torch import nn, from_numpy
from torch import nn, tensor, from_numpy
import torch.nn.functional as F
from torch.nn import Module, ModuleList, Sequential, Linear
from torch.nn.utils.rnn import pad_sequence
Expand Down Expand Up @@ -66,9 +66,28 @@ def list_str_to_tensor(
padding_value = -1
) -> Int['b nt']:

list_tensors = [torch.tensor([*bytes(t, 'UTF-8')]) for t in text]
text = pad_sequence(list_tensors, padding_value = -1, batch_first = True)
return text
list_tensors = [tensor([*bytes(t, 'UTF-8')]) for t in text]
padded_tensor = pad_sequence(list_tensors, padding_value = -1, batch_first = True)
return padded_tensor

# simple english phoneme-based tokenizer

from g2p_en import G2p

def get_g2p_en_encode():
g2p = G2p()

def encode(
text: List[str],
padding_value = -1
) -> Int['b nt']:

phonemes = [g2p(t) for t in text]
list_tensors = [tensor([g2p.p2idx[p] for p in one_phoneme]) for one_phoneme in phonemes]
padded_tensor = pad_sequence(list_tensors, padding_value = -1, batch_first = True)
return padded_tensor

return encode

# tensor helpers

Expand Down Expand Up @@ -152,7 +171,7 @@ def __init__(
norm = norm,
)

self.register_buffer('dummy', torch.tensor(0), persistent = False)
self.register_buffer('dummy', tensor(0), persistent = False)

def forward(self, inp):
if len(inp.shape) == 3:
Expand Down Expand Up @@ -530,7 +549,8 @@ def __init__(
),
mel_spec_kwargs: dict = dict(),
frac_lengths_mask: Tuple[float, float] = (0.7, 1.),
immiscible = False
immiscible = False,
tokenizer: str | Callable[[List[str]], Int['b nt']] = 'char_utf8'
):
super().__init__()

Expand Down Expand Up @@ -573,6 +593,17 @@ def __init__(
self.cond_proj_in = Linear(num_channels, dim)
self.to_pred = Linear(dim, num_channels)

# tokenizer

if callable(tokenizer):
self.tokenizer = tokenizer
elif tokenizer == 'char_utf8':
self.tokenizer = list_str_to_tensor
elif tokenizer == 'phoneme_en':
self.tokenizer = get_g2p_en_encode()
else:
raise ValueError(f'unknown tokenizer string {tokenizer}')

# immiscible flow - https://arxiv.org/abs/2406.12303

self.immiscible = immiscible
Expand Down Expand Up @@ -650,7 +681,7 @@ def sample(
# text

if isinstance(text, list):
text = list_str_to_tensor(text).to(device)
text = self.tokenizer(text).to(device)
assert text.shape[0] == batch

if exists(text):
Expand Down Expand Up @@ -732,7 +763,7 @@ def forward(
# handle text as string

if isinstance(text, list):
text = list_str_to_tensor(text).to(device)
text = self.tokenizer(text).to(device)
assert text.shape[0] == batch

# lens and mask
Expand Down
3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[project]
name = "e2-tts-pytorch"
version = "0.4.0"
version = "0.4.1"
description = "E2-TTS in Pytorch"
authors = [
{ name = "Phil Wang", email = "[email protected]" }
Expand Down Expand Up @@ -29,6 +29,7 @@ dependencies = [
'einx>=0.3.0',
'ema-pytorch>=0.5.2',
'gateloop-transformer>=0.2.2',
'g2p-en',
'jaxtyping',
'loguru',
'scipy',
Expand Down

0 comments on commit c2bb6ce

Please sign in to comment.