Skip to content

Commit

Permalink
Fixed CenterCropPad incorrectly using fill value for the mask (#2195)
Browse files Browse the repository at this point in the history
  • Loading branch information
iRyoka authored Dec 13, 2024
1 parent 8e30be0 commit 852dfef
Show file tree
Hide file tree
Showing 2 changed files with 44 additions and 0 deletions.
20 changes: 20 additions & 0 deletions albumentations/augmentations/crops/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,6 +175,26 @@ def apply(
)
return super().apply(img, crop_coords, **params)

def apply_to_mask(
self,
mask: np.ndarray,
crop_coords: Any,
**params: Any,
) -> np.ndarray:
pad_params = params.get("pad_params")
if pad_params is not None:
mask = fgeometric.pad_with_params(
mask,
pad_params["pad_top"],
pad_params["pad_bottom"],
pad_params["pad_left"],
pad_params["pad_right"],
border_mode=self.border_mode,
value=self.fill_mask,
)
# Note' that super().apply would apply the padding twice as it is looped to this.apply
return BaseCrop.apply(self, mask, crop_coords=crop_coords, **params)

def apply_to_bboxes(
self,
bboxes: np.ndarray,
Expand Down
24 changes: 24 additions & 0 deletions tests/test_crop.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,3 +211,27 @@ def test_pad_position_equivalence(
result2["keypoints"],
err_msg=f"Keypoints don't match for position {pad_position}"
)

def test_base_crop_and_pad_fill():
# tests whether BaseCropAndPad usues correct values for constant borders
c = A.CenterCrop(4, 4, pad_if_needed=True, fill=100, fill_mask=200)
c1 = A.CenterCrop(4, 4, pad_if_needed=True, fill=201)

im = np.zeros((2, 6, 3)).astype(np.float32)
msk = np.zeros((2, 6)).astype(np.uint8)

out = c(image=im, mask=msk)
out1 = c1(image=im, mask=msk)

expected_img = np.ones((4, 4, 3)).astype(np.float32)
expected_img[1:3, ...] = 0

expected_msk = np.ones((4, 4)).astype(np.uint8)
expected_msk[1:3, ...] = 0

assert np.all(out["image"] == expected_img * 100)
assert np.all(out["mask"] == expected_msk * 200)


assert np.all(out1["image"] == expected_img * 201)
assert np.all(out1["mask"] == expected_msk * 0) # 0 is the default for fill_mask

0 comments on commit 852dfef

Please sign in to comment.