diff --git a/fastsam/prompt.py b/fastsam/prompt.py index 4a2b900..a14be68 100644 --- a/fastsam/prompt.py +++ b/fastsam/prompt.py @@ -10,13 +10,13 @@ class FastSAMPrompt: - def __init__(self, image, results, device='cuda'): + def __init__(self, image, results, device='cuda') -> None: if isinstance(image, str) or isinstance(image, Image.Image): image = image_to_np_ndarray(image) self.device = device self.results = results self.img = image - + def _segment_image(self, image, bbox): if isinstance(image, Image.Image): image_array = np.array(image) @@ -91,7 +91,8 @@ def plot_to_result(self, mask_random_color=True, better_quality=True, retina=False, - withContours=True) -> np.ndarray: + withContours=True, + mask_alpha=0.6) -> np.ndarray: if isinstance(annotations[0], dict): annotations = [annotation['segmentation'] for annotation in annotations] image = self.img @@ -126,6 +127,7 @@ def plot_to_result(self, retinamask=retina, target_height=original_h, target_width=original_w, + mask_alpha=mask_alpha ) else: if isinstance(annotations[0], np.ndarray): @@ -140,6 +142,7 @@ def plot_to_result(self, retinamask=retina, target_height=original_h, target_width=original_w, + mask_alpha=mask_alpha ) if isinstance(annotations, torch.Tensor): annotations = annotations.cpu().numpy() @@ -189,7 +192,8 @@ def plot(self, mask_random_color=True, better_quality=True, retina=False, - withContours=True): + withContours=True, + mask_alpha=0.6): if len(annotations) == 0: return None result = self.plot_to_result( @@ -201,6 +205,7 @@ def plot(self, better_quality, retina, withContours, + mask_alpha ) path = os.path.dirname(os.path.abspath(output_path)) @@ -221,6 +226,7 @@ def fast_show_mask( retinamask=True, target_height=960, target_width=960, + mask_alpha=0.6, ): msak_sum = annotation.shape[0] height = annotation.shape[1] @@ -235,7 +241,7 @@ def fast_show_mask( color = np.random.random((msak_sum, 1, 1, 3)) else: color = np.ones((msak_sum, 1, 1, 3)) * np.array([30 / 255, 144 / 255, 255 / 255]) - transparency = np.ones((msak_sum, 1, 1, 1)) * 0.6 + transparency = np.ones((msak_sum, 1, 1, 1)) * mask_alpha visual = np.concatenate([color, transparency], axis=-1) mask_image = np.expand_dims(annotation, -1) * visual @@ -278,6 +284,7 @@ def fast_show_mask_gpu( retinamask=True, target_height=960, target_width=960, + mask_alpha=0.6 ): msak_sum = annotation.shape[0] height = annotation.shape[1] @@ -292,7 +299,7 @@ def fast_show_mask_gpu( else: color = torch.ones((msak_sum, 1, 1, 3)).to(annotation.device) * torch.tensor([ 30 / 255, 144 / 255, 255 / 255]).to(annotation.device) - transparency = torch.ones((msak_sum, 1, 1, 1)).to(annotation.device) * 0.6 + transparency = torch.ones((msak_sum, 1, 1, 1)).to(annotation.device) * mask_alpha visual = torch.cat([color, transparency], dim=-1) mask_image = torch.unsqueeze(annotation, -1) * visual # Select data according to the index. The index indicates which batch's data to choose at each position, converting the mask_image into a single batch form. @@ -453,4 +460,3 @@ def everything_prompt(self): if self.results == None: return [] return self.results[0].masks.data -