Skip to content

Commit

Permalink
fade_in_out error when steam = True
Browse files Browse the repository at this point in the history
  • Loading branch information
Shuruthinaya committed Oct 18, 2024
1 parent 027e1cc commit d8cd87f
Show file tree
Hide file tree
Showing 3 changed files with 40 additions and 94 deletions.
119 changes: 35 additions & 84 deletions cosyvoice/bin/train.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
# Copyright (c) 2020 Mobvoi Inc (Di Wu)
# Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu)
#
# Licensed under the Apache License, Version 2.0 (the "License");
Expand All @@ -14,138 +15,88 @@

from __future__ import print_function
import argparse
import datetime
import logging
logging.getLogger('matplotlib').setLevel(logging.WARNING)
from copy import deepcopy
import os
import torch
import torch.distributed as dist
import deepspeed

from hyperpyyaml import load_hyperpyyaml

from torch.distributed.elastic.multiprocessing.errors import record

from copy import deepcopy
from cosyvoice.utils.executor import Executor
from cosyvoice.utils.train_utils import (
init_distributed,
init_dataset_and_dataloader,
init_optimizer_and_scheduler,
init_summarywriter, save_model,
wrap_cuda_model, check_modify_and_save_config)

wrap_cuda_model, check_modify_and_save_config
)
from torch.distributed.elastic.multiprocessing.errors import record

def get_args():
parser = argparse.ArgumentParser(description='training your network')
parser.add_argument('--train_engine',
default='torch_ddp',
choices=['torch_ddp', 'deepspeed'],
help='Engine for paralleled training')
parser.add_argument('--model', required=True, help='model which will be trained')
parser.add_argument('--config', required=True, help='config file')
parser.add_argument('--train_data', required=True, help='train data file')
parser.add_argument('--cv_data', required=True, help='cv data file')
parser.add_argument('--checkpoint', help='checkpoint model')
parser.add_argument('--model_dir', required=True, help='save model dir')
parser.add_argument('--tensorboard_dir',
default='tensorboard',
help='tensorboard log dir')
parser.add_argument('--ddp.dist_backend',
dest='dist_backend',
default='nccl',
choices=['nccl', 'gloo'],
help='distributed backend')
parser.add_argument('--num_workers',
default=0,
type=int,
help='num of subprocess workers for reading')
parser.add_argument('--prefetch',
default=100,
type=int,
help='prefetch number')
parser.add_argument('--pin_memory',
action='store_true',
default=False,
help='Use pinned memory buffers used for reading')
parser.add_argument('--deepspeed.save_states',
dest='save_states',
default='model_only',
choices=['model_only', 'model+optimizer'],
help='save model/optimizer states')
parser.add_argument('--timeout',
default=60,
type=int,
help='timeout (in seconds) of cosyvoice_join.')
parser = argparse.ArgumentParser(description='Training your network')
parser.add_argument('--train_engine', default='torch_ddp', choices=['torch_ddp', 'deepspeed'], help='Engine for parallelized training')
parser.add_argument('--model', required=True, help='Model to be trained')
parser.add_argument('--config', required=True, help='Config file')
parser.add_argument('--train_data', required=True, help='Training data file')
parser.add_argument('--cv_data', required=True, help='CV data file')
parser.add_argument('--checkpoint', help='Checkpoint model path')
parser.add_argument('--model_dir', required=True, help='Directory to save the model')
parser.add_argument('--tensorboard_dir', default='tensorboard', help='Tensorboard log directory')
parser.add_argument('--ddp.dist_backend', dest='dist_backend', default='nccl', choices=['nccl', 'gloo'], help='Distributed backend')
parser.add_argument('--num_workers', default=0, type=int, help='Number of subprocess workers for reading')
parser.add_argument('--prefetch', default=100, type=int, help='Prefetch number')
parser.add_argument('--pin_memory', action='store_true', default=False, help='Use pinned memory buffers for reading')
parser.add_argument('--deepspeed.save_states', dest='save_states', default='model_only', choices=['model_only', 'model+optimizer'], help='Save model/optimizer states')
parser.add_argument('--timeout', default=60, type=int, help='Timeout (in seconds) for cosyvoice_join')
parser = deepspeed.add_config_arguments(parser)
args = parser.parse_args()
return args

