Skip to content

Commit

Permalink
fix mistake in multispeaker training
Browse files Browse the repository at this point in the history
prevent style overfitting
  • Loading branch information
yl4579 authored Nov 3, 2023
1 parent d8c49b8 commit 11cb0cb
Showing 1 changed file with 2 additions and 1 deletion.
3 changes: 2 additions & 1 deletion train_first.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,7 @@ def main(config_path):
}

model_params = recursive_munch(config['model_params'])
multispeaker = model_params.multispeaker
model = build_model(model_params, text_aligner, pitch_extractor, plbert)

best_loss = float('inf') # best test loss
Expand Down Expand Up @@ -249,7 +250,7 @@ def main(config_path):
real_norm = log_norm(gt.unsqueeze(1)).squeeze(1).detach()
F0_real, _, _ = model.pitch_extractor(gt.unsqueeze(1))

s = model.style_encoder(gt.unsqueeze(1))
s = model.style_encoder(st.unsqueeze(1) if multispeaker else gt.unsqueeze(1))

y_rec = model.decoder(en, F0_real, real_norm, s)

Expand Down

0 comments on commit 11cb0cb

Please sign in to comment.