Skip to content

Commit

Permalink
Finish tensor compatibility for MTCNN (#117)
Browse files Browse the repository at this point in the history
  • Loading branch information
timesler authored Sep 7, 2020
1 parent d16c225 commit 2763e34
Show file tree
Hide file tree
Showing 3 changed files with 21 additions and 5 deletions.
13 changes: 11 additions & 2 deletions models/mtcnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -248,7 +248,11 @@ def forward(self, img, save_path=None, return_prob=False):

# Determine if a batch or single image was passed
batch_mode = True
if not isinstance(img, (list, tuple)) and not (isinstance(img, np.ndarray) and len(img.shape) == 4):
if (
not isinstance(img, (list, tuple)) and
not (isinstance(img, np.ndarray) and len(img.shape) == 4) and
not (isinstance(img, torch.Tensor) and len(img.shape) == 4)
):
img = [img]
batch_boxes = [batch_boxes]
batch_probs = [batch_probs]
Expand Down Expand Up @@ -373,7 +377,11 @@ def detect(self, img, landmarks=False):
probs = np.array(probs)
points = np.array(points)

if not isinstance(img, (list, tuple)) and not (isinstance(img, np.ndarray) and len(img.shape) == 4):
if (
not isinstance(img, (list, tuple)) and
not (isinstance(img, np.ndarray) and len(img.shape) == 4) and
not (isinstance(img, torch.Tensor) and len(img.shape) == 4)
):
boxes = boxes[0]
probs = probs[0]
points = points[0]
Expand All @@ -388,6 +396,7 @@ def fixed_image_standardization(image_tensor):
processed_tensor = (image_tensor - 127.5) / 128.0
return processed_tensor


def prewhiten(x):
mean = x.mean()
std = x.std()
Expand Down
11 changes: 9 additions & 2 deletions models/utils/detect_face.py
Original file line number Diff line number Diff line change
Expand Up @@ -308,11 +308,18 @@ def imresample(img, sz):

def crop_resize(img, box, image_size):
if isinstance(img, np.ndarray):
img = img[box[1]:box[3], box[0]:box[2]]
out = cv2.resize(
img[box[1]:box[3], box[0]:box[2]],
img,
(image_size, image_size),
interpolation=cv2.INTER_AREA
).copy()
elif isinstance(img, torch.Tensor):
img = img[box[1]:box[3], box[0]:box[2]]
out = imresample(
img.permute(2, 0, 1).unsqueeze(0).float(),
(image_size, image_size)
).byte().squeeze(0).permute(1, 2, 0)
else:
out = img.crop(box).copy().resize((image_size, image_size), Image.BILINEAR)
return out
Expand All @@ -326,7 +333,7 @@ def save_img(img, path):


def get_size(img):
if isinstance(img, np.ndarray):
if isinstance(img, (np.ndarray, torch.Tensor)):
return img.shape[1::-1]
else:
return img.size
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import setuptools, os

PACKAGE_NAME = 'facenet-pytorch'
VERSION = '2.3.1'
VERSION = '2.4.1'
AUTHOR = 'Tim Esler'
EMAIL = '[email protected]'
DESCRIPTION = 'Pretrained Pytorch face detection and recognition models'
Expand Down

0 comments on commit 2763e34

Please sign in to comment.