From 2763e34b556d071a072e568892d2bdfd389c37b0 Mon Sep 17 00:00:00 2001 From: Tim Esler Date: Mon, 7 Sep 2020 16:50:12 -0700 Subject: [PATCH] Finish tensor compatibility for MTCNN (#117) --- models/mtcnn.py | 13 +++++++++++-- models/utils/detect_face.py | 11 +++++++++-- setup.py | 2 +- 3 files changed, 21 insertions(+), 5 deletions(-) diff --git a/models/mtcnn.py b/models/mtcnn.py index 06fd63cd..bbe3ba12 100644 --- a/models/mtcnn.py +++ b/models/mtcnn.py @@ -248,7 +248,11 @@ def forward(self, img, save_path=None, return_prob=False): # Determine if a batch or single image was passed batch_mode = True - if not isinstance(img, (list, tuple)) and not (isinstance(img, np.ndarray) and len(img.shape) == 4): + if ( + not isinstance(img, (list, tuple)) and + not (isinstance(img, np.ndarray) and len(img.shape) == 4) and + not (isinstance(img, torch.Tensor) and len(img.shape) == 4) + ): img = [img] batch_boxes = [batch_boxes] batch_probs = [batch_probs] @@ -373,7 +377,11 @@ def detect(self, img, landmarks=False): probs = np.array(probs) points = np.array(points) - if not isinstance(img, (list, tuple)) and not (isinstance(img, np.ndarray) and len(img.shape) == 4): + if ( + not isinstance(img, (list, tuple)) and + not (isinstance(img, np.ndarray) and len(img.shape) == 4) and + not (isinstance(img, torch.Tensor) and len(img.shape) == 4) + ): boxes = boxes[0] probs = probs[0] points = points[0] @@ -388,6 +396,7 @@ def fixed_image_standardization(image_tensor): processed_tensor = (image_tensor - 127.5) / 128.0 return processed_tensor + def prewhiten(x): mean = x.mean() std = x.std() diff --git a/models/utils/detect_face.py b/models/utils/detect_face.py index 2ab6cf18..d0671104 100644 --- a/models/utils/detect_face.py +++ b/models/utils/detect_face.py @@ -308,11 +308,18 @@ def imresample(img, sz): def crop_resize(img, box, image_size): if isinstance(img, np.ndarray): + img = img[box[1]:box[3], box[0]:box[2]] out = cv2.resize( - img[box[1]:box[3], box[0]:box[2]], + img, (image_size, image_size), interpolation=cv2.INTER_AREA ).copy() + elif isinstance(img, torch.Tensor): + img = img[box[1]:box[3], box[0]:box[2]] + out = imresample( + img.permute(2, 0, 1).unsqueeze(0).float(), + (image_size, image_size) + ).byte().squeeze(0).permute(1, 2, 0) else: out = img.crop(box).copy().resize((image_size, image_size), Image.BILINEAR) return out @@ -326,7 +333,7 @@ def save_img(img, path): def get_size(img): - if isinstance(img, np.ndarray): + if isinstance(img, (np.ndarray, torch.Tensor)): return img.shape[1::-1] else: return img.size diff --git a/setup.py b/setup.py index c640fa05..1d2b8c18 100644 --- a/setup.py +++ b/setup.py @@ -1,7 +1,7 @@ import setuptools, os PACKAGE_NAME = 'facenet-pytorch' -VERSION = '2.3.1' +VERSION = '2.4.1' AUTHOR = 'Tim Esler' EMAIL = 'tim.esler@gmail.com' DESCRIPTION = 'Pretrained Pytorch face detection and recognition models'