diff --git a/efficientnet/preprocessing.py b/efficientnet/preprocessing.py index ea9ad56..b040756 100644 --- a/efficientnet/preprocessing.py +++ b/efficientnet/preprocessing.py @@ -30,21 +30,38 @@ def center_crop_and_resize(image, image_size, crop_padding=32, interpolation="bi assert image.ndim in {2, 3} assert interpolation in MAP_INTERPOLATION_TO_ORDER.keys() - h, w = image.shape[:2] + in_h, in_w = image.shape[:2] - padded_center_crop_size = int( - (image_size / (image_size + crop_padding)) * min(h, w) - ) - offset_height = ((h - padded_center_crop_size) + 1) // 2 - offset_width = ((w - padded_center_crop_size) + 1) // 2 + if isinstance(image_size, (int, float)): + out_h = out_w = image_size + else: + out_h, out_w = image_size + + if isinstance(crop_padding, (int, float)): + crop_padding_h = crop_padding_w = crop_padding + else: + crop_padding_h, crop_padding_w = crop_padding + + padded_center_crop_shape_post_scaling = (out_h + crop_padding_h, + out_w + crop_padding_w) + + inv_scale = min(in_h / padded_center_crop_shape_post_scaling[0], + in_w / padded_center_crop_shape_post_scaling[1]) + + unpadded_center_crop_size_pre_scaling = (round(out_h * inv_scale), + round(out_w * inv_scale)) + + offset_h = ((in_h - unpadded_center_crop_size_pre_scaling[0]) + 1) // 2 + offset_w = ((in_w - unpadded_center_crop_size_pre_scaling[1]) + 1) // 2 image_crop = image[ - offset_height : padded_center_crop_size + offset_height, - offset_width : padded_center_crop_size + offset_width, + offset_h : unpadded_center_crop_size_pre_scaling[0] + offset_h, + offset_w : unpadded_center_crop_size_pre_scaling[1] + offset_w, ] + resized_image = resize( image_crop, - (image_size, image_size), + (out_h, out_w), order=MAP_INTERPOLATION_TO_ORDER[interpolation], preserve_range=True, )