From b3c60147795b32feac8564b29d8f818682345c63 Mon Sep 17 00:00:00 2001 From: Phil Wang Date: Tue, 5 Nov 2024 07:16:28 -0800 Subject: [PATCH] add the direction loss for flow matching, claimed to accelerate training from a research group out of Wuhan China --- README.md | 9 +++++++++ e2_tts_pytorch/e2_tts.py | 25 ++++++++++++++++++++++--- pyproject.toml | 2 +- 3 files changed, 32 insertions(+), 4 deletions(-) diff --git a/README.md b/README.md index 4ba5310..24818e0 100644 --- a/README.md +++ b/README.md @@ -155,3 +155,12 @@ sampled = e2tts.sample(mel[:, :5], text = text) url = {https://api.semanticscholar.org/CorpusID:273532030} } ``` + +```bibtex +@inproceedings{Yao2024FasterDiTTF, + title = {FasterDiT: Towards Faster Diffusion Transformers Training without Architecture Modification}, + author = {Jingfeng Yao and Wang Cheng and Wenyu Liu and Xinggang Wang}, + year = {2024}, + url = {https://api.semanticscholar.org/CorpusID:273346237} +} +``` diff --git a/e2_tts_pytorch/e2_tts.py b/e2_tts_pytorch/e2_tts.py index 5a6bc3e..4d003e0 100644 --- a/e2_tts_pytorch/e2_tts.py +++ b/e2_tts_pytorch/e2_tts.py @@ -63,7 +63,7 @@ def __getitem__(self, shapes: str): # named tuples -LossBreakdown = namedtuple('LossBreakdown', ['flow', 'velocity_consistency']) +LossBreakdown = namedtuple('LossBreakdown', ['flow', 'velocity_consistency', 'direction']) E2TTSReturn = namedtuple('E2TTS', ['loss', 'cond', 'pred_flow', 'pred_data', 'loss_breakdown']) @@ -909,7 +909,9 @@ def __init__( use_vocos = True, pretrained_vocos_path = 'charactr/vocos-mel-24khz', sampling_rate: int | None = None, + add_direction_loss = False, velocity_consistency_weight = 0., + direction_loss_weight = 1. ): super().__init__() @@ -986,6 +988,11 @@ def __init__( self.register_buffer('zero', torch.tensor(0.), persistent = False) self.velocity_consistency_weight = velocity_consistency_weight + # direction loss for flow matching + + self.add_direction_loss = add_direction_loss + self.direction_loss_weight = direction_loss_weight + # default vocos for mel -> audio self.vocos = Vocos.from_pretrained(pretrained_vocos_path) if use_vocos else None @@ -1317,10 +1324,22 @@ def forward( loss = loss[rand_span_mask].mean() + # maybe direction loss + + direction_loss = self.zero + + if self.add_direction_loss: + direction_loss = ((1. - F.cosine_similarity(pred, flow, dim = 1)) / 2).mean() # make direction loss at most 1. + # total loss and get breakdown - total_loss = loss + velocity_loss * self.velocity_consistency_weight - breakdown = LossBreakdown(loss, velocity_loss) + total_loss = ( + loss + + direction_loss * self.direction_loss_weight + + velocity_loss * self.velocity_consistency_weight + ) + + breakdown = LossBreakdown(loss, velocity_loss, direction_loss) # return total loss and bunch of intermediates diff --git a/pyproject.toml b/pyproject.toml index 58778e4..5ad98bb 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "e2-tts-pytorch" -version = "1.4.1" +version = "1.5.0" description = "E2-TTS in Pytorch" authors = [ { name = "Phil Wang", email = "lucidrains@gmail.com" }