From 92cc275260dcae7cd4465fc7c39eb2c1b83796b5 Mon Sep 17 00:00:00 2001 From: Fri3dChicken <87434761+AmoghDhaliwal@users.noreply.github.com> Date: Fri, 26 Jul 2024 14:46:58 +1000 Subject: [PATCH 1/4] Update prompt.py parameter to change mask alpha --- fastsam/prompt.py | 29 +++++++++++++++-------------- 1 file changed, 15 insertions(+), 14 deletions(-) diff --git a/fastsam/prompt.py b/fastsam/prompt.py index 4a2b900..4c0354d 100644 --- a/fastsam/prompt.py +++ b/fastsam/prompt.py @@ -4,19 +4,19 @@ import matplotlib.pyplot as plt import numpy as np import torch -from .utils import image_to_np_ndarray +from utils import image_to_np_ndarray from PIL import Image class FastSAMPrompt: - def __init__(self, image, results, device='cuda'): + def __init__(self, image, results, device='cuda') -> None: if isinstance(image, str) or isinstance(image, Image.Image): image = image_to_np_ndarray(image) self.device = device self.results = results self.img = image - + def _segment_image(self, image, bbox): if isinstance(image, Image.Image): image_array = np.array(image) @@ -91,7 +91,8 @@ def plot_to_result(self, mask_random_color=True, better_quality=True, retina=False, - withContours=True) -> np.ndarray: + withContours=True, + mask_alpha=0.6) -> np.ndarray: if isinstance(annotations[0], dict): annotations = [annotation['segmentation'] for annotation in annotations] image = self.img @@ -126,6 +127,7 @@ def plot_to_result(self, retinamask=retina, target_height=original_h, target_width=original_w, + mask_alpha=mask_alpha ) else: if isinstance(annotations[0], np.ndarray): @@ -140,6 +142,7 @@ def plot_to_result(self, retinamask=retina, target_height=original_h, target_width=original_w, + mask_alpha=mask_alpha ) if isinstance(annotations, torch.Tensor): annotations = annotations.cpu().numpy() @@ -189,9 +192,8 @@ def plot(self, mask_random_color=True, better_quality=True, retina=False, - withContours=True): - if len(annotations) == 0: - return None + withContours=True, + mask_alpha=0.6): result = self.plot_to_result( annotations, bboxes, @@ -201,6 +203,7 @@ def plot(self, better_quality, retina, withContours, + mask_alpha ) path = os.path.dirname(os.path.abspath(output_path)) @@ -221,6 +224,7 @@ def fast_show_mask( retinamask=True, target_height=960, target_width=960, + mask_alpha=0.6, ): msak_sum = annotation.shape[0] height = annotation.shape[1] @@ -235,7 +239,7 @@ def fast_show_mask( color = np.random.random((msak_sum, 1, 1, 3)) else: color = np.ones((msak_sum, 1, 1, 3)) * np.array([30 / 255, 144 / 255, 255 / 255]) - transparency = np.ones((msak_sum, 1, 1, 1)) * 0.6 + transparency = np.ones((msak_sum, 1, 1, 1)) * mask_alpha visual = np.concatenate([color, transparency], axis=-1) mask_image = np.expand_dims(annotation, -1) * visual @@ -278,6 +282,7 @@ def fast_show_mask_gpu( retinamask=True, target_height=960, target_width=960, + mask_alpha=0.6 ): msak_sum = annotation.shape[0] height = annotation.shape[1] @@ -292,15 +297,12 @@ def fast_show_mask_gpu( else: color = torch.ones((msak_sum, 1, 1, 3)).to(annotation.device) * torch.tensor([ 30 / 255, 144 / 255, 255 / 255]).to(annotation.device) - transparency = torch.ones((msak_sum, 1, 1, 1)).to(annotation.device) * 0.6 + transparency = torch.ones((msak_sum, 1, 1, 1)).to(annotation.device) * mask_alpha visual = torch.cat([color, transparency], dim=-1) mask_image = torch.unsqueeze(annotation, -1) * visual # Select data according to the index. The index indicates which batch's data to choose at each position, converting the mask_image into a single batch form. show = torch.zeros((height, weight, 4)).to(annotation.device) - try: - h_indices, w_indices = torch.meshgrid(torch.arange(height), torch.arange(weight), indexing='ij') - except: - h_indices, w_indices = torch.meshgrid(torch.arange(height), torch.arange(weight)) + h_indices, w_indices = torch.meshgrid(torch.arange(height), torch.arange(weight), indexing='ij') indices = (index[h_indices, w_indices], h_indices, w_indices, slice(None)) # Use vectorized indexing to update the values of 'show'. show[h_indices, w_indices, :] = mask_image[indices] @@ -453,4 +455,3 @@ def everything_prompt(self): if self.results == None: return [] return self.results[0].masks.data - From 302f485d4ebbe6955a3b508e206979e26403e8c2 Mon Sep 17 00:00:00 2001 From: Fri3dChicken <87434761+AmoghDhaliwal@users.noreply.github.com> Date: Fri, 26 Jul 2024 14:48:59 +1000 Subject: [PATCH 2/4] Update prompt.py Import Bug --- fastsam/prompt.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/fastsam/prompt.py b/fastsam/prompt.py index 4c0354d..dadfdf9 100644 --- a/fastsam/prompt.py +++ b/fastsam/prompt.py @@ -4,7 +4,7 @@ import matplotlib.pyplot as plt import numpy as np import torch -from utils import image_to_np_ndarray +from .utils import image_to_np_ndarray from PIL import Image From 681960222d50807d105dd19c5eefb6361cfb65bd Mon Sep 17 00:00:00 2001 From: Fri3dChicken <87434761+AmoghDhaliwal@users.noreply.github.com> Date: Fri, 26 Jul 2024 14:54:45 +1000 Subject: [PATCH 3/4] Update prompt.py bug --- fastsam/prompt.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/fastsam/prompt.py b/fastsam/prompt.py index dadfdf9..d7eeaf7 100644 --- a/fastsam/prompt.py +++ b/fastsam/prompt.py @@ -302,7 +302,10 @@ def fast_show_mask_gpu( mask_image = torch.unsqueeze(annotation, -1) * visual # Select data according to the index. The index indicates which batch's data to choose at each position, converting the mask_image into a single batch form. show = torch.zeros((height, weight, 4)).to(annotation.device) - h_indices, w_indices = torch.meshgrid(torch.arange(height), torch.arange(weight), indexing='ij') + try: + h_indices, w_indices = torch.meshgrid(torch.arange(height), torch.arange(weight), indexing='ij') + except: + h_indices, w_indices = torch.meshgrid(torch.arange(height), torch.arange(weight)) indices = (index[h_indices, w_indices], h_indices, w_indices, slice(None)) # Use vectorized indexing to update the values of 'show'. show[h_indices, w_indices, :] = mask_image[indices] From 7a278d4d4dab21dbf42ea9e5743e2fcabc09facf Mon Sep 17 00:00:00 2001 From: Fri3dChicken <87434761+AmoghDhaliwal@users.noreply.github.com> Date: Fri, 26 Jul 2024 15:22:48 +1000 Subject: [PATCH 4/4] Update prompt.py Rollback changes --- fastsam/prompt.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/fastsam/prompt.py b/fastsam/prompt.py index d7eeaf7..a14be68 100644 --- a/fastsam/prompt.py +++ b/fastsam/prompt.py @@ -194,6 +194,8 @@ def plot(self, retina=False, withContours=True, mask_alpha=0.6): + if len(annotations) == 0: + return None result = self.plot_to_result( annotations, bboxes,