-
Notifications
You must be signed in to change notification settings - Fork 16
/
Copy pathget_genre.py
61 lines (52 loc) · 1.93 KB
/
get_genre.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
import numpy as np
import torch
import sys
from collections import Counter
from sklearn.preprocessing import LabelEncoder
from librosa.core import load
from librosa.feature import melspectrogram
from librosa import power_to_db
from model import genreNet
from config import MODELPATH
from config import GENRES
import warnings
warnings.filterwarnings("ignore")
def main(argv):
if len(argv) != 1:
print("Usage: python3 get_genre.py audiopath")
exit()
le = LabelEncoder().fit(GENRES)
# ------------------------------- #
## LOAD TRAINED GENRENET MODEL
net = genreNet()
net.load_state_dict(torch.load(MODELPATH, map_location='cpu'))
# ------------------------------- #
## LOAD AUDIO
audio_path = argv[0]
y, sr = load(audio_path, mono=True, sr=22050)
# ------------------------------- #
## GET CHUNKS OF AUDIO SPECTROGRAMS
S = melspectrogram(y, sr).T
S = S[:-1 * (S.shape[0] % 128)]
num_chunk = S.shape[0] / 128
data_chunks = np.split(S, num_chunk)
# ------------------------------- #
## CLASSIFY SPECTROGRAMS
genres = list()
for i, data in enumerate(data_chunks):
data = torch.FloatTensor(data).view(1, 1, 128, 128)
preds = net(data)
pred_val, pred_index = preds.max(1)
pred_index = pred_index.data.numpy()
pred_val = np.exp(pred_val.data.numpy()[0])
pred_genre = le.inverse_transform(pred_index).item()
if pred_val >= 0.5:
genres.append(pred_genre)
# ------------------------------- #
s = float(sum([v for k,v in dict(Counter(genres)).items()]))
pos_genre = sorted([(k, v/s*100 ) for k,v in dict(Counter(genres)).items()], key=lambda x:x[1], reverse=True)
for genre, pos in pos_genre:
print("%10s: \t%.2f\t%%" % (genre, pos))
return
if __name__ == '__main__':
main(sys.argv[1:])