From 5ed5bb15c8dc3cc00e47fb84f0ef31c0cadf13e4 Mon Sep 17 00:00:00 2001 From: "lyuxiang.lx" Date: Mon, 11 Nov 2024 22:14:52 +0800 Subject: [PATCH] use stream read to save memory --- cosyvoice/dataset/processor.py | 23 ++++++++++++----------- 1 file changed, 12 insertions(+), 11 deletions(-) diff --git a/cosyvoice/dataset/processor.py b/cosyvoice/dataset/processor.py index 35a1445..e0d3979 100644 --- a/cosyvoice/dataset/processor.py +++ b/cosyvoice/dataset/processor.py @@ -40,17 +40,18 @@ def parquet_opener(data, mode='train', tts_data={}): assert 'src' in sample url = sample['src'] try: - df = pq.read_table(url).to_pandas() - for i in range(len(df)): - if mode == 'inference' and df.loc[i, 'utt'] not in tts_data: - continue - sample.update(dict(df.loc[i])) - if mode == 'train': - # NOTE do not return sample directly, must initialize a new dict - yield {**sample} - else: - for index, text in enumerate(tts_data[df.loc[i, 'utt']]): - yield {**sample, 'tts_index': index, 'tts_text': text} + for df in pq.ParquetFile(url).iter_batches(batch_size=64): + df = df.to_pandas() + for i in range(len(df)): + if mode == 'inference' and df.loc[i, 'utt'] not in tts_data: + continue + sample.update(dict(df.loc[i])) + if mode == 'train': + # NOTE do not return sample directly, must initialize a new dict + yield {**sample} + else: + for index, text in enumerate(tts_data[df.loc[i, 'utt']]): + yield {**sample, 'tts_index': index, 'tts_text': text} except Exception as ex: logging.warning('Failed to open {}, ex info {}'.format(url, ex))