From 149de254384caaba9da1d1a99ee527398be3f4be Mon Sep 17 00:00:00 2001 From: Jennings Zhang Date: Mon, 27 Nov 2023 18:04:53 -0500 Subject: [PATCH] Fix resized image masking --- emerald/__init__.py | 2 +- emerald/emerald.py | 20 ++++++++++---------- 2 files changed, 11 insertions(+), 11 deletions(-) diff --git a/emerald/__init__.py b/emerald/__init__.py index c8cf5f3..ea5e3f7 100644 --- a/emerald/__init__.py +++ b/emerald/__init__.py @@ -1,4 +1,4 @@ -__version__ = '0.2.0' +__version__ = '0.2.1' DISPLAY_TITLE = r""" _ _ _ diff --git a/emerald/emerald.py b/emerald/emerald.py index 4cd26cb..93ccefd 100644 --- a/emerald/emerald.py +++ b/emerald/emerald.py @@ -101,15 +101,16 @@ def emerald(model: Unet, input_path: str, mask_path: Optional[Path], brain_paths post_processing: bool, footprint: Optional[npt.NDArray]): img_path = str(input_path) - img, hdr = getImageData(img_path) + img_original, hdr = getImageData(img_path) + img_resized = img_original resizeNeeded = False - if img.shape[1] != 256 or img.shape[2] != 256: - original_shape = (img.shape[2], img.shape[1]) - img = __resizeData(img) + if img_original.shape[1] != 256 or img_original.shape[2] != 256: + original_shape = (img_original.shape[2], img_original.shape[1]) + img_resized = __resizeData(img_original) resizeNeeded = True - res = model.predict_mask(img) + res = model.predict_mask(img_resized) if post_processing: res = __postProcessing(res, no_dilation=(footprint is not None), footprint=footprint) @@ -130,12 +131,11 @@ def emerald(model: Unet, input_path: str, mask_path: Optional[Path], brain_paths if brain_paths: # for whatever reason, img.shape=(38, 256, 256, 1). - if len(img.shape) == 4 and img.shape[3] == 1: - img = np.squeeze(img) - img = np.moveaxis(img, 0, -1) + if len(img_original.shape) == 4 and img_original.shape[3] == 1: + img_original = np.squeeze(img_original) + img_original = np.moveaxis(img_original, 0, -1) # apply res mask to img for mult, brain_path in brain_paths: - print(f'img.shape={img.shape}, res.shape={res.shape}') - overlayed_data = np.clip(res, mult, 1.0) * img + overlayed_data = np.clip(res, mult, 1.0) * img_original save(overlayed_data, str(brain_path), hdr)