-
Notifications
You must be signed in to change notification settings - Fork 9
/
Copy pathtest_encodec.py
162 lines (133 loc) · 4.83 KB
/
test_encodec.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
#!/usr/bin/env python3
from prefigure.prefigure import get_all_args, push_wandb_config
from contextlib import contextmanager
from copy import deepcopy
import math
from pathlib import Path
import sys, re
import random
import torch
from torch import optim, nn
from torch.nn import functional as F
from torchaudio import transforms as T
from torch.utils import data
from tqdm import trange
from einops import rearrange
import numpy as np
import torchaudio
from functools import partial
import wandb
from dataset.dataset import get_all_s3_urls, get_s3_contents, get_wds_loader, wds_preprocess, log_and_continue, is_valid_sample
import webdataset as wds
import time
def base_plus_ext(path):
"""Split off all file extensions.
Returns base, allext.
:param path: path with extensions
:param returns: path with all extensions removed
"""
match = re.match(r"^((?:.*/|)[^.]+)[.]([^/]*)$", path)
if not match:
return None, None
return match.group(1), match.group(2)
def group_by_keys(data, keys=base_plus_ext, lcase=True, suffixes=None, handler=None):
"""Return function over iterator that groups key, value pairs into samples.
:param keys: function that splits the key into key and extension (base_plus_ext)
:param lcase: convert suffixes to lower case (Default value = True)
"""
print("Running new function")
current_sample = None
for filesample in data:
assert isinstance(filesample, dict)
fname, value = filesample["fname"], filesample["data"]
prefix, suffix = keys(fname)
if wds.tariterators.trace:
print(
prefix,
suffix,
current_sample.keys() if isinstance(current_sample, dict) else None,
)
if prefix is None:
continue
if lcase:
suffix = suffix.lower()
if current_sample is None or prefix != current_sample["__key__"]:
if valid_sample(current_sample):
yield current_sample
current_sample = dict(__key__=prefix, __url__=filesample["__url__"])
if suffix in current_sample:
raise ValueError(
f"{fname}: duplicate file name in tar file {suffix} {current_sample.keys()}"
)
if suffixes is None or suffix in suffixes:
current_sample[suffix] = value
if valid_sample(current_sample):
yield current_sample
def valid_sample(sample):
"""Check whether a sample is valid.
:param sample: sample to be checked
"""
return (
sample is not None
and isinstance(sample, dict)
and len(list(sample.keys())) > 0
and not sample.get("__bad__", False)
)
# Creates and returns a text prompt given a metadata object
def get_prompt_from_metadata(metadata):
print(metadata)
return ""
def main():
args = get_all_args()
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print('Using device:', device)
torch.manual_seed(args.seed)
print("Creating data loader")
preprocess_fn = partial(wds_preprocess,
sample_size=args.sample_size,
sample_rate=args.sample_rate,
random_crop=args.random_crop,
verbose=True,
normalize_lufs=-12.0,
metadata_prompt_funcs={"FMA_stereo": get_prompt_from_metadata}
)
names = [
]
urls = get_all_s3_urls(
names=names,
#s3_url_prefix="",
recursive=True,
)
#urls = ["s3://s-harmonai/datasets/"]
def print_inputs(inputs):
print(f"Sample: {inputs}")
return inputs
wds.tariterators.group_by_keys = group_by_keys
dataset = wds.DataPipeline(
wds.ResampledShards(urls), # Yields a single .tar URL
wds.split_by_worker,
wds.map(print_inputs),
wds.tarfile_to_samples(handler=log_and_continue), # Opens up a stream to the TAR file, yields files grouped by keys
wds.decode(wds.torch_audio, handler=log_and_continue),
wds.map(preprocess_fn, handler=log_and_continue),
wds.select(is_valid_sample),
wds.to_tuple("audio", "json", "timestamps", handler=log_and_continue),
wds.batched(args.batch_size, partial=False)
)
train_dl = wds.WebLoader(dataset, num_workers=args.num_workers)
print("Creating data loader")
#for json in train_dl:
for epoch_num in range(1):
train_iter = iter(train_dl)
print(f"Starting epoch {epoch_num}")
start_time = time.time()
for i, sample in enumerate(train_iter):
#json = next(train_dl)
audio, json, timestamps = sample
print(f"Epoch {epoch_num} Batch {i}")
print(audio.shape)
samples_per_sec = ((i+1) * args.batch_size) / (time.time() - start_time)
print(f"Samples/sec this epoch: {samples_per_sec}")
#time.sleep(5.0)
if __name__ == '__main__':
main()