return parser.parse_args()

@record
def main():
args = get_args()
logging.basicConfig(level=logging.DEBUG,
format='%(asctime)s %(levelname)s %(message)s')
# gan train has some special initialization logic
gan = True if args.model == 'hifigan' else False
logging.basicConfig(level=logging.DEBUG, format='%(asctime)s %(levelname)s %(message)s')

gan = True if args.model == 'hifigan' else False
override_dict = {k: None for k in ['llm', 'flow', 'hift', 'hifigan'] if k != args.model}
if gan is True:
override_dict.pop('hift')
if gan: override_dict.pop('hift')

with open(args.config, 'r') as f:
configs = load_hyperpyyaml(f, overrides=override_dict)
if gan is True:

if gan:
configs['train_conf'] = configs['train_conf_gan']
configs['train_conf'].update(vars(args))

# Init env for ddp
init_distributed(args)

# Get dataset & dataloader
train_dataset, cv_dataset, train_data_loader, cv_data_loader = \
init_dataset_and_dataloader(args, configs, gan)

# Do some sanity checks and save config to arsg.model_dir
train_dataset, cv_dataset, train_data_loader, cv_data_loader = init_dataset_and_dataloader(args, configs, gan)
configs = check_modify_and_save_config(args, configs)

# Tensorboard summary
writer = init_summarywriter(args)

# load checkpoint
model = configs[args.model]
if args.checkpoint is not None:
if os.path.exists(args.checkpoint):
model.load_state_dict(torch.load(args.checkpoint, map_location='cpu'), strict=False)
else:
logging.warning('checkpoint {} do not exsist!'.format(args.checkpoint))

# Dispatch model from cpu to gpu
if args.checkpoint and os.path.exists(args.checkpoint):
model.load_state_dict(torch.load(args.checkpoint, map_location='cpu'), strict=False)

model = wrap_cuda_model(args, model)

# Get optimizer & scheduler
model, optimizer, scheduler, optimizer_d, scheduler_d = init_optimizer_and_scheduler(args, configs, model, gan)

# Save init checkpoints
info_dict = deepcopy(configs['train_conf'])
save_model(model, 'init', info_dict)

# Get executor
executor = Executor(gan=gan)

# Start training loop
for epoch in range(info_dict['max_epoch']):
executor.epoch = epoch
train_dataset.set_epoch(epoch)
dist.barrier()
group_join = dist.new_group(backend="gloo", timeout=datetime.timedelta(seconds=args.timeout))
if gan is True:
executor.train_one_epoc_gan(model, optimizer, scheduler, optimizer_d, scheduler_d, train_data_loader, cv_data_loader,
writer, info_dict, group_join)

if gan:
executor.train_one_epoc_gan(model, optimizer, scheduler, optimizer_d, scheduler_d, train_data_loader, cv_data_loader, writer, info_dict, group_join)
else:
executor.train_one_epoc(model, optimizer, scheduler, train_data_loader, cv_data_loader, writer, info_dict, group_join)

dist.destroy_process_group(group_join)


