Skip to content

Commit

Permalink
make sure phonemizer can be used for duration predictor, and set the …
Browse files Browse the repository at this point in the history
…correct number of phoneme embeds
  • Loading branch information
lucidrains committed Aug 19, 2024
1 parent c2bb6ce commit 24dc805
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 7 deletions.
32 changes: 26 additions & 6 deletions e2_tts_pytorch/e2_tts.py
Original file line number Diff line number Diff line change
Expand Up @@ -434,12 +434,13 @@ class DurationPredictor(Module):
def __init__(
self,
transformer: dict | Transformer,
text_num_embeds = 256,
num_channels = None,
mel_spec_kwargs: dict = dict(),
char_embed_kwargs: dict = dict(
num_gateloop_layers = 2
)
),
text_num_embeds = None,
tokenizer: str | Callable[[List[str]], Int['b nt']] = 'char_utf8'
):
super().__init__()

Expand All @@ -460,8 +461,24 @@ def __init__(

self.proj_in = Linear(self.num_channels, self.dim)

# tokenizer and text embed

if callable(tokenizer):
assert exists(text_num_embeds), '`text_num_embeds` must be given if supplying your own tokenizer encode function'
self.tokenizer = tokenizer
elif tokenizer == 'char_utf8':
text_num_embeds = 256
self.tokenizer = list_str_to_tensor
elif tokenizer == 'phoneme_en':
text_num_embeds = 74
self.tokenizer = get_g2p_en_encode()
else:
raise ValueError(f'unknown tokenizer string {tokenizer}')

self.embed_text = CharacterEmbed(dim, num_embeds = text_num_embeds, **char_embed_kwargs)

# to prediction

self.to_pred = Sequential(
Linear(dim, 1, bias = False),
nn.Softplus(),
Expand Down Expand Up @@ -540,7 +557,6 @@ def __init__(
rtol = 1e-5,
method = 'midpoint'
),
text_num_embeds = 256,
cond_drop_prob = 0.25,
num_channels = None,
mel_spec_module: Module | None = None,
Expand All @@ -550,6 +566,7 @@ def __init__(
mel_spec_kwargs: dict = dict(),
frac_lengths_mask: Tuple[float, float] = (0.7, 1.),
immiscible = False,
text_num_embeds = None,
tokenizer: str | Callable[[List[str]], Int['b nt']] = 'char_utf8'
):
super().__init__()
Expand All @@ -570,8 +587,6 @@ def __init__(

self.frac_lengths_mask = frac_lengths_mask

self.embed_text = CharacterEmbed(dim, num_embeds = text_num_embeds, cond_drop_prob = cond_drop_prob, **char_embed_kwargs)

self.duration_predictor = duration_predictor

# conditional flow related
Expand All @@ -593,17 +608,22 @@ def __init__(
self.cond_proj_in = Linear(num_channels, dim)
self.to_pred = Linear(dim, num_channels)

# tokenizer
# tokenizer and text embed

if callable(tokenizer):
assert exists(text_num_embeds), '`text_num_embeds` must be given if supplying your own tokenizer encode function'
self.tokenizer = tokenizer
elif tokenizer == 'char_utf8':
text_num_embeds = 256
self.tokenizer = list_str_to_tensor
elif tokenizer == 'phoneme_en':
text_num_embeds = 74
self.tokenizer = get_g2p_en_encode()
else:
raise ValueError(f'unknown tokenizer string {tokenizer}')

self.embed_text = CharacterEmbed(dim, num_embeds = text_num_embeds, cond_drop_prob = cond_drop_prob, **char_embed_kwargs)

# immiscible flow - https://arxiv.org/abs/2406.12303

self.immiscible = immiscible
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[project]
name = "e2-tts-pytorch"
version = "0.4.1"
version = "0.4.2"
description = "E2-TTS in Pytorch"
authors = [
{ name = "Phil Wang", email = "[email protected]" }
Expand Down

0 comments on commit 24dc805

Please sign in to comment.