-
Notifications
You must be signed in to change notification settings - Fork 12
/
Copy pathpredict.py
49 lines (29 loc) · 1.26 KB
/
predict.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
import json
from glob import glob
import numpy as np
from models import transformer_classifier
from prepare_data import random_crop
from audio_processing import load_audio_file
def chunker(seq, size):
return (seq[pos : pos + size] for pos in range(0, len(seq), size))
if __name__ == "__main__":
transformer_v2_h5 = "transformer_v2.h5"
CLASS_MAPPING = json.load(open("/media/ml/data_ml/fma_metadata/mapping.json"))
base_path = "../audio"
files = sorted(list(glob(base_path + "/*.mp3")))
data = [load_audio_file(x, input_length=16000 * 120) for x in files]
transformer_v2_model = transformer_classifier(n_classes=len(CLASS_MAPPING))
transformer_v2_model.load_weights(transformer_v2_h5)
crop_size = np.random.randint(128, 512)
repeats = 8
transformer_v2_Y = 0
for _ in range(repeats):
X = np.array([random_crop(x, crop_size=crop_size) for x in data])
transformer_v2_Y += transformer_v2_model.predict(X) / repeats
transformer_v2_Y = transformer_v2_Y.tolist()
for path, pred in zip(files, transformer_v2_Y):
print(path)
pred_tup = [(k, pred[v]) for k, v in CLASS_MAPPING.items()]
pred_tup.sort(key=lambda x: x[1], reverse=True)
for a in pred_tup[:5]:
print(a)