Skip to content

Commit

Permalink
Update tutorial and saved format for Image-Seq input
Browse files Browse the repository at this point in the history
  • Loading branch information
yamy-cheng committed Apr 28, 2023
1 parent 02d319c commit 02bec3b
Show file tree
Hide file tree
Showing 5 changed files with 65 additions and 54 deletions.
3 changes: 2 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -7,4 +7,5 @@ assets/*gif
*.pyc
debug
cym_utils
/src
/src
/tracking_results
4 changes: 2 additions & 2 deletions app.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ def get_click_prompt(click_stack, point):
def get_meta_from_video(input_video):
if input_video is None:
return None, None, None, ""

print("get meta information of input video")
cap = cv2.VideoCapture(input_video)

Expand Down Expand Up @@ -283,7 +283,7 @@ def seg_track_app():
input_img_seq = gr.File(label='Input Image-Seq').style(height=550)
with gr.Column(scale=0.25):
extract_button = gr.Button(value="extract")
fps = gr.Slider(label='fps', minimum=5, maximum=50, value=30, step=1)
fps = gr.Slider(label='fps', minimum=5, maximum=50, value=8, step=1)

input_first_frame = gr.Image(label='Segment result of first frame',interactive=True).style(height=550)

Expand Down
Binary file modified assets/840_iSXIa0hE8Ek.zip
Binary file not shown.
101 changes: 55 additions & 46 deletions seg_track_anything.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,12 @@ def draw_mask(img, mask, alpha=0.5, id_countour=False):

return img_mask.astype(img.dtype)

def create_dir(dir_path):
if os.path.isdir(dir_path):
os.system(f"rm -r {dir_path}")

os.makedirs(dir_path)

aot_model2ckpt = {
"deaotb": "./ckpt/DeAOTB_PRE_YTB_DAV.pth",
"deaotl": "./ckpt/DeAOTL_PRE_YTB_DAV",
Expand All @@ -63,28 +69,44 @@ def draw_mask(img, mask, alpha=0.5, id_countour=False):
def tracking_objects_in_video(SegTracker, input_video, input_img_seq, fps):

if input_video is not None:
return video_type_input_tracking(SegTracker, input_video)
video_name = os.path.basename(input_video).split('.')[0]
elif input_img_seq is not None:
return img_seq_type_input_tracking(SegTracker, input_img_seq, fps)

return None, None
file_name = input_img_seq.name.split('/')[-1].split('.')[0]
file_path = f'./assets/{file_name}'
imgs_path = sorted([os.path.join(file_path, img_name) for img_name in os.listdir(file_path)])
video_name = file_name
else:
return None, None

def video_type_input_tracking(SegTracker, input_video):
video_name = os.path.basename(input_video).split('.')[0]
# create dir to save result
tracking_result_dir = f'{os.path.join(os.path.dirname(__file__), "tracking_results", f"{video_name}")}'
create_dir(tracking_result_dir)

io_args = {
'input_video': f'{input_video}',
'output_mask_dir': f'{os.path.join(os.path.dirname(__file__), "assets")}/{video_name}_masks',
'output_video': f'{os.path.join(os.path.dirname(__file__), "assets")}/{video_name}_seg.mp4', # keep same format as input video
'output_gif': f'{os.path.join(os.path.dirname(__file__), "assets")}/{video_name}_seg.gif',
'tracking_result_dir': tracking_result_dir,
'output_mask_dir': f'{tracking_result_dir}/{video_name}_masks',
'output_masked_frame_dir': f'{tracking_result_dir}/{video_name}_masked_frames',
'output_video': f'{tracking_result_dir}/{video_name}_seg.mp4', # keep same format as input video
'output_gif': f'{tracking_result_dir}/{video_name}_seg.gif',
}

if input_video is not None:
return video_type_input_tracking(SegTracker, input_video, io_args, video_name)
elif input_img_seq is not None:
return img_seq_type_input_tracking(SegTracker, io_args, video_name, imgs_path, fps)


def video_type_input_tracking(SegTracker, input_video, io_args, video_name):

# source video to segment
cap = cv2.VideoCapture(input_video)
fps = cap.get(cv2.CAP_PROP_FPS)
# output masks
output_dir = io_args['output_mask_dir']
if not os.path.exists(output_dir):
os.makedirs(output_dir)

# create dir to save predicted mask and masked frame
output_mask_dir = io_args['output_mask_dir']
create_dir(io_args['output_mask_dir'])
create_dir(io_args['output_masked_frame_dir'])

pred_list = []
masked_pred_list = []

Expand All @@ -111,7 +133,7 @@ def video_type_input_tracking(SegTracker, input_video):
track_mask = SegTracker.track(frame)
# find new objects, and update tracker with new objects
new_obj_mask = SegTracker.find_new_objs(track_mask,seg_mask)
save_prediction(new_obj_mask,output_dir,str(frame_idx)+'_new.png')
save_prediction(new_obj_mask, output_mask_dir, str(frame_idx).zfill(5) + '_new.png')
pred_mask = track_mask + new_obj_mask
# segtracker.restart_tracker()
SegTracker.add_reference(frame, pred_mask)
Expand All @@ -120,7 +142,7 @@ def video_type_input_tracking(SegTracker, input_video):
torch.cuda.empty_cache()
gc.collect()

save_prediction(pred_mask,output_dir,str(frame_idx)+'.png')
save_prediction(pred_mask, output_mask_dir, str(frame_idx).zfill(5) + '.png')
pred_list.append(pred_mask)

print("processed frame {}, obj_num {}".format(frame_idx, SegTracker.get_obj_num()),end='\r')
Expand All @@ -133,16 +155,16 @@ def video_type_input_tracking(SegTracker, input_video):
##################

# draw pred mask on frame and save as a video
cap = cv2.VideoCapture(io_args['input_video'])
cap = cv2.VideoCapture(input_video)
fps = cap.get(cv2.CAP_PROP_FPS)
width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
num_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))

fourcc = cv2.VideoWriter_fourcc(*"mp4v")
# if io_args['input_video'][-3:]=='mp4':
# if input_video[-3:]=='mp4':
# fourcc = cv2.VideoWriter_fourcc(*"mp4v")
# elif io_args['input_video'][-3:] == 'avi':
# elif input_video[-3:] == 'avi':
# fourcc = cv2.VideoWriter_fourcc(*"MJPG")
# # fourcc = cv2.VideoWriter_fourcc(*"XVID")
# else:
Expand All @@ -158,8 +180,9 @@ def video_type_input_tracking(SegTracker, input_video):
frame = cv2.cvtColor(frame,cv2.COLOR_BGR2RGB)
pred_mask = pred_list[frame_idx]
masked_frame = draw_mask(frame, pred_mask)
masked_pred_list.append(masked_frame)
cv2.imwrite(f"{io_args['output_masked_frame_dir']}/{frame_idx}.png", masked_frame[:, :, ::-1])

masked_pred_list.append(masked_frame)
masked_frame = cv2.cvtColor(masked_frame,cv2.COLOR_RGB2BGR)
out.write(masked_frame)
print('frame {} writed'.format(frame_idx),end='\r')
Expand All @@ -184,22 +207,13 @@ def video_type_input_tracking(SegTracker, input_video):
return io_args['output_video'], f"./assets/{video_name}_pred_mask.zip"


def img_seq_type_input_tracking(SegTracker, input_img_seq, fps):
file_name = input_img_seq.name.split('/')[-1].split('.')[0]
file_path = f'./assets/{file_name}'
imgs_path = sorted([os.path.join(file_path, img_name) for img_name in os.listdir(file_path)])
def img_seq_type_input_tracking(SegTracker, io_args, video_name, imgs_path, fps):

# create dir to save predicted mask and masked frame
output_mask_dir = io_args['output_mask_dir']
create_dir(io_args['output_mask_dir'])
create_dir(io_args['output_masked_frame_dir'])

video_name = file_name
io_args = {
'output_mask_dir': f'{os.path.join(os.path.dirname(__file__), "assets")}/{video_name}_masks',
'output_video': f'{os.path.join(os.path.dirname(__file__), "assets")}/{video_name}_seg.mp4', # keep same format as input video
'output_gif': f'{os.path.join(os.path.dirname(__file__), "assets")}/{video_name}_seg.gif',
}

# output masks
output_dir = io_args['output_mask_dir']
if not os.path.exists(output_dir):
os.makedirs(output_dir)
pred_list = []
masked_pred_list = []

Expand All @@ -210,6 +224,7 @@ def img_seq_type_input_tracking(SegTracker, input_img_seq, fps):

with torch.cuda.amp.autocast():
for img_path in imgs_path:
frame_name = os.path.basename(img_path).split('.')[0]
frame = cv2.imread(img_path)
frame = cv2.cvtColor(frame,cv2.COLOR_BGR2RGB)

Expand All @@ -224,7 +239,7 @@ def img_seq_type_input_tracking(SegTracker, input_img_seq, fps):
track_mask = SegTracker.track(frame)
# find new objects, and update tracker with new objects
new_obj_mask = SegTracker.find_new_objs(track_mask,seg_mask)
save_prediction(new_obj_mask,output_dir,str(frame_idx)+'_new.png')
save_prediction(new_obj_mask, output_mask_dir, f'{frame_name}_new.png')
pred_mask = track_mask + new_obj_mask
# segtracker.restart_tracker()
SegTracker.add_reference(frame, pred_mask)
Expand All @@ -233,7 +248,7 @@ def img_seq_type_input_tracking(SegTracker, input_img_seq, fps):
torch.cuda.empty_cache()
gc.collect()

save_prediction(pred_mask,output_dir,str(frame_idx)+'.png')
save_prediction(pred_mask, output_mask_dir, f'{frame_name}.png')
pred_list.append(pred_mask)

print("processed frame {}, obj_num {}".format(frame_idx, SegTracker.get_obj_num()),end='\r')
Expand All @@ -248,13 +263,6 @@ def img_seq_type_input_tracking(SegTracker, input_img_seq, fps):
height, width = pred_list[0].shape
fourcc = cv2.VideoWriter_fourcc(*"mp4v")

# if io_args['input_video'][-3:]=='mp4':
# fourcc = cv2.VideoWriter_fourcc(*"mp4v")
# elif io_args['input_video'][-3:] == 'avi':
# fourcc = cv2.VideoWriter_fourcc(*"MJPG")
# # fourcc = cv2.VideoWriter_fourcc(*"XVID")
# else:
# fourcc = int(cap.get(cv2.CAP_PROP_FOURCC))
out = cv2.VideoWriter(io_args['output_video'], fourcc, fps, (width, height))

frame_idx = 0
Expand All @@ -265,6 +273,7 @@ def img_seq_type_input_tracking(SegTracker, input_img_seq, fps):
pred_mask = pred_list[frame_idx]
masked_frame = draw_mask(frame, pred_mask)
masked_pred_list.append(masked_frame)
cv2.imwrite(f"{io_args['output_masked_frame_dir']}/{frame_idx}.png", masked_frame[:, :, ::-1])

masked_frame = cv2.cvtColor(masked_frame,cv2.COLOR_RGB2BGR)
out.write(masked_frame)
Expand All @@ -279,12 +288,12 @@ def img_seq_type_input_tracking(SegTracker, input_img_seq, fps):
print("{} saved".format(io_args['output_gif']))

# zip predicted mask
os.system(f"zip -r ./assets/{video_name}_pred_mask.zip {io_args['output_mask_dir']}")
os.system(f"zip -r {io_args['tracking_result_dir']}/{video_name}_pred_mask.zip {io_args['output_mask_dir']}")

# manually release memory (after cuda out of memory)
del SegTracker
torch.cuda.empty_cache()
gc.collect()


return io_args['output_video'], f"./assets/{video_name}_pred_mask.zip"
return io_args['output_video'], f"{io_args['tracking_result_dir']}/{video_name}_pred_mask.zip"
11 changes: 6 additions & 5 deletions tutorial/tutorial for Image-Sequence input.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,14 @@
**The structure of test-data-seq.zip must be like this. Please confirm that the image names are in ascending order.**
```
- test-data-seq
- 0.png
- 1.png
- 2.png
- 3.png
- 000000.png
- 000001.png
- 000002.png
- 000003.png
....
- x.png
- 0000xx.png
```
**Note: Please ensure that the image naming method is in ascending alphabetical order.**

## Use WebUI get test Image-Sequence data
### 1. Switch to the `Image-Seq type input` tab.
Expand Down

0 comments on commit 02bec3b

Please sign in to comment.