-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtest_frame.py
107 lines (82 loc) · 3.54 KB
/
test_frame.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
import os
import cv2
import time
import utils
import pickle
import datetime
import numpy as np
import tensorflow as tf
from modules.models import ArcFaceModel
from scipy.spatial.distance import euclidean
from modules.utils import set_memory_growth, load_yaml, l2_norm
from preprocess import prepare_facebank, load_facebank, align_multi
path =os.getcwd()
file = open(path+"/embds_dict.pkl", 'rb')
data = pickle.load(file)
names = list(data.keys())
def main():
cfg = load_yaml(path+'/configs/arc_res50.yaml')
model = ArcFaceModel(size=cfg['input_size'],
backbone_type=cfg['backbone_type'],
training=False)
ckpt_path = tf.train.latest_checkpoint(path+'/checkpoints/' + cfg['sub_name'])
if ckpt_path is not None:
print("[*] load ckpt from {}".format(ckpt_path))
model.load_weights(ckpt_path)
else:
print("[*] Cannot find ckpt from {}.".format(ckpt_path))
exit()
cap = cv2.VideoCapture(0)
count =0
skip_frame= 10
while cap.isOpened():
print(count)
if (count% skip_frame == 0):
strat_time =datetime.datetime.now()
is_success, frame = cap.read()
if is_success:
img = frame
bboxes, landmarks, faces = align_multi(cfg, img, min_confidence=0.97, limits=5)
bboxes = bboxes.astype(int)
embs = []
for face in faces:
if len(face.shape) == 3:
face = np.expand_dims(face, 0)
face = face.astype(np.float32) / 255.
embs.append(l2_norm(model(face)).numpy())
list_min_idx = []
list_score = []
for emb in embs:
dist = [euclidean(emb, target) for target in data.values()]
min_idx = np.argmin(dist)
list_min_idx.append(min_idx)
list_score.append(dist[int(min_idx)])
list_min_idx = np.array(list_min_idx)
print(list_min_idx)
list_score = np.array(list_score)
if list_score.any()==False:
continue
if list_score[0] < 1:
list_min_idx[list_score > 1.2] = -1
print("############# Face matched #############\n\n")
for idx, box in enumerate(bboxes):
print("Name of the Detected Person:\n\n", names[list_min_idx[idx]])
print('\n\n')
print("***********************************")
frame = utils.draw_box_name(box,
landmarks[idx],
names[list_min_idx[idx]],
frame)
else:
print("########### Warning ##########")
for idx, box in enumerate(bboxes):
frame = utils.draw_box_name(box,
landmarks[idx],
"unknown",
frame)
cv2.imwrite('./save_frame/'+str(strat_time)+'.png',frame)
cv2.imshow('face Capture', frame)
count+=1
cap.release()
cv2.destroyAllWindows()
main()