Skip to content

Commit

Permalink
convenience method for saving mel directly to audio files after sampling
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Sep 10, 2024
1 parent d5cf973 commit d07a255
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 3 deletions.
23 changes: 21 additions & 2 deletions e2_tts_pytorch/e2_tts.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@
"""

from __future__ import annotations

from pathlib import Path
from random import random
from functools import partial
from itertools import zip_longest
Expand Down Expand Up @@ -204,6 +206,7 @@ def __init__(
):
super().__init__()
self.n_mel_channels = n_mel_channels
self.sampling_rate = sampling_rate

self.mel_stft = torchaudio.transforms.MelSpectrogram(
sample_rate = sampling_rate,
Expand Down Expand Up @@ -730,6 +733,7 @@ def __init__(
) = 'char_utf8',
use_vocos = True,
pretrained_vocos_path = 'charactr/vocos-mel-24khz',
sampling_rate: int | None = None
):
super().__init__()

Expand Down Expand Up @@ -764,6 +768,7 @@ def __init__(
num_channels = default(num_channels, self.mel_spec.n_mel_channels)

self.num_channels = num_channels
self.sampling_rate = default(sampling_rate, getattr(self.mel_spec, 'sampling_rate', None))

# whether to concat condition and project rather than project both and sum

Expand Down Expand Up @@ -876,7 +881,8 @@ def sample(
cfg_strength = 1., # they used a classifier free guidance strength of 1.
max_duration = 4096, # in case the duration predictor goes haywire
vocoder: Callable[[Float['b d n']], list[Float['_']]] | None = None,
return_raw_output: bool | None = None
return_raw_output: bool | None = None,
save_to_filename: str | None = None
) -> (
Float['b n d'],
list[Float['_']]
Expand Down Expand Up @@ -977,11 +983,24 @@ def fn(t, x):

one_out = rearrange(one_out, 'n d -> 1 d n')
one_audio = self.vocos.decode(one_out)
one_audio = rearrange(one_audio, '1 nt -> nt')
one_audio = rearrange(one_audio, '1 nw -> nw')
audio.append(one_audio)

out = audio

if exists(save_to_filename):
assert exists(vocoder) or exists(self.vocos)
assert exists(self.sampling_rate)

path = Path(save_to_filename)
parent_path = path.parents[0]
parent_path.mkdir(exist_ok = True, parents = True)

for ind, one_audio in enumerate(out):
one_audio = rearrange(one_audio, 'nw -> 1 nw')
save_path = str(parent_path / f'{ind + 1}.{path.name}')
torchaudio.save(save_path, one_audio.detach().cpu(), sample_rate = self.sampling_rate)

return out

def forward(
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[project]
name = "e2-tts-pytorch"
version = "1.0.2"
version = "1.0.4"
description = "E2-TTS in Pytorch"
authors = [
{ name = "Phil Wang", email = "[email protected]" }
Expand Down

0 comments on commit d07a255

Please sign in to comment.