diff --git a/models/utils/detect_face.py b/models/utils/detect_face.py index 5d14486..6c0c794 100644 --- a/models/utils/detect_face.py +++ b/models/utils/detect_face.py @@ -6,6 +6,7 @@ import numpy as np import os import math +from collections import defaultdict # OpenCV is optional, but required if using numpy arrays instead of PIL try: @@ -108,9 +109,10 @@ def detect_face(imgs, minsize, pnet, rnet, onet, threshold, factor, device): im_data = [] for k in range(len(y)): if ey[k] > (y[k] - 1) and ex[k] > (x[k] - 1): - img_k = imgs[image_inds[k], :, (y[k] - 1):ey[k], (x[k] - 1):ex[k]].unsqueeze(0) - im_data.append(imresample(img_k, (24, 24))) - im_data = torch.cat(im_data, dim=0) + img_k = imgs[image_inds[k], :, (y[k] - 1):ey[k], (x[k] - 1):ex[k]] + im_data.append(img_k) + + im_data = batch_resample_by_size(im_data, (24, 24), device) im_data = (im_data - 127.5) * 0.0078125 # This is equivalent to out = rnet(im_data) to avoid GPU out of memory. @@ -137,9 +139,10 @@ def detect_face(imgs, minsize, pnet, rnet, onet, threshold, factor, device): im_data = [] for k in range(len(y)): if ey[k] > (y[k] - 1) and ex[k] > (x[k] - 1): - img_k = imgs[image_inds[k], :, (y[k] - 1):ey[k], (x[k] - 1):ex[k]].unsqueeze(0) - im_data.append(imresample(img_k, (48, 48))) - im_data = torch.cat(im_data, dim=0) + img_k = imgs[image_inds[k], :, (y[k] - 1):ey[k], (x[k] - 1):ex[k]] + im_data.append(img_k) + + im_data = batch_resample_by_size(im_data, (48, 48), device) im_data = (im_data - 127.5) * 0.0078125 # This is equivalent to out = onet(im_data) to avoid GPU out of memory. @@ -306,6 +309,45 @@ def imresample(img, sz): return im_data +def batch_resample_by_size(imgs, target_size, device): + """ + Batch resampling function grouping by size while preserving order. + + Args: + imgs (list of torch.Tensor): List of image tensors + target_size (tuple): Target size for resampling (height, width) + device (torch.device): Device to perform computation on + + Returns: + torch.Tensor: Batch of resampled images in original order + """ + if not imgs: + return torch.zeros((0, 3, target_size[0], target_size[1]), device=device) + + # Group images by size + size_groups = defaultdict(list) + size_to_indices = defaultdict(list) + for i, img in enumerate(imgs): + size = tuple(img.shape[1:]) + size_groups[size].append(img) + size_to_indices[size].append(i) + + resampled_imgs = torch.zeros(len(imgs), 3, target_size[0], target_size[1], device=device) + + for size, group in size_groups.items(): + # Stack images of the same size + batch = torch.stack(group).to(device) + + # Perform batch resample + resampled = interpolate(batch, size=target_size, mode='area') + + # Put resampled images back in their original positions + for resampled_img, original_idx in zip(resampled, size_to_indices[size]): + resampled_imgs[original_idx] = resampled_img + + return resampled_imgs + + def crop_resize(img, box, image_size): if isinstance(img, np.ndarray): img = img[box[1]:box[3], box[0]:box[2]] diff --git a/setup.py b/setup.py index 1864989..72d3692 100644 --- a/setup.py +++ b/setup.py @@ -1,7 +1,7 @@ import setuptools, os PACKAGE_NAME = 'facenet-pytorch' -VERSION = '2.5.2' +VERSION = '2.5.4' AUTHOR = 'Tim Esler' EMAIL = 'tim.esler@gmail.com' DESCRIPTION = 'Pretrained Pytorch face detection and recognition models' @@ -39,8 +39,8 @@ 'numpy>=1.24.0,<2.0.0', 'Pillow>=10.2.0,<10.3.0', 'requests>=2.0.0,<3.0.0', - 'torch>=2.2.0,<=2.3.0', - 'torchvision>=0.17.0,<=0.18.0', + 'torch>=2.2.0,<=2.4.0', + 'torchvision>=0.17.0,<=0.19.0', 'tqdm>=4.0.0,<5.0.0', ], ) diff --git a/tests/actions_requirements.txt b/tests/actions_requirements.txt index b74924a..4e66bdb 100644 --- a/tests/actions_requirements.txt +++ b/tests/actions_requirements.txt @@ -1,7 +1,7 @@ numpy>=1.24.0,<2.0.0 requests>=2.0.0,<3.0.0 -torch>=2.2.0,<2.3.0 -torchvision>=0.17.0,<0.18.0 +torch>=2.2.0,<=2.4.0 +torchvision>=0.17.0,<=0.19.0 Pillow>=10.2.0,<10.3.0 opencv-python>=4.9.0 scipy>=1.10.0,<2.0.0