Skip to content

Commit

Permalink
add ability to use velocity consistency as a solution for straighteni…
Browse files Browse the repository at this point in the history
…ng the flow
  • Loading branch information
lucidrains committed Oct 13, 2024
1 parent c3ca775 commit 9599691
Show file tree
Hide file tree
Showing 4 changed files with 102 additions and 23 deletions.
11 changes: 11 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -124,3 +124,14 @@ sampled = e2tts.sample(mel[:, :5], text = text)
url = {https://api.semanticscholar.org/CorpusID:218674528}
}
```

```bibtex
@article{Yang2024ConsistencyFM,
title = {Consistency Flow Matching: Defining Straight Flows with Velocity Consistency},
author = {Ling Yang and Zixiang Zhang and Zhilong Zhang and Xingchao Liu and Minkai Xu and Wentao Zhang and Chenlin Meng and Stefano Ermon and Bin Cui},
journal = {ArXiv},
year = {2024},
volume = {abs/2407.02398},
url = {https://api.semanticscholar.org/CorpusID:270878436}
}
```
78 changes: 69 additions & 9 deletions e2_tts_pytorch/e2_tts.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,12 +28,12 @@
from torch.nn.utils.rnn import pad_sequence

import torchaudio
from torchaudio.functional import DB_to_amplitude, resample
from torchaudio.functional import DB_to_amplitude
from torchdiffeq import odeint

import einx
from einops.layers.torch import Rearrange
from einops import einsum, rearrange, repeat, reduce, pack, unpack
from einops import rearrange, repeat, reduce, pack, unpack

from x_transformers import (
Attention,
Expand Down Expand Up @@ -61,7 +61,11 @@ def __getitem__(self, shapes: str):
Int = TorchTyping(jaxtyping.Int)
Bool = TorchTyping(jaxtyping.Bool)

E2TTSReturn = namedtuple('E2TTS', ['loss', 'cond', 'pred_flow', 'pred_data'])
# named tuples

LossBreakdown = namedtuple('LossBreakdown', ['flow', 'velocity_consistency'])

E2TTSReturn = namedtuple('E2TTS', ['loss', 'cond', 'pred_flow', 'pred_data', 'loss_breakdown'])

# helpers

Expand Down Expand Up @@ -393,7 +397,6 @@ def forward(
) -> Float['b n d']:

device = text.device
seq = torch.arange(max_seq_len, device = device)

mask = default(mask, (None,))

Expand Down Expand Up @@ -895,7 +898,8 @@ def __init__(
) = 'char_utf8',
use_vocos = True,
pretrained_vocos_path = 'charactr/vocos-mel-24khz',
sampling_rate: int | None = None
sampling_rate: int | None = None,
velocity_consistency_weight = 0.,
):
super().__init__()

Expand Down Expand Up @@ -967,6 +971,11 @@ def __init__(

self.embed_text = text_embed_klass(dim_text, num_embeds = text_num_embeds, **char_embed_kwargs)

# weight for velocity consistency

self.register_buffer('zero', torch.tensor(0.), persistent = False)
self.velocity_consistency_weight = velocity_consistency_weight

# default vocos for mel -> audio

self.vocos = Vocos.from_pretrained(pretrained_vocos_path) if use_vocos else None
Expand All @@ -982,7 +991,8 @@ def transformer_with_pred_head(
times: Float['b'],
mask: Bool['b n'] | None = None,
text: Int['b nt'] | None = None,
drop_text_cond: bool | None = None
drop_text_cond: bool | None = None,
return_drop_text_cond = False
):
seq_len = x.shape[-2]
drop_text_cond = default(drop_text_cond, self.training and random() < self.cond_drop_prob)
Expand Down Expand Up @@ -1015,7 +1025,12 @@ def transformer_with_pred_head(
text_embed = text_embed
)

return self.to_pred(attended)
pred = self.to_pred(attended)

if not return_drop_text_cond:
return pred

return pred, drop_text_cond

def cfg_transformer_with_pred_head(
self,
Expand Down Expand Up @@ -1183,7 +1198,11 @@ def forward(
text: Int['b nt'] | list[str] | None = None,
times: Int['b'] | None = None,
lens: Int['b'] | None = None,
velocity_consistency_model: E2TTS | None = None,
velocity_consistency_delta = 1e-5
):
need_velocity_loss = exists(velocity_consistency_model) and self.velocity_consistency_weight > 0.

# handle raw wave

if inp.ndim == 2:
Expand Down Expand Up @@ -1230,6 +1249,11 @@ def forward(
times = torch.rand((batch,), dtype = dtype, device = self.device)
t = rearrange(times, 'b -> b 1 1')

# if need velocity consistency, make sure time does not exceed 1.

if need_velocity_loss:
t = t * (1. - velocity_consistency_delta)

# sample xt (w in the paper)

w = (1. - t) * x0 + t * x1
Expand All @@ -1246,12 +1270,48 @@ def forward(

# transformer and prediction head

pred = self.transformer_with_pred_head(w, cond, times = times, text = text, mask = mask)
pred, did_drop_text_cond = self.transformer_with_pred_head(
w,
cond,
times = times,
text = text,
mask = mask,
return_drop_text_cond = True
)

# maybe velocity consistency loss

velocity_loss = self.zero

if need_velocity_loss:

t_with_delta = t + velocity_consistency_delta
w_with_delta = (1. - t_with_delta) * x0 + t_with_delta * x1

with torch.no_grad():
ema_pred = velocity_consistency_model.transformer_with_pred_head(
w_with_delta,
cond,
times = times + velocity_consistency_delta,
text = text,
mask = mask,
drop_text_cond = did_drop_text_cond
)

velocity_loss = F.mse_loss(pred, ema_pred, reduction = 'none')
velocity_loss = velocity_loss[rand_span_mask].mean()

# flow matching loss

loss = F.mse_loss(pred, flow, reduction = 'none')

loss = loss[rand_span_mask].mean()

return E2TTSReturn(loss, cond, pred, x0 + pred)
# total loss and get breakdown

total_loss = loss + velocity_loss * self.velocity_consistency_weight
breakdown = LossBreakdown(loss, velocity_loss)

# return total loss and bunch of intermediates

return E2TTSReturn(total_loss, cond, pred, x0 + pred, breakdown)
34 changes: 21 additions & 13 deletions e2_tts_pytorch/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,23 +160,22 @@ def __init__(

self.model = model

if self.is_main:
self.ema_model = EMA(
model,
include_online_model = False,
**ema_kwargs
)
self.need_velocity_consistent_loss = model.velocity_consistency_weight > 0.

self.ema_model.to(self.accelerator.device)
self.ema_model = EMA(
model,
include_online_model = False,
**ema_kwargs
)

self.duration_predictor = duration_predictor
self.optimizer = optimizer
self.num_warmup_steps = num_warmup_steps
self.checkpoint_path = default(checkpoint_path, 'model.pth')
self.mel_spectrogram = MelSpec(sampling_rate=self.target_sample_rate)

self.model, self.optimizer = self.accelerator.prepare(
self.model, self.optimizer
self.ema_model, self.model, self.optimizer = self.accelerator.prepare(
self.ema_model, self.model, self.optimizer
)
self.max_grad_norm = max_grad_norm

Expand Down Expand Up @@ -239,11 +238,21 @@ def train(self, train_dataset, epochs, batch_size, num_workers=12, save_step=100
mel_spec = rearrange(batch['mel'], 'b d n -> b n d')
mel_lengths = batch["mel_lengths"]

if self.duration_predictor is not None:
if exists(self.duration_predictor):
dur_loss = self.duration_predictor(mel_spec, lens=batch.get('durations'))
self.writer.add_scalar('duration loss', dur_loss.item(), global_step)

loss, cond, pred, pred_data = self.model(mel_spec, text=text_inputs, lens=mel_lengths)
velocity_consistency_model = None
if self.need_velocity_consistent_loss and self.ema_model.initted:
velocity_consistency_model = self.ema_model.ema_model

loss, cond, pred, pred_data = self.model(
mel_spec,
text=text_inputs,
lens=mel_lengths,
velocity_consistency_model=velocity_consistency_model
)

self.accelerator.backward(loss)

if self.max_grad_norm > 0 and self.accelerator.sync_gradients:
Expand All @@ -253,8 +262,7 @@ def train(self, train_dataset, epochs, batch_size, num_workers=12, save_step=100
self.scheduler.step()
self.optimizer.zero_grad()

if self.is_main:
self.ema_model.update()
self.ema_model.update()

if self.accelerator.is_local_main_process:
logger.info(f"step {global_step+1}: loss = {loss.item():.4f}")
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[project]
name = "e2-tts-pytorch"
version = "1.2.0"
version = "1.2.1"
description = "E2-TTS in Pytorch"
authors = [
{ name = "Phil Wang", email = "[email protected]" }
Expand Down

0 comments on commit 9599691

Please sign in to comment.