forked from PINTO0309/TPU-Posenet
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathpose_camera_single_tpu.py
164 lines (131 loc) · 4.71 KB
/
pose_camera_single_tpu.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
# coding=utf-8
import argparse
import numpy as np
import rclpy
import rclpy.node as node
import rclpy.qos as qos
import sensor_msgs.msg as msg
import std_msgs.msg as std_msg
import sys
sys.path.remove('/opt/ros/kinetic/lib/python2.7/dist-packages')
import cv2
from cv_bridge import CvBridge, CvBridgeError
import time
from PIL import Image
from time import sleep
from edgetpu.basic import edgetpu_utils
from pose_engine import PoseEngine
lastresults = None
processes = []
frameBuffer = None
results = None
EDGES = (
('nose', 'left eye'),
('nose', 'right eye'),
('nose', 'left ear'),
('nose', 'right ear'),
('left ear', 'left eye'),
('right ear', 'right eye'),
('left eye', 'right eye'),
('left shoulder', 'right shoulder'),
('left shoulder', 'left elbow'),
('left shoulder', 'left hip'),
('right shoulder', 'right elbow'),
('right shoulder', 'right hip'),
('left elbow', 'left wrist'),
('right elbow', 'right wrist'),
('left hip', 'right hip'),
('left hip', 'left knee'),
('right hip', 'right knee'),
('left knee', 'left ankle'),
('right knee', 'right ankle'),
)
class TestDisplayNode(node.Node):
def __init__(self,engine):
super().__init__('IProc_TestDisplayNode')
self.__window_name = "ROS2 PoseNet"
profile = qos.QoSProfile()
self.sub = self.create_subscription(msg.Image, '/image', self.msg_callback, qos_profile=profile)
self.pub = self.create_publisher(msg.Image, '/recognized_image')
self.engine = engine
self.fps = ""
self.detectfps = ""
self.framecount = 0
self.detectframecount = 0
self.time1 = 0
self.time2 = 0
self.t1 = time.perf_counter()
self.t2 = time.perf_counter()
self.msg = msg.Image()
self.cv_bridge = CvBridge()
def msg_callback(self, m : msg.Image):
self.msg = m
np_img = np.reshape(m.data, (m.height, m.width, 3)).astype(np.uint8)
self.recognize(np_img)
def display(self, img : np.ndarray):
cv2.imshow(self.__window_name, img)
cv2.waitKey(1)
def draw_pose(self,img, pose, threshold=0.2):
xys = {}
for label, keypoint in pose.keypoints.items():
if keypoint.score < threshold: continue
xys[label] = (int(keypoint.yx[1]), int(keypoint.yx[0]))
img = cv2.circle(img, (int(keypoint.yx[1]), int(keypoint.yx[0])), 5, (0, 255, 0), -1)
for a, b in EDGES:
if a not in xys or b not in xys: continue
ax, ay = xys[a]
bx, by = xys[b]
img = cv2.line(img, (ax, ay), (bx, by), (0, 255, 255), 2)
def overlay_on_image(self,frames, result, model_width, model_height):
color_image = frames
if isinstance(result, type(None)):
return color_image
img_cp = color_image.copy()
for pose in result:
self.draw_pose(img_cp, pose)
cv2.putText(img_cp, self.detectfps,(model_width-170,15), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (38,0,255), 1, cv2.LINE_AA)
return img_cp
def recognize(self, img):
self.t1 = time.perf_counter()
camera_width = 320
camera_height = 240
model_width = 640
model_height = 480
color_image = img
# Run inference.
color_image = cv2.resize(color_image, (model_width, model_height))
prepimg = color_image[:, :, ::-1].copy()
tinf = time.perf_counter()
res, inference_time = self.engine.DetectPosesInImage(prepimg)
if res:
self.detectframecount += 1
imdraw = self.overlay_on_image(color_image, res, model_width, model_height)
else:
imdraw = color_image
#self.display(imdraw)
img_msg = self.cv_bridge.cv2_to_imgmsg(imdraw,encoding="bgr8")
self.pub.publish(img_msg)
# FPS calculation
self.framecount += 1
if self.framecount >= 15:
self.detectfps = "(Detection) {:.1f} FPS".format(self.detectframecount/self.time2)
self.framecount = 0
self.detectframecount = 0
self.time1 = 0
self.time2 = 0
self.t2 = time.perf_counter()
elapsedTime = self.t2 - self.t1
self.time1 += 1/elapsedTime
self.time2 += elapsedTime
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument("--model", default="models/posenet_mobilenet_v1_075_481_641_quant_decoder_edgetpu.tflite", help="Path of the detection model.")
args = parser.parse_args()
model = args.model
engine = PoseEngine(model)
sleep(5)
rclpy.init()
node = TestDisplayNode(engine)
rclpy.spin(node)
node.destroy_node()
rclpy.shutdown()