Skip to content

Commit

Permalink
Fix resized image masking
Browse files Browse the repository at this point in the history
  • Loading branch information
jennydaman committed Nov 27, 2023
1 parent fa4a47f commit 149de25
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 11 deletions.
2 changes: 1 addition & 1 deletion emerald/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
__version__ = '0.2.0'
__version__ = '0.2.1'

DISPLAY_TITLE = r"""
_ _ _
Expand Down
20 changes: 10 additions & 10 deletions emerald/emerald.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,15 +101,16 @@ def emerald(model: Unet, input_path: str, mask_path: Optional[Path], brain_paths
post_processing: bool, footprint: Optional[npt.NDArray]):
img_path = str(input_path)

img, hdr = getImageData(img_path)
img_original, hdr = getImageData(img_path)
img_resized = img_original
resizeNeeded = False

if img.shape[1] != 256 or img.shape[2] != 256:
original_shape = (img.shape[2], img.shape[1])
img = __resizeData(img)
if img_original.shape[1] != 256 or img_original.shape[2] != 256:
original_shape = (img_original.shape[2], img_original.shape[1])
img_resized = __resizeData(img_original)
resizeNeeded = True

res = model.predict_mask(img)
res = model.predict_mask(img_resized)

if post_processing:
res = __postProcessing(res, no_dilation=(footprint is not None), footprint=footprint)
Expand All @@ -130,12 +131,11 @@ def emerald(model: Unet, input_path: str, mask_path: Optional[Path], brain_paths

if brain_paths:
# for whatever reason, img.shape=(38, 256, 256, 1).
if len(img.shape) == 4 and img.shape[3] == 1:
img = np.squeeze(img)
img = np.moveaxis(img, 0, -1)
if len(img_original.shape) == 4 and img_original.shape[3] == 1:
img_original = np.squeeze(img_original)
img_original = np.moveaxis(img_original, 0, -1)

# apply res mask to img
for mult, brain_path in brain_paths:
print(f'img.shape={img.shape}, res.shape={res.shape}')
overlayed_data = np.clip(res, mult, 1.0) * img
overlayed_data = np.clip(res, mult, 1.0) * img_original
save(overlayed_data, str(brain_path), hdr)

0 comments on commit 149de25

Please sign in to comment.