diff --git a/.gitignore b/.gitignore index 5cb5580..a8e339e 100644 --- a/.gitignore +++ b/.gitignore @@ -3,7 +3,7 @@ *.pyd .DS_Store .idea -weights +weights/*.pt build/ *.egg-info/ gradio_cached_examples \ No newline at end of file diff --git a/Inference.py b/Inference.py index 61b70fc..083839b 100644 --- a/Inference.py +++ b/Inference.py @@ -4,6 +4,9 @@ import torch from PIL import Image from utils.tools import convert_box_xywh_to_xyxy +import glob +from tqdm import tqdm +import os def parse_args(): @@ -12,7 +15,7 @@ def parse_args(): "--model_path", type=str, default="./weights/FastSAM.pt", help="model" ) parser.add_argument( - "--img_path", type=str, default="./images/dogs.jpg", help="path to image file" + "--img_path", type=str, default="./images/", help="This can be a folder or just path to one image (single inference)" ) parser.add_argument("--imgsz", type=int, default=1024, help="image size") parser.add_argument( @@ -28,7 +31,7 @@ def parse_args(): "--conf", type=float, default=0.4, help="object confidence threshold" ) parser.add_argument( - "--output", type=str, default="./output/", help="image save path" + "--output", type=str, default="output", help="folder for saving outputs" ) parser.add_argument( "--randomcolor", type=bool, default=True, help="mask random color" @@ -71,13 +74,13 @@ def parse_args(): return parser.parse_args() -def main(args): - # load model - model = FastSAM(args.model_path) - args.point_prompt = ast.literal_eval(args.point_prompt) - args.box_prompt = convert_box_xywh_to_xyxy(ast.literal_eval(args.box_prompt)) - args.point_label = ast.literal_eval(args.point_label) - input = Image.open(args.img_path) +def single_infer(img_path, model): + image_name = img_path.split("/")[-1] + image_dir = "/".join(img_path.split("/")[:-2]) + output_dir = os.path.join(image_dir,args.output) + if not os.path.exists(output_dir): + os.makedirs(output_dir) + input = Image.open(img_path) input = input.convert("RGB") everything_results = model( input, @@ -86,7 +89,7 @@ def main(args): imgsz=args.imgsz, conf=args.conf, iou=args.iou - ) + ) bboxes = None points = None point_label = None @@ -106,7 +109,7 @@ def main(args): ann = prompt_process.everything_prompt() prompt_process.plot( annotations=ann, - output_path=args.output+args.img_path.split("/")[-1], + output_path=os.path.join(output_dir, image_name), bboxes = bboxes, points = points, point_label = point_label, @@ -115,6 +118,17 @@ def main(args): ) +def main(args): + # load model + model = FastSAM(args.model_path) + args.point_prompt = ast.literal_eval(args.point_prompt) + args.box_prompt = convert_box_xywh_to_xyxy(ast.literal_eval(args.box_prompt)) + args.point_label = ast.literal_eval(args.point_label) + if os.path.isdir(args.img_path): + for img_path in tqdm(glob.glob(os.path.join(args.img_path,"*.jpg"))): + single_infer(img_path, model) + else: + single_infer(args.img_path, model) if __name__ == "__main__": diff --git a/fastsam/prompt.py b/fastsam/prompt.py index dde50ac..9ba7853 100644 --- a/fastsam/prompt.py +++ b/fastsam/prompt.py @@ -101,10 +101,16 @@ def plot_to_result(self, better_quality=True, retina=False, withContours=True) -> np.ndarray: - if isinstance(annotations[0], dict): - annotations = [annotation['segmentation'] for annotation in annotations] + image = self.img image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) + + if len(annotations) == 0: + return cv2.cvtColor(image, cv2.COLOR_RGB2BGR) # if cannot detect anything, return original image + + if isinstance(annotations[0], dict): + annotations = [annotation['segmentation'] for annotation in annotations] + original_h = image.shape[0] original_w = image.shape[1] if sys.platform == "darwin": diff --git a/output/cat.jpg b/output/cat.jpg deleted file mode 100644 index e4764de..0000000 Binary files a/output/cat.jpg and /dev/null differ diff --git a/output/dogs.jpg b/output/dogs.jpg deleted file mode 100644 index a026bce..0000000 Binary files a/output/dogs.jpg and /dev/null differ