diff --git a/pytorch3d/implicitron/dataset/frame_data.py b/pytorch3d/implicitron/dataset/frame_data.py index 3d4d6167..137b6324 100644 --- a/pytorch3d/implicitron/dataset/frame_data.py +++ b/pytorch3d/implicitron/dataset/frame_data.py @@ -589,7 +589,8 @@ def build( ), ) - fg_mask_np: Optional[np.ndarray] = None + fg_mask_np: np.ndarray | None = None + bbox_xywh: tuple[float, float, float, float] | None = None mask_annotation = frame_annotation.mask if mask_annotation is not None: if load_blobs and self.load_masks: @@ -598,10 +599,6 @@ def build( frame_data.fg_probability = safe_as_tensor(fg_mask_np, torch.float) bbox_xywh = mask_annotation.bounding_box_xywh - if bbox_xywh is None and fg_mask_np is not None: - bbox_xywh = get_bbox_from_mask(fg_mask_np, self.box_crop_mask_thr) - - frame_data.bbox_xywh = safe_as_tensor(bbox_xywh, torch.float) if frame_annotation.image is not None: image_size_hw = safe_as_tensor(frame_annotation.image.size, torch.long) @@ -618,11 +615,27 @@ def build( if image_path is None: raise ValueError("Image path is required to load images.") - image_np = load_image(self._local_path(image_path)) + no_mask = fg_mask_np is None # didn’t read the mask file + image_np = load_image( + self._local_path(image_path), try_read_alpha=no_mask + ) + if image_np.shape[0] == 4: # RGBA image + if no_mask: + fg_mask_np = image_np[3:] + frame_data.fg_probability = safe_as_tensor( + fg_mask_np, torch.float + ) + + image_np = image_np[:3] + frame_data.image_rgb = self._postprocess_image( image_np, frame_annotation.image.size, frame_data.fg_probability ) + if bbox_xywh is None and fg_mask_np is not None: + bbox_xywh = get_bbox_from_mask(fg_mask_np, self.box_crop_mask_thr) + frame_data.bbox_xywh = safe_as_tensor(bbox_xywh, torch.float) + depth_annotation = frame_annotation.depth if ( load_blobs diff --git a/pytorch3d/implicitron/dataset/utils.py b/pytorch3d/implicitron/dataset/utils.py index b4443584..3e8ba35b 100644 --- a/pytorch3d/implicitron/dataset/utils.py +++ b/pytorch3d/implicitron/dataset/utils.py @@ -87,6 +87,15 @@ def is_train_frame( def get_bbox_from_mask( mask: np.ndarray, thr: float, decrease_quant: float = 0.05 ) -> Tuple[int, int, int, int]: + # these corner cases need to be handled in order to avoid an infinite loop + if mask.size == 0: + warnings.warn("Empty mask is provided for bbox extraction.", stacklevel=1) + return 0, 0, 1, 1 + + if not mask.min() >= 0.0: + warnings.warn("Negative values in the mask for bbox extraction.", stacklevel=1) + mask = mask.clip(min=0.0) + # bbox in xywh masks_for_box = np.zeros_like(mask) while masks_for_box.sum() <= 1.0: @@ -229,9 +238,20 @@ def transpose_normalize_image(image: np.ndarray) -> np.ndarray: return im.astype(np.float32) / 255.0 -def load_image(path: str) -> np.ndarray: +def load_image(path: str, try_read_alpha: bool = False) -> np.ndarray: + """ + Load an image from a path and return it as a numpy array. + If try_read_alpha is True, the image is read as RGBA and the alpha channel is + returned as the fourth channel. + Otherwise, the image is read as RGB and a three-channel image is returned. + """ + with Image.open(path) as pil_im: - im = np.array(pil_im.convert("RGB")) + # Check if the image has an alpha channel + if try_read_alpha and pil_im.mode == "RGBA": + im = np.array(pil_im) + else: + im = np.array(pil_im.convert("RGB")) return transpose_normalize_image(im)