Skip to content

Commit

Permalink
just improvise a variant with dual text and audio transformers, with …
Browse files Browse the repository at this point in the history
…cross text-audio conditioning every layer
  • Loading branch information
lucidrains committed Aug 23, 2024
1 parent f98283e commit 6d51e28
Show file tree
Hide file tree
Showing 2 changed files with 125 additions and 69 deletions.
192 changes: 124 additions & 68 deletions e2_tts_pytorch/e2_tts.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
nt - text sequence
nw - raw wave length
d - dimension
dt - dimension text
"""

from __future__ import annotations
Expand Down Expand Up @@ -221,52 +222,46 @@ def __init__(
self,
dim,
num_embeds = 256,
cond_drop_prob = 0.,
num_gateloop_layers = 0
):
super().__init__()
self.dim = dim
self.cond_drop_prob = cond_drop_prob

self.embed = nn.Embedding(num_embeds + 1, dim) # will just use 0 as the 'filler token'

self.gateloops = ModuleList([Sequential(Linear(dim * 3, dim * 3, bias = False), SimpleGateLoopLayer(dim = dim * 3)) for _ in range(num_gateloop_layers)])

self.to_cond_gamma_beta = Linear(dim * 3, dim * 2)

nn.init.zeros_(self.to_cond_gamma_beta.weight)
nn.init.zeros_(self.to_cond_gamma_beta.bias)

def forward(
self,
x: Float['b n d'],
cond: Float['b n d'],
text: Int['b nt'],
drop_text_cond = None
max_seq_len: int,
) -> Float['b n d']:
drop_text_cond = default(drop_text_cond, self.training and random() < self.cond_drop_prob)

if drop_text_cond:
return x

max_seq_len = x.shape[1]

text = text + 1 # shift all other token ids up by 1 and use 0 as filler token

text = text[:, :max_seq_len] # just curtail if character tokens are more than the mel spec tokens, one of the edge cases the paper did not address
text = F.pad(text, (0, max_seq_len - text.shape[1]), value = 0)

text_embed = self.embed(text)
return self.embed(text)

concatted = torch.cat((x, cond, text_embed), dim = -1)
class TextAudioCrossCondition(Module):
def __init__(
self,
dim,
dim_text,
):
super().__init__()
self.audio_to_text = nn.Linear(dim, dim_text, bias = False)
self.text_to_audio = nn.Linear(dim_text, dim, bias = False)

for gateloop in self.gateloops:
concatted = gateloop(concatted) + concatted
nn.init.zeros_(self.audio_to_text.weight)
nn.init.zeros_(self.text_to_audio.weight)

assert x.shape[-1] == text_embed.shape[-1] == self.dim, f'expected {self.dim} but received ({x.shape[-1]}, {text_embed.shape[-1]})'
def forward(
self,
audio: Float['b n d'],
text: Float['b n dt']
):
text_cond = self.text_to_audio(text)
audio_cond = self.audio_to_text(audio)

gamma, beta = self.to_cond_gamma_beta(concatted).chunk(2, dim = -1)
return x * (gamma + 1.) + beta
return audio + text_cond, text + audio_cond

# attention and transformer backbone
# for use in both e2tts as well as duration module
Expand All @@ -276,14 +271,14 @@ def __init__(
self,
*,
dim,
dim_text = None, # will default to half of audio dimension
depth = 8,
heads = 8,
dim_head = 64,
cond_on_time = True,
skip_connect_type: Literal['add', 'concat', 'none'] = 'concat',
abs_pos_emb = True,
max_seq_len = 8192,
heads = 8,
dim_head = 64,
num_gateloop_layers = 1,
dropout = 0.1,
num_registers = 32,
attn_kwargs: dict = dict(
Expand All @@ -301,6 +296,10 @@ def __init__(
self.abs_pos_emb = nn.Embedding(max_seq_len, dim) if abs_pos_emb else None

self.dim = dim

dim_text = default(dim_text, dim // 2)
self.dim_text = dim_text

self.skip_connect_type = skip_connect_type
needs_skip_proj = skip_connect_type == 'concat'

Expand All @@ -313,13 +312,13 @@ def __init__(
self.registers = nn.Parameter(torch.zeros(num_registers, dim))
nn.init.normal_(self.registers, std = 0.02)

self.text_registers = nn.Parameter(torch.zeros(num_registers, dim_text))
nn.init.normal_(self.text_registers, std = 0.02)

# rotary embedding

self.rotary_emb = RotaryEmbedding(dim_head)

# gateloops

self.gateloops = ModuleList([SimpleGateLoopLayer(dim = dim) for _ in range(num_gateloop_layers)])
self.text_rotary_emb = RotaryEmbedding(dim_head)

# time conditioning
# will use adaptive rmsnorm
Expand All @@ -337,11 +336,13 @@ def __init__(
nn.SiLU()
)

self.to_time_token = nn.Linear(dim, dim, bias = False)

for ind in range(depth):
is_later_half = ind >= (depth // 2)

# speech related

gateloop = SimpleGateLoopLayer(dim = dim)

attn_norm = rmsnorm_klass(dim)
attn = Attention(dim = dim, heads = heads, dim_head = dim_head, dropout = dropout, **attn_kwargs)
attn_adaln_zero = postbranch_klass()
Expand All @@ -352,14 +353,32 @@ def __init__(

skip_proj = Linear(dim * 2, dim, bias = False) if needs_skip_proj and is_later_half else None

# text related

text_attn_norm = RMSNorm(dim_text)
text_attn = Attention(dim = dim_text, heads = heads, dim_head = dim_head, dropout = dropout, **attn_kwargs)

text_ff_norm = RMSNorm(dim_text)
text_ff = FeedForward(dim = dim_text, glu = True, dropout = dropout, **ff_kwargs)

# cross condition

cross_condition = TextAudioCrossCondition(dim = dim, dim_text = dim_text)

self.layers.append(ModuleList([
gateloop,
skip_proj,
attn_norm,
attn,
attn_adaln_zero,
ff_norm,
ff,
ff_adaln_zero
ff_adaln_zero,
text_attn_norm,
text_attn,
text_ff_norm,
text_ff,
cross_condition
]))

self.final_norm = RMSNorm(dim)
Expand All @@ -368,17 +387,13 @@ def forward(
self,
x: Float['b n d'],
times: Float['b'] | Float[''] | None = None,
mask: Bool['b n'] | None = None
mask: Bool['b n'] | None = None,
text_embed: Float['b n dt'] | None = None,
):
batch, seq_len, device = *x.shape[:2], x.device

assert not (exists(times) ^ self.cond_on_time), '`times` must be passed in if `cond_on_time` is set to `True` and vice versa'

# gateloop layers

for gateloop in self.gateloops:
x = gateloop(x) + x

# handle absolute positions if needed

if exists(self.abs_pos_emb):
Expand All @@ -397,14 +412,6 @@ def forward(
times = self.time_cond_mlp(times)
norm_kwargs.update(condition = times)

# u-vit paper claims using a time token helps better condition https://arxiv.org/abs/2209.12152

time_token = self.to_time_token(times)
x, time_packed_shape = pack((time_token, x), 'b * d')

if exists(mask):
mask = F.pad(mask, (1, 0), value = True)

# register tokens

registers = repeat(self.registers, 'r d -> b r d', b = batch)
Expand All @@ -417,6 +424,14 @@ def forward(

rotary_pos_emb = self.rotary_emb.forward_from_seq_len(x.shape[-2])

# text related

if exists(text_embed):
text_rotary_pos_emb = self.text_rotary_emb.forward_from_seq_len(x.shape[-2])

text_registers = repeat(self.text_registers, 'r d -> b r d', b = batch)
text_embed, _ = pack((text_registers, text_embed), 'b * d')

# skip connection related stuff

skip_connect_type = self.skip_connect_type
Expand All @@ -425,9 +440,33 @@ def forward(

# go through the layers

for ind, (maybe_skip_proj, attn_norm, attn, maybe_attn_adaln_zero, ff_norm, ff, maybe_ff_adaln_zero) in enumerate(self.layers):
for ind, (
gateloop,
maybe_skip_proj,
attn_norm,
attn,
maybe_attn_adaln_zero,
ff_norm,
ff,
maybe_ff_adaln_zero,
text_attn_norm,
text_attn,
text_ff_norm,
text_ff,
cross_condition
) in enumerate(self.layers):

layer = ind + 1

# smaller text transformer

if exists(text_embed):
text_embed = text_attn(text_attn_norm(text_embed), rotary_pos_emb = text_rotary_pos_emb, mask = mask) + text_embed

text_embed = text_ff(text_ff_norm(text_embed)) + text_embed

x, text_embed = cross_condition(x, text_embed)

# skip connection logic

is_first_half = layer <= (self.depth // 2)
Expand All @@ -447,6 +486,10 @@ def forward(
# additive
x = x + skip

# associative scan

x = gateloop(x) + x

# attention and feedforward blocks

attn_out = attn(attn_norm(x, **norm_kwargs), rotary_pos_emb = rotary_pos_emb, mask = mask)
Expand All @@ -461,9 +504,6 @@ def forward(

_, x = unpack(x, registers_packed_shape, 'b * d')

if exists(times):
_, x = unpack(x, time_packed_shape, 'b * d')

return self.final_norm(x)

# main classes
Expand All @@ -474,9 +514,7 @@ def __init__(
transformer: dict | Transformer,
num_channels = None,
mel_spec_kwargs: dict = dict(),
char_embed_kwargs: dict = dict(
num_gateloop_layers = 2
),
char_embed_kwargs: dict = dict(),
text_num_embeds = None,
tokenizer: str | Callable[[List[str]], Int['b nt']] = 'char_utf8'
):
Expand Down Expand Up @@ -513,7 +551,7 @@ def __init__(
else:
raise ValueError(f'unknown tokenizer string {tokenizer}')

self.embed_text = CharacterEmbed(dim, num_embeds = text_num_embeds, **char_embed_kwargs)
self.embed_text = CharacterEmbed(transformer.dim_text, num_embeds = text_num_embeds, **char_embed_kwargs)

# to prediction

Expand Down Expand Up @@ -544,12 +582,14 @@ def forward(

# text

text_embed = None

if exists(text):
if isinstance(text, list):
text = list_str_to_tensor(text).to(device)
assert text.shape[0] == batch

x = self.embed_text(x, x, text)
text_embed = self.embed_text(text, seq_len)

# handle lengths (duration)

Expand All @@ -569,7 +609,11 @@ def forward(

# attending

x = self.transformer(x, mask = mask)
x = self.transformer(
x,
mask = mask,
text_embed = text_embed,
)

x = maybe_masked_mean(x, mask)

Expand Down Expand Up @@ -598,9 +642,7 @@ def __init__(
cond_drop_prob = 0.25,
num_channels = None,
mel_spec_module: Module | None = None,
char_embed_kwargs: dict = dict(
num_gateloop_layers = 2
),
char_embed_kwargs: dict = dict(),
mel_spec_kwargs: dict = dict(),
frac_lengths_mask: Tuple[float, float] = (0.7, 1.),
immiscible = False,
Expand All @@ -621,7 +663,10 @@ def __init__(
self.transformer = transformer

dim = transformer.dim
dim_text = transformer.dim_text

self.dim = dim
self.dim_text = dim_text

self.frac_lengths_mask = frac_lengths_mask

Expand Down Expand Up @@ -660,7 +705,9 @@ def __init__(
else:
raise ValueError(f'unknown tokenizer string {tokenizer}')

self.embed_text = CharacterEmbed(dim, num_embeds = text_num_embeds, cond_drop_prob = cond_drop_prob, **char_embed_kwargs)
self.cond_drop_prob = cond_drop_prob

self.embed_text = CharacterEmbed(dim_text, num_embeds = text_num_embeds, **char_embed_kwargs)

# immiscible flow - https://arxiv.org/abs/2406.12303

Expand All @@ -679,16 +726,25 @@ def transformer_with_pred_head(
text: Int['b nt'] | None = None,
drop_text_cond: bool | None = None
):
seq_len = x.shape[-2]
drop_text_cond = default(drop_text_cond, self.training and random() < self.cond_drop_prob)

x = self.proj_in(x)
cond = self.cond_proj_in(cond)

if exists(text):
x = self.embed_text(x, cond, text, drop_text_cond = drop_text_cond)
# whether to use a text embedding

text_embed = None
if exists(text) and not drop_text_cond:
text_embed = self.embed_text(text, seq_len)

# attend

attended = self.transformer(
x,
times = times,
mask = mask
mask = mask,
text_embed = text_embed
)

return self.to_pred(attended)
Expand Down
Loading

0 comments on commit 6d51e28

Please sign in to comment.