diff --git a/deploy/python/det_keypoint_unite_infer.py b/deploy/python/det_keypoint_unite_infer.py index 7b57714d18..445685c0a7 100644 --- a/deploy/python/det_keypoint_unite_infer.py +++ b/deploy/python/det_keypoint_unite_infer.py @@ -24,14 +24,15 @@ from preprocess import decode_image from infer import Detector, DetectorPicoDet, PredictConfig, print_arguments, get_test_images, bench_log from keypoint_infer import KeyPointDetector, PredictConfig_KeyPoint -from visualize import visualize_pose +from visualize import visualize_pose, visualize_pose_point131 from benchmark_utils import PaddleInferBenchmark from utils import get_current_memory_mb from keypoint_postprocess import translate_to_ori_images KEYPOINT_SUPPORT_MODELS = { 'HigherHRNet': 'keypoint_bottomup', - 'HRNet': 'keypoint_topdown' + 'HRNet': 'keypoint_topdown', + 'VitPose_TopDown_WholeBody': 'keypoint_topdown_wholebody' } @@ -178,7 +179,7 @@ def topdown_unite_predict_video(detector, keypoint_res['keypoint'][0][0] = smooth_keypoints.tolist() - im = visualize_pose( + im = visualize_pose_point131( frame, keypoint_res, visual_thresh=FLAGS.keypoint_threshold, @@ -329,8 +330,7 @@ def main(): enable_mkldnn=FLAGS.enable_mkldnn, use_dark=FLAGS.use_dark) keypoint_arch = topdown_keypoint_detector.pred_config.arch - assert KEYPOINT_SUPPORT_MODELS[ - keypoint_arch] == 'keypoint_topdown', 'Detection-Keypoint unite inference only supports topdown models.' + assert KEYPOINT_SUPPORT_MODELS[keypoint_arch] == 'keypoint_topdown' or KEYPOINT_SUPPORT_MODELS[keypoint_arch] == 'keypoint_topdown_wholebody', 'Detection-Keypoint unite inference only supports topdown models.' # predict from video file or camera video stream if FLAGS.video_file is not None or FLAGS.camera_id != -1: diff --git a/deploy/python/infer.py b/deploy/python/infer.py index d1790e6a7d..ec7c7591eb 100644 --- a/deploy/python/infer.py +++ b/deploy/python/infer.py @@ -34,7 +34,7 @@ from benchmark_utils import PaddleInferBenchmark from picodet_postprocess import PicoDetPostProcess from preprocess import preprocess, Resize, NormalizeImage, Permute, PadStride, LetterBoxResize, WarpAffine, Pad, decode_image, CULaneResize -from keypoint_preprocess import EvalAffine, TopDownEvalAffine, expand_crop +from keypoint_preprocess import EvalAffine, TopDownEvalAffine, TopDownAffineImage, expand_crop from clrnet_postprocess import CLRNetPostProcess from visualize import visualize_box_mask, imshow_lanes from utils import argsparser, Timer, get_current_memory_mb, multiclass_nms, coco_clsid2catid diff --git a/deploy/python/keypoint_infer.py b/deploy/python/keypoint_infer.py index 39e195bf56..d5dc139e68 100644 --- a/deploy/python/keypoint_infer.py +++ b/deploy/python/keypoint_infer.py @@ -42,10 +42,40 @@ # Global dictionary KEYPOINT_SUPPORT_MODELS = { 'HigherHRNet': 'keypoint_bottomup', - 'HRNet': 'keypoint_topdown' + 'HRNet': 'keypoint_topdown', + 'VitPose_TopDown_WholeBody': 'keypoint_topdown_wholebody' } +def _box2cs(image_size, box): + """This encodes bbox(x,y,w,h) into (center, scale) + + Args: + x, y, w, h + + Returns: + tuple: A tuple containing center and scale. + + - np.ndarray[float32](2,): Center of the bbox (x, y). + - np.ndarray[float32](2,): Scale of the bbox w & h. + """ + + x, y, w, h = box[:4] + input_size = image_size + aspect_ratio = input_size[0] / input_size[1] + center = np.array([x + w * 0.5, y + h * 0.5], dtype=np.float32) + + if w > aspect_ratio * h: + h = w * 1.0 / aspect_ratio + elif w < aspect_ratio * h: + w = h * aspect_ratio + + # pixel std is 200.0 + scale = np.array([w / 200.0, h / 200.0], dtype=np.float32) + scale = scale * 1.25 + + return center, scale + class KeyPointDetector(Detector): """ Args: @@ -137,6 +167,23 @@ def postprocess(self, inputs, result): imshape = inputs['im_shape'][:, ::-1] center = np.round(imshape / 2.) scale = imshape / 200. + keypoint_postprocess = HRNetPostProcess(use_dark=self.use_dark) + kpts, scores = keypoint_postprocess(np_heatmap, center, scale) + results['keypoint'] = kpts + results['score'] = scores + return results + elif KEYPOINT_SUPPORT_MODELS[ + self.pred_config.arch] == 'keypoint_topdown_wholebody': + results = {} + imshape = inputs['im_shape'][:, ::-1] + center = [] + scale = [] + for i in range(len(inputs['im_shape'])): + transize = np.shape(inputs["image"]) + tmp_center, tmp_scale = _box2cs([np.shape(inputs["image"])[-1],np.shape(inputs["image"])[-2]], [0,0,inputs['im_shape'][i][1],inputs['im_shape'][i][0]] ) + center.append(tmp_center) + scale.append(tmp_scale) + keypoint_postprocess = HRNetPostProcess(use_dark=self.use_dark) kpts, scores = keypoint_postprocess(np_heatmap, center, scale) results['keypoint'] = kpts diff --git a/deploy/python/keypoint_preprocess.py b/deploy/python/keypoint_preprocess.py index b4e50e887a..ddd5f5f5e5 100644 --- a/deploy/python/keypoint_preprocess.py +++ b/deploy/python/keypoint_preprocess.py @@ -18,6 +18,83 @@ import numpy as np +def _box2cs(image_size, box): + """This encodes bbox(x,y,w,h) into (center, scale) + + Args: + x, y, w, h + + Returns: + tuple: A tuple containing center and scale. + + - np.ndarray[float32](2,): Center of the bbox (x, y). + - np.ndarray[float32](2,): Scale of the bbox w & h. + """ + + x, y, w, h = box[:4] + input_size = image_size + aspect_ratio = input_size[0] / input_size[1] + center = np.array([x + w * 0.5, y + h * 0.5], dtype=np.float32) + + if w > aspect_ratio * h: + h = w * 1.0 / aspect_ratio + elif w < aspect_ratio * h: + w = h * aspect_ratio + + # pixel std is 200.0 + scale = np.array([w / 200.0, h / 200.0], dtype=np.float32) + scale = scale * 1.25 + + return center, scale + +class TopDownAffineImage(object): + """apply affine transform to image and coords + + Args: + trainsize (list): [w, h], the standard size used to train + use_udp (bool): whether to use Unbiased Data Processing. + records(dict): the dict contained the image and coords + + Returns: + records (dict): contain the image and coords after tranformed + + """ + + def __init__(self, trainsize, use_udp=False, use_box2cs=True): + self.trainsize = trainsize + self.use_udp = use_udp + self.use_box2cs = use_box2cs + + def __call__(self, records, im_info): + if self.use_box2cs: + center, scale = _box2cs(self.trainsize, [0,0,im_info['im_shape'][1],im_info['im_shape'][0]]) + else: + imshape = im_info['im_shape'][::-1] + center = im_info['center'] if 'center' in im_info else imshape / 2. + scale = im_info['scale'] if 'scale' in im_info else imshape + + image = records + rot = records['rotate'] if "rotate" in records else 0 + if self.use_udp: + trans = get_warp_matrix( + rot, center * 2.0, + [self.trainsize[0] - 1.0, self.trainsize[1] - 1.0], + scale * 200.0) + image = cv2.warpAffine( + image, + trans, (int(self.trainsize[0]), int(self.trainsize[1])), + flags=cv2.INTER_LINEAR) + joints[:, 0:2] = warp_affine_joints(joints[:, 0:2].copy(), trans) + else: + trans = get_affine_transform(center, scale * + 200, rot, self.trainsize) + image = cv2.warpAffine( + image, + trans, (int(self.trainsize[0]), int(self.trainsize[1])), + flags=cv2.INTER_LINEAR) + return image, im_info + + class EvalAffine(object): def __init__(self, size, stride=64): super(EvalAffine, self).__init__() diff --git a/deploy/python/visualize.py b/deploy/python/visualize.py index 9e96c29df0..aa640a9173 100644 --- a/deploy/python/visualize.py +++ b/deploy/python/visualize.py @@ -20,6 +20,15 @@ import numpy as np import PIL from PIL import Image, ImageDraw, ImageFile +import json + + +from mmengine.structures import InstanceData +from mmpose.structures import PoseDataSample +from mmpose.visualization import PoseLocalVisualizer + +from mmpose.structures import merge_data_samples, split_instances + ImageFile.LOAD_TRUNCATED_IMAGES = True def imagedraw_textsize_c(draw, text): @@ -235,6 +244,48 @@ def get_color(idx): return color +def visualize_pose_point131(imgfile, + results, + visual_thresh=0.3, + save_name='pose.jpg', + save_dir='output', + returnimg=False, + ids=None): + pose_local_visualizer = PoseLocalVisualizer(vis_backends= [{'type': 'LocalVisBackend'}], name= 'visualizer', radius= 3, alpha= 0.8, line_width= 1) + # with open("/paddle/mmpose-dev-1.x/dataset_meta.json", 'r') as f: + with open("deploy/python/dataset_meta.json", 'r') as f: + meta_data = json.load(f) + + pred_instances = InstanceData() + pose_local_visualizer.set_dataset_meta(meta_data, skeleton_style="mmpose") + image = cv2.imread(imgfile) if type(imgfile) == str else imgfile + skeletons, score = results['keypoint'] + keypoints = [] + scores = [] + for i in range(len(skeletons[0])): + keypoints.append([skeletons[0][i][0], skeletons[0][i][1]]) + scores.append(skeletons[0][i][2]) + keypoints = [keypoints] + skeletons = np.array(skeletons) + scores = np.array(scores) + pred_instances.keypoints = skeletons + + pred_pose_data_sample = PoseDataSample() + pred_pose_data_sample.pred_instances = pred_instances + + blank_image = np.zeros(image.shape, dtype=np.uint8) + pose_local_visualizer.add_datasample('image', blank_image, data_sample=pred_pose_data_sample, + draw_gt=False, + draw_heatmap=False, + draw_bbox=True, + show_kpt_idx=False, + skeleton_style='mmpose', + show=False, + wait_time=0, + kpt_thr=visual_thresh) + + return pose_local_visualizer.get_image() + def visualize_pose(imgfile, results, visual_thresh=0.6, diff --git a/ppdet/engine/export_utils.py b/ppdet/engine/export_utils.py index 62acd1db67..58c1ea8bef 100644 --- a/ppdet/engine/export_utils.py +++ b/ppdet/engine/export_utils.py @@ -55,10 +55,11 @@ 'YOLOF': 40, 'METRO_Body': 3, 'DETR': 3, - 'CLRNet': 3 + 'CLRNet': 3, + 'VitPose_TopDown_WholeBody': 3 } -KEYPOINT_ARCH = ['HigherHRNet', 'TopDownHRNet'] +KEYPOINT_ARCH = ['HigherHRNet', 'TopDownHRNet', 'VitPose_TopDown_WholeBody'] MOT_ARCH = ['JDE', 'FairMOT', 'DeepSORT', 'ByteTrack', 'CenterTrack'] LANE_ARCH = ['CLRNet'] diff --git a/ppdet/modeling/architectures/__init__.py b/ppdet/modeling/architectures/__init__.py index d22df32d85..278e701b95 100644 --- a/ppdet/modeling/architectures/__init__.py +++ b/ppdet/modeling/architectures/__init__.py @@ -26,6 +26,7 @@ from . import keypoint_hrhrnet from . import keypoint_hrnet from . import keypoint_vitpose +from . import keypoint_vitpose_wholebody from . import jde from . import deepsort from . import fairmot @@ -60,6 +61,7 @@ from .keypoint_hrhrnet import * from .keypoint_hrnet import * from .keypoint_vitpose import * +from .keypoint_vitpose_wholebody import * from .jde import * from .deepsort import * from .fairmot import * diff --git a/ppdet/modeling/architectures/keypoint_vitpose_wholebody.py b/ppdet/modeling/architectures/keypoint_vitpose_wholebody.py new file mode 100644 index 0000000000..97811f35f6 --- /dev/null +++ b/ppdet/modeling/architectures/keypoint_vitpose_wholebody.py @@ -0,0 +1,565 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function +from typing import Tuple +import paddle +import numpy as np +import math +import cv2 +from ppdet.core.workspace import register, create, serializable +from .meta_arch import BaseArch +from ..keypoint_utils import transform_preds +from .. import layers as L + +__all__ = ['VitPose_TopDown_WholeBody', 'VitPoseWholeBodyPostProcess'] + + +def _get_max_preds(heatmaps): + """Get keypoint predictions from score maps. + + Note: + batch_size: N + num_keypoints: K + heatmap height: H + heatmap width: W + + Args: + heatmaps (np.ndarray[N, K, H, W]): model predicted heatmaps. + + Returns: + tuple: A tuple containing aggregated results. + + - preds (np.ndarray[N, K, 2]): Predicted keypoint location. + - maxvals (np.ndarray[N, K, 1]): Scores (confidence) of the keypoints. + """ + assert isinstance(heatmaps, + np.ndarray), ('heatmaps should be numpy.ndarray') + assert heatmaps.ndim == 4, 'batch_images should be 4-ndim' + + N, K, _, W = heatmaps.shape + heatmaps_reshaped = heatmaps.reshape((N, K, -1)) + idx = np.argmax(heatmaps_reshaped, 2).reshape((N, K, 1)) + maxvals = np.amax(heatmaps_reshaped, 2).reshape((N, K, 1)) + + preds = np.tile(idx, (1, 1, 2)).astype(np.float32) + preds[:, :, 0] = preds[:, :, 0] % W + preds[:, :, 1] = preds[:, :, 1] // W + + preds = np.where(np.tile(maxvals, (1, 1, 2)) > 0.0, preds, -1) + return preds, maxvals + +def _taylor(heatmap, coord): + """Distribution aware coordinate decoding method. + + Note: + - heatmap height: H + - heatmap width: W + + Args: + heatmap (np.ndarray[H, W]): Heatmap of a particular joint type. + coord (np.ndarray[2,]): Coordinates of the predicted keypoints. + + Returns: + np.ndarray[2,]: Updated coordinates. + """ + H, W = heatmap.shape[:2] + px, py = int(coord[0]), int(coord[1]) + if 1 < px < W - 2 and 1 < py < H - 2: + dx = 0.5 * (heatmap[py][px + 1] - heatmap[py][px - 1]) + dy = 0.5 * (heatmap[py + 1][px] - heatmap[py - 1][px]) + dxx = 0.25 * ( + heatmap[py][px + 2] - 2 * heatmap[py][px] + heatmap[py][px - 2]) + dxy = 0.25 * ( + heatmap[py + 1][px + 1] - heatmap[py - 1][px + 1] - + heatmap[py + 1][px - 1] + heatmap[py - 1][px - 1]) + dyy = 0.25 * ( + heatmap[py + 2 * 1][px] - 2 * heatmap[py][px] + + heatmap[py - 2 * 1][px]) + derivative = np.array([[dx], [dy]]) + hessian = np.array([[dxx, dxy], [dxy, dyy]]) + if dxx * dyy - dxy**2 != 0: + hessianinv = np.linalg.inv(hessian) + offset = -hessianinv @ derivative + offset = np.squeeze(np.array(offset.T), axis=0) + coord += offset + return coord + +def _box2cs(image_size, box): + """This encodes bbox(x,y,w,h) into (center, scale) + + Args: + x, y, w, h + + Returns: + tuple: A tuple containing center and scale. + + - np.ndarray[float32](2,): Center of the bbox (x, y). + - np.ndarray[float32](2,): Scale of the bbox w & h. + """ + + x, y, w, h = box[:4] + aspect_ratio = image_size[0] / image_size[1] + center = np.array([x + w * 0.5, y + h * 0.5], dtype=np.float32) + + if w > aspect_ratio * h: + h = w * 1.0 / aspect_ratio + elif w < aspect_ratio * h: + w = h * aspect_ratio + + # pixel std is 200.0 + scale = np.array([w / 200.0, h / 200.0], dtype=np.float32) + scale = scale * 1.25 + + return center, scale + +def _gaussian_blur(heatmaps, kernel=11): + """Modulate heatmap distribution with Gaussian. + sigma = 0.3*((kernel_size-1)*0.5-1)+0.8 + sigma~=3 if k=17 + sigma=2 if k=11; + sigma~=1.5 if k=7; + sigma~=1 if k=3; + + Note: + - batch_size: N + - num_keypoints: K + - heatmap height: H + - heatmap width: W + + Args: + heatmaps (np.ndarray[N, K, H, W]): model predicted heatmaps. + kernel (int): Gaussian kernel size (K) for modulation, which should + match the heatmap gaussian sigma when training. + K=17 for sigma=3 and k=11 for sigma=2. + + Returns: + np.ndarray ([N, K, H, W]): Modulated heatmap distribution. + """ + assert kernel % 2 == 1 + + border = (kernel - 1) // 2 + batch_size = heatmaps.shape[0] + num_joints = heatmaps.shape[1] + height = heatmaps.shape[2] + width = heatmaps.shape[3] + for i in range(batch_size): + for j in range(num_joints): + origin_max = np.max(heatmaps[i, j]) + dr = np.zeros((height + 2 * border, width + 2 * border), + dtype=np.float32) + dr[border:-border, border:-border] = heatmaps[i, j].copy() + dr = cv2.GaussianBlur(dr, (kernel, kernel), 0) + heatmaps[i, j] = dr[border:-border, border:-border].copy() + heatmaps[i, j] *= origin_max / np.max(heatmaps[i, j]) + return heatmaps + +def keypoints_from_heatmaps(heatmaps, + center, + scale, + unbiased=False, + post_process='default', + kernel=11, + valid_radius_factor=0.0546875, + use_udp=False, + target_type='GaussianHeatmap'): + """Get final keypoint predictions from heatmaps and transform them back to + the image. + + Note: + - batch size: N + - num keypoints: K + - heatmap height: H + - heatmap width: W + + Args: + heatmaps (np.ndarray[N, K, H, W]): model predicted heatmaps. + center (np.ndarray[N, 2]): Center of the bounding box (x, y). + scale (np.ndarray[N, 2]): Scale of the bounding box + wrt height/width. + post_process (str/None): Choice of methods to post-process + heatmaps. Currently supported: None, 'default', 'unbiased', + 'megvii'. + unbiased (bool): Option to use unbiased decoding. Mutually + exclusive with megvii. + Note: this arg is deprecated and unbiased=True can be replaced + by post_process='unbiased' + Paper ref: Zhang et al. Distribution-Aware Coordinate + Representation for Human Pose Estimation (CVPR 2020). + kernel (int): Gaussian kernel size (K) for modulation, which should + match the heatmap gaussian sigma when training. + K=17 for sigma=3 and k=11 for sigma=2. + valid_radius_factor (float): The radius factor of the positive area + in classification heatmap for UDP. + use_udp (bool): Use unbiased data processing. + target_type (str): 'GaussianHeatmap' or 'CombinedTarget'. + GaussianHeatmap: Classification target with gaussian distribution. + CombinedTarget: The combination of classification target + (response map) and regression target (offset map). + Paper ref: Huang et al. The Devil is in the Details: Delving into + Unbiased Data Processing for Human Pose Estimation (CVPR 2020). + + Returns: + tuple: A tuple containing keypoint predictions and scores. + + - preds (np.ndarray[N, K, 2]): Predicted keypoint location in images. + - maxvals (np.ndarray[N, K, 1]): Scores (confidence) of the keypoints. + """ + # Avoid being affected + # heatmaps = heatmaps.copy() + + # detect conflicts + # if unbiased: + # assert post_process not in [False, None, 'megvii'] + if post_process in ['megvii', 'unbiased']: + assert kernel > 0 + if use_udp: + assert not post_process == 'megvii' + + # normalize configs + if post_process is False: + warnings.warn( + 'post_process=False is deprecated, ' + 'please use post_process=None instead', DeprecationWarning) + post_process = None + elif post_process is True: + if unbiased is True: + warnings.warn( + 'post_process=True, unbiased=True is deprecated,' + " please use post_process='unbiased' instead", + DeprecationWarning) + post_process = 'unbiased' + else: + warnings.warn( + 'post_process=True, unbiased=False is deprecated, ' + "please use post_process='default' instead", + DeprecationWarning) + post_process = 'default' + elif post_process == 'default': + if unbiased is True: + warnings.warn( + 'unbiased=True is deprecated, please use ' + "post_process='unbiased' instead", DeprecationWarning) + post_process = 'unbiased' + + # start processing + if post_process == 'megvii': + heatmaps = _gaussian_blur(heatmaps, kernel=kernel) + + N, K, H, W = heatmaps.shape + if use_udp: + if target_type.lower() == 'GaussianHeatMap'.lower(): + preds, maxvals = _get_max_preds(heatmaps) + preds = post_dark_udp(preds, heatmaps, kernel=kernel) + elif target_type.lower() == 'CombinedTarget'.lower(): + for person_heatmaps in heatmaps: + for i, heatmap in enumerate(person_heatmaps): + kt = 2 * kernel + 1 if i % 3 == 0 else kernel + cv2.GaussianBlur(heatmap, (kt, kt), 0, heatmap) + # valid radius is in direct proportion to the height of heatmap. + valid_radius = valid_radius_factor * H + offset_x = heatmaps[:, 1::3, :].flatten() * valid_radius + offset_y = heatmaps[:, 2::3, :].flatten() * valid_radius + heatmaps = heatmaps[:, ::3, :] + preds, maxvals = _get_max_preds(heatmaps) + index = preds[..., 0] + preds[..., 1] * W + index += W * H * np.arange(0, N * K / 3) + index = index.astype(int).reshape(N, K // 3, 1) + preds += np.concatenate((offset_x[index], offset_y[index]), axis=2) + else: + raise ValueError('target_type should be either ' + "'GaussianHeatmap' or 'CombinedTarget'") + else: + preds, maxvals = _get_max_preds(heatmaps) + if post_process == 'unbiased': # alleviate biased coordinate + # apply Gaussian distribution modulation. + heatmaps = np.log( + np.maximum(_gaussian_blur(heatmaps, kernel), 1e-10)) + for n in range(N): + for k in range(K): + preds[n][k] = _taylor(heatmaps[n][k], preds[n][k]) + elif post_process is not None: + # add +/-0.25 shift to the predicted locations for higher acc. + for n in range(N): + for k in range(K): + heatmap = heatmaps[n][k] + px = int(preds[n][k][0]) + py = int(preds[n][k][1]) + if 1 < px < W - 1 and 1 < py < H - 1: + diff = np.array([ + heatmap[py][px + 1] - heatmap[py][px - 1], + heatmap[py + 1][px] - heatmap[py - 1][px] + ]) + preds[n][k] += np.sign(diff) * .25 + if post_process == 'megvii': + preds[n][k] += 0.5 + + # Transform back to the image + for i in range(N): + preds[i] = transform_preds( + preds[i], center[i], scale[i], [W, H]) + + if post_process == 'megvii': + maxvals = maxvals / 255.0 + 0.5 + + return preds, maxvals + + +# def post_process_vitpose + +@register +@serializable +class VitPoseWholeBodyPostProcess(object): + def __call__(self, img_metas, output, **kwargs): + """Decode keypoints from heatmaps. + + Args: + img_metas (list(dict)): Information about data augmentation + By default this includes: + + - "image_file: path to the image file + - "center": center of the bbox + - "scale": scale of the bbox + - "rotation": rotation of the bbox + - "bbox_score": score of bbox + output (np.ndarray[N, K, H, W]): model predicted heatmaps. + """ + img_metas = [{'center': img_metas['center'], 'scale': img_metas['scale'], 'rotation': 0, 'bbox_score': 1, 'flip_pairs': [[1, 2], [3, 4], [5, 6], [7, 8], [9, 10], [11, 12], [13, 14], [15, 16], [17, 20], [18, 21], [19, 22], [23, 39], [24, 38], [25, 37], [26, 36], [27, 35], [28, 34], [29, 33], [30, 32], [40, 49], [41, 48], [42, 47], [43, 46], [44, 45], [54, 58], [55, 57], [59, 68], [60, 67], [61, 66], [62, 65], [63, 70], [64, 69], [71, 77], [72, 76], [73, 75], [78, 82], [79, 81], [83, 87], [84, 86], [88, 90], [91, 112], [92, 113], [93, 114], [94, 115], [95, 116], [96, 117], [97, 118], [98, 119], [99, 120], [100, 121], [101, 122], [102, 123], [103, 124], [104, 125], [105, 126], [106, 127], [107, 128], [108, 129], [109, 130], [110, 131], [111, 132]], 'bbox_id': 0}] + + + batch_size = len(img_metas) + + if 'bbox_id' in img_metas[0]: + bbox_ids = [] + else: + bbox_ids = None + + c = np.zeros((batch_size, 2), dtype=np.float32) + s = np.zeros((batch_size, 2), dtype=np.float32) + image_paths = [] + score = np.ones(batch_size) + for i in range(batch_size): + c[i, :] = img_metas[i]['center'] + s[i, :] = img_metas[i]['scale'] + + if 'bbox_score' in img_metas[i]: + score[i] = np.array(img_metas[i]['bbox_score']).reshape(-1) + if bbox_ids is not None: + bbox_ids.append(img_metas[i]['bbox_id']) + + preds, maxvals = keypoints_from_heatmaps( + output, + c, + s, + unbiased=False, + post_process='unbiased', + kernel=17, + valid_radius_factor=0.0546875, + use_udp=False, + target_type='GaussianHeatmap' + ) + + all_preds = np.zeros((batch_size, preds.shape[1], 3), dtype=np.float32) + all_boxes = np.zeros((batch_size, 6), dtype=np.float32) + all_preds[:, :, 0:2] = preds[:, :, 0:2] + all_preds[:, :, 2:3] = maxvals + all_boxes[:, 0:2] = c[:, 0:2] + all_boxes[:, 2:4] = s[:, 0:2] + all_boxes[:, 4] = np.prod(s * 200.0, axis=1) + all_boxes[:, 5] = score + + result = {} + + result['preds'] = all_preds + result['boxes'] = all_boxes + result['bbox_ids'] = bbox_ids + + return result + +@register +class VitPose_TopDown_WholeBody(BaseArch): + __category__ = 'architecture' + __inject__ = ['loss'] + + def __init__(self, backbone, head, loss, flip_test): + """ + VitPose network, see https://arxiv.org/pdf/2204.12484v2.pdf + + Args: + backbone (nn.Layer): backbone instance + post_process (object): `HRNetPostProcess` instance + + """ + super(VitPose_TopDown_WholeBody, self).__init__() + self.backbone = backbone + self.head = head + self.loss = loss + self.flip_test = flip_test + + + @classmethod + def from_config(cls, cfg, *args, **kwargs): + backbone = create(cfg['backbone']) + head = create(cfg['head']) + + return { + 'backbone': backbone, + 'head': head, + } + + + + + def _forward_train(self): + + feats = self.backbone.forward_features(self.inputs['image']) + vitpost_output = self.head(feats) + return self.loss(vitpost_output, self.inputs) + + def _forward_test(self,bbox=None): + feats = self.backbone.forward_features(self.inputs['image']) + print("feats") + print(feats) + output_heatmap = self.head(feats) + + if self.flip_test: + img_flipped = self.inputs['image'].flip(3) + features_flipped = self.backbone.forward_features(img_flipped) + output_flipped_heatmap = self.head.inference_model(features_flipped, + self.flip_test) + + output_heatmap = (output_heatmap + output_flipped_heatmap) * 0.5 + + imshape = (self.inputs['im_shape'].numpy() + )[:, ::-1] if 'im_shape' in self.inputs else none + return output_heatmap + + def get_loss(self): + return self._forward_train() + + def get_pred(self): + res_lst = self._forward_test() + outputs = {'keypoint': res_lst} + return outputs + +def _get_3rd_point(a, b): + """To calculate the affine matrix, three pairs of points are required. This + function is used to get the 3rd point, given 2D points a & b. + + The 3rd point is defined by rotating vector `a - b` by 90 degrees + anticlockwise, using b as the rotation center. + + Args: + a (np.ndarray): point(x,y) + b (np.ndarray): point(x,y) + + Returns: + np.ndarray: The 3rd point. + """ + assert len(a) == 2 + assert len(b) == 2 + direction = a - b + third_pt = b + np.array([-direction[1], direction[0]], dtype=np.float32) + + return third_pt + +def rotate_point(pt, angle_rad): + """Rotate a point by an angle. + + Args: + pt (list[float]): 2 dimensional point to be rotated + angle_rad (float): rotation angle by radian + + Returns: + list[float]: Rotated point. + """ + assert len(pt) == 2 + sn, cs = np.sin(angle_rad), np.cos(angle_rad) + new_x = pt[0] * cs - pt[1] * sn + new_y = pt[0] * sn + pt[1] * cs + rotated_pt = [new_x, new_y] + + return rotated_pt + +def get_affine_transform(center, + scale, + rot, + output_size, + shift=(0., 0.), + inv=False): + """Get the affine transform matrix, given the center/scale/rot/output_size. + + Args: + center (np.ndarray[2, ]): Center of the bounding box (x, y). + scale (np.ndarray[2, ]): Scale of the bounding box + wrt [width, height]. + rot (float): Rotation angle (degree). + output_size (np.ndarray[2, ] | list(2,)): Size of the + destination heatmaps. + shift (0-100%): Shift translation ratio wrt the width/height. + Default (0., 0.). + inv (bool): Option to inverse the affine transform direction. + (inv=False: src->dst or inv=True: dst->src) + + Returns: + np.ndarray: The transform matrix. + """ + assert len(center) == 2 + assert len(scale) == 2 + assert len(output_size) == 2 + assert len(shift) == 2 + + # pixel_std is 200. + scale_tmp = scale * 200.0 + + shift = np.array(shift) + src_w = scale_tmp[0] + dst_w = output_size[0] + dst_h = output_size[1] + + rot_rad = np.pi * rot / 180 + src_dir = rotate_point([0., src_w * -0.5], rot_rad) + dst_dir = np.array([0., dst_w * -0.5]) + + src = np.zeros((3, 2), dtype=np.float32) + src[0, :] = center + scale_tmp * shift + src[1, :] = center + src_dir + scale_tmp * shift + src[2, :] = _get_3rd_point(src[0, :], src[1, :]) + + dst = np.zeros((3, 2), dtype=np.float32) + dst[0, :] = [dst_w * 0.5, dst_h * 0.5] + dst[1, :] = np.array([dst_w * 0.5, dst_h * 0.5]) + dst_dir + dst[2, :] = _get_3rd_point(dst[0, :], dst[1, :]) + + if inv: + trans = cv2.getAffineTransform(np.float32(dst), np.float32(src)) + else: + trans = cv2.getAffineTransform(np.float32(src), np.float32(dst)) + + return trans + + + + +@register +@serializable +class VitPosePreProcess(object): + def __init__(self, use_dark=False): + self.use_dark = use_dark + + def __call__(self, input=None): + trans = get_affine_transform(np.array([124. , 180.5]), np.array([1.6921875, 2.25625]), 0, np.array([288, 384])) + return input diff --git a/ppdet/modeling/backbones/hrnet.py b/ppdet/modeling/backbones/hrnet.py index 977edd69e9..00a071c0c7 100644 --- a/ppdet/modeling/backbones/hrnet.py +++ b/ppdet/modeling/backbones/hrnet.py @@ -858,6 +858,46 @@ def forward(self, inputs): return res + def forward_features(self, inputs): + x = inputs + conv1 = self.conv_layer1_1(x) + conv2 = self.conv_layer1_2(conv1) + + la1 = self.la1(conv2) + tr1 = self.tr1([la1]) + st2 = self.st2(tr1) + tr2 = self.tr2(st2) + + st3 = self.st3(tr2) + tr3 = self.tr3(st3) + + st4 = self.st4(tr3) + + if self.upsample: + # Upsampling + x0_h, x0_w = st4[0].shape[2:4] + x1 = F.upsample(st4[1], size=(x0_h, x0_w), mode='bilinear') + x2 = F.upsample(st4[2], size=(x0_h, x0_w), mode='bilinear') + x3 = F.upsample(st4[3], size=(x0_h, x0_w), mode='bilinear') + x = paddle.concat([st4[0], x1, x2, x3], 1) + return x + + if self.downsample: + y = self.incre_modules[0](st4[0]) + for i in range(len(self.downsamp_modules)): + y = self.incre_modules[i+1](st4[i+1]) + \ + self.downsamp_modules[i](y) + y = self.final_layer(y) + return y + + res = [] + for i, layer in enumerate(st4): + if i == self.freeze_at: + layer.stop_gradient = True + if i in self.return_idx: + res.append(layer) + return res + @property def out_shape(self): if self.upsample: diff --git a/ppdet/modeling/heads/vitpose_head.py b/ppdet/modeling/heads/vitpose_head.py index 43908ed57b..3f1ce2f68d 100644 --- a/ppdet/modeling/heads/vitpose_head.py +++ b/ppdet/modeling/heads/vitpose_head.py @@ -198,7 +198,6 @@ def forward(self, x): x = self._transform_inputs(x) x = self.deconv_layers(x) x = self.final_layer(x) - return x def inference_model(self, x, flip_pairs=None):