if __name__ == '__main__':
main()
14 changes: 4 additions & 10 deletions cosyvoice/llm/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,8 +80,7 @@ def encode(
def pad_unpad_sequence(self, sos_eos_emb, embedding, text_token, text_token_len, task_id_emb, speech_token, speech_token_len):
text_token = unpad_sequence(text_token, text_token_len.cpu(), batch_first=True)
speech_token = unpad_sequence(speech_token, speech_token_len.cpu(), batch_first=True)
lm_input = [torch.concat([sos_eos_emb.squeeze(dim=0), embedding[i], text_token[i], task_id_emb.squeeze(dim=0), speech_token[i]], dim=0)
for i in range(len(text_token))]
lm_input = [torch.concat([sos_eos_emb.squeeze(dim=0), embedding[i], text_token[i], task_id_emb.squeeze(dim=0), speech_token[i]], dim=0) for i in range(len(text_token))]
lm_input_len = torch.tensor([i.size(0) for i in lm_input], dtype=torch.int32)
lm_input = pad_sequence(lm_input, batch_first=True, padding_value=IGNORE_ID)
return lm_input, lm_input_len
Expand All @@ -105,8 +104,7 @@ def forward(
embedding = batch['embedding'].to(device)

# 1. prepare llm_target
lm_target = [torch.tensor([IGNORE_ID] * (2 + text_token_len[i]) + speech_token[i, :speech_token_len[i]].tolist() +
[self.speech_token_size]) for i in range(text_token.size(0))]
lm_target = [torch.tensor([IGNORE_ID] * (2 + text_token_len[i]) + speech_token[i, :speech_token_len[i]].tolist() + [self.speech_token_size]) for i in range(text_token.size(0))]
lm_target = pad_sequence(lm_target, batch_first=True, padding_value=IGNORE_ID).to(device)

# 1. encode text_token
Expand All @@ -126,8 +124,7 @@ def forward(
speech_token = self.speech_embedding(speech_token)

# 5. unpad and pad
lm_input, lm_input_len = self.pad_unpad_sequence(sos_eos_emb, embedding, text_token, text_token_len,
task_id_emb, speech_token, speech_token_len)
lm_input, lm_input_len = self.pad_unpad_sequence(sos_eos_emb, embedding, text_token, text_token_len, task_id_emb, speech_token, speech_token_len)

# 6. run lm forward
lm_output, lm_output_mask = self.llm(lm_input, lm_input_len.to(device))
Expand Down Expand Up @@ -197,10 +194,7 @@ def inference(
offset = 0
att_cache, cnn_cache = torch.zeros((0, 0, 0, 0), device=lm_input.device), torch.zeros((0, 0, 0, 0), device=lm_input.device)
for i in range(max_len):
y_pred, att_cache, cnn_cache = self.llm.forward_chunk(lm_input, offset=offset, required_cache_size=-1,
att_cache=att_cache, cnn_cache=cnn_cache,
att_mask=torch.tril(torch.ones((1, lm_input.shape[1], lm_input.shape[1]),
device=lm_input.device)).to(torch.bool))
y_pred, att_cache, cnn_cache = self.llm.forward_chunk(lm_input, offset=offset, required_cache_size=-1, att_cache=att_cache, cnn_cache=cnn_cache, att_mask=torch.tril(torch.ones((1, lm_input.shape[1], lm_input.shape[1]), device=lm_input.device)).to(torch.bool))
logp = self.llm_decoder(y_pred[:, -1]).log_softmax(dim=-1)
top_ids = self.sampling_ids(logp.squeeze(dim=0), out_tokens, sampling, ignore_eos=True if i < min_len else False).item()
if top_ids == self.speech_token_size:
Expand Down
1 change: 1 addition & 0 deletions cosyvoice/utils/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,7 @@ def fade_in_out(fade_in_mel, fade_out_mel, window):
device = fade_in_mel.device
fade_in_mel, fade_out_mel = fade_in_mel.cpu(), fade_out_mel.cpu()
mel_overlap_len = int(window.shape[0] / 2)
fade_in_mel = fade_in_mel.clone() #clone function will do
fade_in_mel[..., :mel_overlap_len] = fade_in_mel[..., :mel_overlap_len] * window[:mel_overlap_len] + \
fade_out_mel[..., -mel_overlap_len:] * window[mel_overlap_len:]
return fade_in_mel.to(device)
Expand Down

0 comments on commit d8cd87f

Please sign in to comment.