Skip to content

Commit

Permalink
also give text transformer an associative scan based layer before att…
Browse files Browse the repository at this point in the history
…ention, release new minor version
  • Loading branch information
lucidrains committed Aug 28, 2024
1 parent b39fdbe commit e7e1cfc
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 10 deletions.
24 changes: 15 additions & 9 deletions e2_tts_pytorch/e2_tts.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from random import random
from functools import partial
from collections import namedtuple
from typing import Literal, List, Callable
from typing import Literal, Callable

import torch
from torch import nn, tensor, from_numpy
Expand Down Expand Up @@ -68,7 +68,7 @@ def forward(self, x, **kwargs):
# simple utf-8 tokenizer, since paper went character based

def list_str_to_tensor(
text: List[str],
text: list[str],
padding_value = -1
) -> Int['b nt']:

Expand All @@ -84,7 +84,7 @@ def get_g2p_en_encode():
g2p = G2p()

def encode(
text: List[str],
text: list[str],
padding_value = -1
) -> Int['b nt']:

Expand Down Expand Up @@ -367,6 +367,8 @@ def __init__(

# text related

text_gateloop = SimpleGateLoopLayer(dim = dim_text)

text_attn_norm = RMSNorm(dim_text)
text_attn = Attention(dim = dim_text, heads = text_heads, dim_head = text_dim_head, dropout = dropout, **attn_kwargs)

Expand All @@ -386,6 +388,7 @@ def __init__(
ff_norm,
ff,
ff_adaln_zero,
text_gateloop,
text_attn_norm,
text_attn,
text_ff_norm,
Expand Down Expand Up @@ -461,6 +464,7 @@ def forward(
ff_norm,
ff,
maybe_ff_adaln_zero,
text_gateloop,
text_attn_norm,
text_attn,
text_ff_norm,
Expand All @@ -473,6 +477,8 @@ def forward(
# smaller text transformer

if exists(text_embed):
text_embed = text_gateloop(text_embed) + text_embed

text_embed = text_attn(text_attn_norm(text_embed), rotary_pos_emb = text_rotary_pos_emb, mask = mask) + text_embed

text_embed = text_ff(text_ff_norm(text_embed)) + text_embed
Expand Down Expand Up @@ -528,7 +534,7 @@ def __init__(
mel_spec_kwargs: dict = dict(),
char_embed_kwargs: dict = dict(),
text_num_embeds = None,
tokenizer: str | Callable[[List[str]], Int['b nt']] = 'char_utf8'
tokenizer: str | Callable[[list[str]], Int['b nt']] = 'char_utf8'
):
super().__init__()

Expand Down Expand Up @@ -577,7 +583,7 @@ def forward(
self,
x: Float['b n d'] | Float['b nw'],
*,
text: Int['b nt'] | List[str] | None = None,
text: Int['b nt'] | list[str] | None = None,
lens: Int['b'] | None = None,
return_loss = True
):
Expand Down Expand Up @@ -656,10 +662,10 @@ def __init__(
mel_spec_module: Module | None = None,
char_embed_kwargs: dict = dict(),
mel_spec_kwargs: dict = dict(),
frac_lengths_mask: Tuple[float, float] = (0.7, 1.),
frac_lengths_mask: tuple[float, float] = (0.7, 1.),
immiscible = False,
text_num_embeds = None,
tokenizer: str | Callable[[List[str]], Int['b nt']] = 'char_utf8'
tokenizer: str | Callable[[list[str]], Int['b nt']] = 'char_utf8'
):
super().__init__()

Expand Down Expand Up @@ -786,7 +792,7 @@ def sample(
self,
cond: Float['b n d'] | Float['b nw'],
*,
text: Int['b nt'] | List[str] | None = None,
text: Int['b nt'] | list[str] | None = None,
lens: Int['b'] | None = None,
duration: int | Int['b'] | None = None,
steps = 32,
Expand Down Expand Up @@ -877,7 +883,7 @@ def forward(
self,
inp: Float['b n d'] | Float['b nw'], # mel or raw wave
*,
text: Int['b nt'] | List[str] | None = None,
text: Int['b nt'] | list[str] | None = None,
times: Int['b'] | None = None,
lens: Int['b'] | None = None,
):
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[project]
name = "e2-tts-pytorch"
version = "0.6.3"
version = "0.7.0"
description = "E2-TTS in Pytorch"
authors = [
{ name = "Phil Wang", email = "[email protected]" }
Expand Down

0 comments on commit e7e1cfc

Please sign in to comment.