diff --git a/.gitignore b/.gitignore index 89f6739a..af37122a 100644 --- a/.gitignore +++ b/.gitignore @@ -7,4 +7,5 @@ assets/*gif *.pyc debug cym_utils -/src \ No newline at end of file +/src +/tracking_results diff --git a/app.py b/app.py index f238906b..6b47871d 100644 --- a/app.py +++ b/app.py @@ -42,7 +42,7 @@ def get_click_prompt(click_stack, point): def get_meta_from_video(input_video): if input_video is None: return None, None, None, "" - + print("get meta information of input video") cap = cv2.VideoCapture(input_video) @@ -283,7 +283,7 @@ def seg_track_app(): input_img_seq = gr.File(label='Input Image-Seq').style(height=550) with gr.Column(scale=0.25): extract_button = gr.Button(value="extract") - fps = gr.Slider(label='fps', minimum=5, maximum=50, value=30, step=1) + fps = gr.Slider(label='fps', minimum=5, maximum=50, value=8, step=1) input_first_frame = gr.Image(label='Segment result of first frame',interactive=True).style(height=550) diff --git a/assets/840_iSXIa0hE8Ek.zip b/assets/840_iSXIa0hE8Ek.zip index f960f179..97b3ab5c 100644 Binary files a/assets/840_iSXIa0hE8Ek.zip and b/assets/840_iSXIa0hE8Ek.zip differ diff --git a/seg_track_anything.py b/seg_track_anything.py index d95d1832..44bb55d1 100644 --- a/seg_track_anything.py +++ b/seg_track_anything.py @@ -53,6 +53,12 @@ def draw_mask(img, mask, alpha=0.5, id_countour=False): return img_mask.astype(img.dtype) +def create_dir(dir_path): + if os.path.isdir(dir_path): + os.system(f"rm -r {dir_path}") + + os.makedirs(dir_path) + aot_model2ckpt = { "deaotb": "./ckpt/DeAOTB_PRE_YTB_DAV.pth", "deaotl": "./ckpt/DeAOTL_PRE_YTB_DAV", @@ -63,28 +69,44 @@ def draw_mask(img, mask, alpha=0.5, id_countour=False): def tracking_objects_in_video(SegTracker, input_video, input_img_seq, fps): if input_video is not None: - return video_type_input_tracking(SegTracker, input_video) + video_name = os.path.basename(input_video).split('.')[0] elif input_img_seq is not None: - return img_seq_type_input_tracking(SegTracker, input_img_seq, fps) - - return None, None + file_name = input_img_seq.name.split('/')[-1].split('.')[0] + file_path = f'./assets/{file_name}' + imgs_path = sorted([os.path.join(file_path, img_name) for img_name in os.listdir(file_path)]) + video_name = file_name + else: + return None, None -def video_type_input_tracking(SegTracker, input_video): - video_name = os.path.basename(input_video).split('.')[0] + # create dir to save result + tracking_result_dir = f'{os.path.join(os.path.dirname(__file__), "tracking_results", f"{video_name}")}' + create_dir(tracking_result_dir) + io_args = { - 'input_video': f'{input_video}', - 'output_mask_dir': f'{os.path.join(os.path.dirname(__file__), "assets")}/{video_name}_masks', - 'output_video': f'{os.path.join(os.path.dirname(__file__), "assets")}/{video_name}_seg.mp4', # keep same format as input video - 'output_gif': f'{os.path.join(os.path.dirname(__file__), "assets")}/{video_name}_seg.gif', + 'tracking_result_dir': tracking_result_dir, + 'output_mask_dir': f'{tracking_result_dir}/{video_name}_masks', + 'output_masked_frame_dir': f'{tracking_result_dir}/{video_name}_masked_frames', + 'output_video': f'{tracking_result_dir}/{video_name}_seg.mp4', # keep same format as input video + 'output_gif': f'{tracking_result_dir}/{video_name}_seg.gif', } + if input_video is not None: + return video_type_input_tracking(SegTracker, input_video, io_args, video_name) + elif input_img_seq is not None: + return img_seq_type_input_tracking(SegTracker, io_args, video_name, imgs_path, fps) + + +def video_type_input_tracking(SegTracker, input_video, io_args, video_name): + # source video to segment cap = cv2.VideoCapture(input_video) fps = cap.get(cv2.CAP_PROP_FPS) - # output masks - output_dir = io_args['output_mask_dir'] - if not os.path.exists(output_dir): - os.makedirs(output_dir) + + # create dir to save predicted mask and masked frame + output_mask_dir = io_args['output_mask_dir'] + create_dir(io_args['output_mask_dir']) + create_dir(io_args['output_masked_frame_dir']) + pred_list = [] masked_pred_list = [] @@ -111,7 +133,7 @@ def video_type_input_tracking(SegTracker, input_video): track_mask = SegTracker.track(frame) # find new objects, and update tracker with new objects new_obj_mask = SegTracker.find_new_objs(track_mask,seg_mask) - save_prediction(new_obj_mask,output_dir,str(frame_idx)+'_new.png') + save_prediction(new_obj_mask, output_mask_dir, str(frame_idx).zfill(5) + '_new.png') pred_mask = track_mask + new_obj_mask # segtracker.restart_tracker() SegTracker.add_reference(frame, pred_mask) @@ -120,7 +142,7 @@ def video_type_input_tracking(SegTracker, input_video): torch.cuda.empty_cache() gc.collect() - save_prediction(pred_mask,output_dir,str(frame_idx)+'.png') + save_prediction(pred_mask, output_mask_dir, str(frame_idx).zfill(5) + '.png') pred_list.append(pred_mask) print("processed frame {}, obj_num {}".format(frame_idx, SegTracker.get_obj_num()),end='\r') @@ -133,16 +155,16 @@ def video_type_input_tracking(SegTracker, input_video): ################## # draw pred mask on frame and save as a video - cap = cv2.VideoCapture(io_args['input_video']) + cap = cv2.VideoCapture(input_video) fps = cap.get(cv2.CAP_PROP_FPS) width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)) height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) num_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) fourcc = cv2.VideoWriter_fourcc(*"mp4v") - # if io_args['input_video'][-3:]=='mp4': + # if input_video[-3:]=='mp4': # fourcc = cv2.VideoWriter_fourcc(*"mp4v") - # elif io_args['input_video'][-3:] == 'avi': + # elif input_video[-3:] == 'avi': # fourcc = cv2.VideoWriter_fourcc(*"MJPG") # # fourcc = cv2.VideoWriter_fourcc(*"XVID") # else: @@ -158,8 +180,9 @@ def video_type_input_tracking(SegTracker, input_video): frame = cv2.cvtColor(frame,cv2.COLOR_BGR2RGB) pred_mask = pred_list[frame_idx] masked_frame = draw_mask(frame, pred_mask) - masked_pred_list.append(masked_frame) + cv2.imwrite(f"{io_args['output_masked_frame_dir']}/{frame_idx}.png", masked_frame[:, :, ::-1]) + masked_pred_list.append(masked_frame) masked_frame = cv2.cvtColor(masked_frame,cv2.COLOR_RGB2BGR) out.write(masked_frame) print('frame {} writed'.format(frame_idx),end='\r') @@ -184,22 +207,13 @@ def video_type_input_tracking(SegTracker, input_video): return io_args['output_video'], f"./assets/{video_name}_pred_mask.zip" -def img_seq_type_input_tracking(SegTracker, input_img_seq, fps): - file_name = input_img_seq.name.split('/')[-1].split('.')[0] - file_path = f'./assets/{file_name}' - imgs_path = sorted([os.path.join(file_path, img_name) for img_name in os.listdir(file_path)]) +def img_seq_type_input_tracking(SegTracker, io_args, video_name, imgs_path, fps): + + # create dir to save predicted mask and masked frame + output_mask_dir = io_args['output_mask_dir'] + create_dir(io_args['output_mask_dir']) + create_dir(io_args['output_masked_frame_dir']) - video_name = file_name - io_args = { - 'output_mask_dir': f'{os.path.join(os.path.dirname(__file__), "assets")}/{video_name}_masks', - 'output_video': f'{os.path.join(os.path.dirname(__file__), "assets")}/{video_name}_seg.mp4', # keep same format as input video - 'output_gif': f'{os.path.join(os.path.dirname(__file__), "assets")}/{video_name}_seg.gif', - } - - # output masks - output_dir = io_args['output_mask_dir'] - if not os.path.exists(output_dir): - os.makedirs(output_dir) pred_list = [] masked_pred_list = [] @@ -210,6 +224,7 @@ def img_seq_type_input_tracking(SegTracker, input_img_seq, fps): with torch.cuda.amp.autocast(): for img_path in imgs_path: + frame_name = os.path.basename(img_path).split('.')[0] frame = cv2.imread(img_path) frame = cv2.cvtColor(frame,cv2.COLOR_BGR2RGB) @@ -224,7 +239,7 @@ def img_seq_type_input_tracking(SegTracker, input_img_seq, fps): track_mask = SegTracker.track(frame) # find new objects, and update tracker with new objects new_obj_mask = SegTracker.find_new_objs(track_mask,seg_mask) - save_prediction(new_obj_mask,output_dir,str(frame_idx)+'_new.png') + save_prediction(new_obj_mask, output_mask_dir, f'{frame_name}_new.png') pred_mask = track_mask + new_obj_mask # segtracker.restart_tracker() SegTracker.add_reference(frame, pred_mask) @@ -233,7 +248,7 @@ def img_seq_type_input_tracking(SegTracker, input_img_seq, fps): torch.cuda.empty_cache() gc.collect() - save_prediction(pred_mask,output_dir,str(frame_idx)+'.png') + save_prediction(pred_mask, output_mask_dir, f'{frame_name}.png') pred_list.append(pred_mask) print("processed frame {}, obj_num {}".format(frame_idx, SegTracker.get_obj_num()),end='\r') @@ -248,13 +263,6 @@ def img_seq_type_input_tracking(SegTracker, input_img_seq, fps): height, width = pred_list[0].shape fourcc = cv2.VideoWriter_fourcc(*"mp4v") - # if io_args['input_video'][-3:]=='mp4': - # fourcc = cv2.VideoWriter_fourcc(*"mp4v") - # elif io_args['input_video'][-3:] == 'avi': - # fourcc = cv2.VideoWriter_fourcc(*"MJPG") - # # fourcc = cv2.VideoWriter_fourcc(*"XVID") - # else: - # fourcc = int(cap.get(cv2.CAP_PROP_FOURCC)) out = cv2.VideoWriter(io_args['output_video'], fourcc, fps, (width, height)) frame_idx = 0 @@ -265,6 +273,7 @@ def img_seq_type_input_tracking(SegTracker, input_img_seq, fps): pred_mask = pred_list[frame_idx] masked_frame = draw_mask(frame, pred_mask) masked_pred_list.append(masked_frame) + cv2.imwrite(f"{io_args['output_masked_frame_dir']}/{frame_idx}.png", masked_frame[:, :, ::-1]) masked_frame = cv2.cvtColor(masked_frame,cv2.COLOR_RGB2BGR) out.write(masked_frame) @@ -279,7 +288,7 @@ def img_seq_type_input_tracking(SegTracker, input_img_seq, fps): print("{} saved".format(io_args['output_gif'])) # zip predicted mask - os.system(f"zip -r ./assets/{video_name}_pred_mask.zip {io_args['output_mask_dir']}") + os.system(f"zip -r {io_args['tracking_result_dir']}/{video_name}_pred_mask.zip {io_args['output_mask_dir']}") # manually release memory (after cuda out of memory) del SegTracker @@ -287,4 +296,4 @@ def img_seq_type_input_tracking(SegTracker, input_img_seq, fps): gc.collect() - return io_args['output_video'], f"./assets/{video_name}_pred_mask.zip" \ No newline at end of file + return io_args['output_video'], f"{io_args['tracking_result_dir']}/{video_name}_pred_mask.zip" \ No newline at end of file diff --git a/tutorial/tutorial for Image-Sequence input.md b/tutorial/tutorial for Image-Sequence input.md index 86dedeea..32fcf782 100644 --- a/tutorial/tutorial for Image-Sequence input.md +++ b/tutorial/tutorial for Image-Sequence input.md @@ -4,13 +4,14 @@ **The structure of test-data-seq.zip must be like this. Please confirm that the image names are in ascending order.** ``` - test-data-seq - - 0.png - - 1.png - - 2.png - - 3.png + - 000000.png + - 000001.png + - 000002.png + - 000003.png .... - - x.png + - 0000xx.png ``` +**Note: Please ensure that the image naming method is in ascending alphabetical order.** ## Use WebUI get test Image-Sequence data ### 1. Switch to the `Image-Seq type input` tab.