Skip to content

Commit

Permalink
Merge branch 'main' into main
Browse files Browse the repository at this point in the history
  • Loading branch information
eamonn-zh authored Jan 16, 2025
2 parents 436ed5a + 49cf5a0 commit 5cf34d2
Show file tree
Hide file tree
Showing 2 changed files with 41 additions and 8 deletions.
25 changes: 19 additions & 6 deletions pytorch3d/implicitron/dataset/frame_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -589,7 +589,8 @@ def build(
),
)

fg_mask_np: Optional[np.ndarray] = None
fg_mask_np: np.ndarray | None = None
bbox_xywh: tuple[float, float, float, float] | None = None
mask_annotation = frame_annotation.mask
if mask_annotation is not None:
if load_blobs and self.load_masks:
Expand All @@ -598,10 +599,6 @@ def build(
frame_data.fg_probability = safe_as_tensor(fg_mask_np, torch.float)

bbox_xywh = mask_annotation.bounding_box_xywh
if bbox_xywh is None and fg_mask_np is not None:
bbox_xywh = get_bbox_from_mask(fg_mask_np, self.box_crop_mask_thr)

frame_data.bbox_xywh = safe_as_tensor(bbox_xywh, torch.float)

if frame_annotation.image is not None:
image_size_hw = safe_as_tensor(frame_annotation.image.size, torch.long)
Expand All @@ -618,11 +615,27 @@ def build(
if image_path is None:
raise ValueError("Image path is required to load images.")

image_np = load_image(self._local_path(image_path))
no_mask = fg_mask_np is None # didn’t read the mask file
image_np = load_image(
self._local_path(image_path), try_read_alpha=no_mask
)
if image_np.shape[0] == 4: # RGBA image
if no_mask:
fg_mask_np = image_np[3:]
frame_data.fg_probability = safe_as_tensor(
fg_mask_np, torch.float
)

image_np = image_np[:3]

frame_data.image_rgb = self._postprocess_image(
image_np, frame_annotation.image.size, frame_data.fg_probability
)

if bbox_xywh is None and fg_mask_np is not None:
bbox_xywh = get_bbox_from_mask(fg_mask_np, self.box_crop_mask_thr)
frame_data.bbox_xywh = safe_as_tensor(bbox_xywh, torch.float)

depth_annotation = frame_annotation.depth
if (
load_blobs
Expand Down
24 changes: 22 additions & 2 deletions pytorch3d/implicitron/dataset/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,15 @@ def is_train_frame(
def get_bbox_from_mask(
mask: np.ndarray, thr: float, decrease_quant: float = 0.05
) -> Tuple[int, int, int, int]:
# these corner cases need to be handled in order to avoid an infinite loop
if mask.size == 0:
warnings.warn("Empty mask is provided for bbox extraction.", stacklevel=1)
return 0, 0, 1, 1

if not mask.min() >= 0.0:
warnings.warn("Negative values in the mask for bbox extraction.", stacklevel=1)
mask = mask.clip(min=0.0)

# bbox in xywh
masks_for_box = np.zeros_like(mask)
while masks_for_box.sum() <= 1.0:
Expand Down Expand Up @@ -229,9 +238,20 @@ def transpose_normalize_image(image: np.ndarray) -> np.ndarray:
return im.astype(np.float32) / 255.0


def load_image(path: str) -> np.ndarray:
def load_image(path: str, try_read_alpha: bool = False) -> np.ndarray:
"""
Load an image from a path and return it as a numpy array.
If try_read_alpha is True, the image is read as RGBA and the alpha channel is
returned as the fourth channel.
Otherwise, the image is read as RGB and a three-channel image is returned.
"""

with Image.open(path) as pil_im:
im = np.array(pil_im.convert("RGB"))
# Check if the image has an alpha channel
if try_read_alpha and pil_im.mode == "RGBA":
im = np.array(pil_im)
else:
im = np.array(pil_im.convert("RGB"))

return transpose_normalize_image(im)

Expand Down

0 comments on commit 5cf34d2

Please sign in to comment.