Skip to content

Commit

Permalink
add a time token
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Aug 17, 2024
1 parent a71e4f8 commit 4d7d836
Show file tree
Hide file tree
Showing 3 changed files with 25 additions and 1 deletion.
11 changes: 11 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -73,3 +73,14 @@ sampled = e2tts.sample(mel[:, :5], text = text)
url = {https://api.semanticscholar.org/CorpusID:263134283}
}
```

```bibtex
@article{Bao2022AllAW,
title = {All are Worth Words: A ViT Backbone for Diffusion Models},
author = {Fan Bao and Shen Nie and Kaiwen Xue and Yue Cao and Chongxuan Li and Hang Su and Jun Zhu},
journal = {2023 IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR)},
year = {2022},
pages = {22669-22679},
url = {https://api.semanticscholar.org/CorpusID:253581703}
}
```
13 changes: 13 additions & 0 deletions e2_tts_pytorch/e2_tts.py
Original file line number Diff line number Diff line change
Expand Up @@ -289,6 +289,8 @@ def __init__(
nn.SiLU()
)

self.to_time_token = nn.Linear(dim, dim, bias = False)

for ind in range(depth):
is_later_half = ind >= (depth // 2)

Expand Down Expand Up @@ -343,6 +345,14 @@ def forward(
times = self.time_cond_mlp(times)
norm_kwargs.update(condition = times)

# u-vit paper claims using a time token helps better condition https://arxiv.org/abs/2209.12152

time_token = self.to_time_token(times)
x, time_packed_shape = pack((time_token, x), 'b * d')

if exists(mask):
mask = F.pad(mask, (1, 0), value = True)

# register tokens

registers = repeat(self.registers, 'r d -> b r d', b = batch)
Expand Down Expand Up @@ -394,6 +404,9 @@ def forward(

_, x = unpack(x, registers_packed_shape, 'b * d')

if exists(times):
_, x = unpack(x, time_packed_shape, 'b * d')

return self.final_norm(x, **norm_kwargs)

# main classes
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[project]
name = "e2-tts-pytorch"
version = "0.3.1"
version = "0.4.0"
description = "E2-TTS in Pytorch"
authors = [
{ name = "Phil Wang", email = "[email protected]" }
Expand Down

0 comments on commit 4d7d836

Please sign in to comment.