-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathcreateNPY.py
80 lines (56 loc) · 2.19 KB
/
createNPY.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
import os
import glob
import base64
import numpy as np
from ast import literal_eval
# decode vector
float32 = np.dtype('>f4')
# > : big-endian
# < : little-endian
# f4 : float32 (it has 4 bytes and each byte has 8 bits)
def decode_float_list(base64_string):
bytes = base64.b64decode(base64_string)
return np.frombuffer(bytes, dtype=float32).tolist()
def galleryCreate(fileName):
with open(fileName, "r") as f: originData = f.readlines()
total_vec = []
total_id = []
for data in originData:
data = literal_eval(data)
vec = np.asarray(decode_float_list(data['resnet_vector']))
_id = np.asarray(data['id'])
total_vec.append(vec)
total_id.append(_id)
saveName = "pirsData/gallery/"+fileName.split("/")[-1].split(".")[0]
np.save(saveName.replace("resnet1024_*","vector")+".npy", np.array(total_vec))
np.save(saveName.replace("resnet1024_*","id")+".npy", np.array(total_id))
print(">>>", saveName, "saved!")
def queryCreate(fileName):
with open(fileName, "r") as f: originData = f.readlines()
total_vec = {}
total_id = {}
for data in originData:
data = literal_eval(data)
vec = np.asarray(decode_float_list(data['resnet_vector']))
_id = np.asarray(data['id'])
cate = data["cat_key"]
if not cate in total_vec.keys(): total_vec[cate], total_id[cate] = [], []
total_vec[cate].append(vec)
total_id[cate].append(_id)
for key in total_vec.keys():
k = key.lower()[0]+"_"+key[1:]+".npy"
np.save("pirsData/query/vector_"+k, np.array(total_vec[key]))
np.save("pirsData/query/id_"+k, np.array(total_id[key]))
print("query :", k, "saved!")
if __name__ == "__main__":
# create query npy
queryN = "/data/github/elastic/testset_pirs.json"
queryCreate(queryN)
# create gallery npy
fileList = glob.glob("/data/github/elastic/resnet/*.*")
for fileN in fileList:
saveCheck="pirsData/gallery/"+(fileN.split("/")[-1].split(".")[0]).replace("resnet1024_*","vector")+".npy"
if os.path.isfile(saveCheck):
print(saveCheck, "already saved..!")
continue
galleryCreate(fileN)