Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

How to generate the prompt for tracking anything? #500

Open
DapengFeng opened this issue Dec 18, 2024 · 1 comment
Open

How to generate the prompt for tracking anything? #500

DapengFeng opened this issue Dec 18, 2024 · 1 comment

Comments

@DapengFeng
Copy link

DapengFeng commented Dec 18, 2024

import torch
from sam2.build_sam import build_sam2_video_predictor
from sam2.utils.amg import build_point_grid
import numpy as np
import os
import cv2
import rerun as rr
import time
import matplotlib.pyplot as plt

rr.init("tum_rgbd_sam2", spawn=True)

cmap = plt.get_cmap("tab10")


def show_anns(anns, img, idxs, borders=True):
    if len(anns) == 0:
        return

    mask_img = np.zeros_like(img, dtype=np.float32)
    h, w= img.shape[:2]
    for ann, obj_id in zip(anns, idxs):
        cmap_idx = 0 if obj_id is None else obj_id
        color = np.array([*cmap(cmap_idx)[:3]])
        m = ann.reshape(h, w, 1).cpu().numpy()
        mask_img += m * color.reshape(1, 1, -1)
        if borders:
            
            contours, _ = cv2.findContours(m.astype(np.uint8), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_NONE) 
            # Try to smooth contours
            contours = [cv2.approxPolyDP(contour, epsilon=0.01, closed=True) for contour in contours]
            cv2.drawContours(mask_img, contours, -1, (0, 0, 1, 0.4), thickness=1) 

    
    img = np.hstack((img, (mask_img * 255).astype(np.uint8)))

    rr.set_time_seconds("timestamp", time.time())
    rr.log("image", rr.Image(img, rr.ColorModel.BGR))

sam2_checkpoint = "./checkpoints/sam2.1_hiera_tiny.pt"
model_cfg = "configs/sam2.1/sam2.1_hiera_t.yaml"

sam2 = build_sam2_video_predictor(model_cfg, sam2_checkpoint, device="cuda", apply_postprocessing=False)

images = sorted(os.listdir("/home/dapeng/data/TUM_RGBD/rgbd_dataset_freiburg1_desk/rgb"))

origin_image_size = (640, 480)
point_coords = build_point_grid(4)
point_coords = point_coords.reshape(-1, 1, 2) * np.array(origin_image_size).reshape(1, 1, 2)
point_labels = np.ones((point_coords.shape[0], 1), dtype=np.int32)

with torch.inference_mode(), torch.autocast("cuda", dtype=torch.bfloat16):
    state = sam2.init_state("/home/dapeng/data/TUM_RGBD/rgbd_dataset_freiburg1_desk/rgb")
    for i in range(point_coords.shape[0]):
      frame_idx, object_ids, masks = sam2.add_new_points_or_box(state, 0, i, points=point_coords[i], labels=point_labels[i])
    for frame_idx, object_ids, masks in sam2.propagate_in_video(state):
      for i in range(point_coords.shape[0]):
        sam2.add_new_points_or_box(state, frame_idx, i, points=point_coords[i], labels=point_labels[i])
      img = cv2.imread(os.path.join("/home/dapeng/data/TUM_RGBD/rgbd_dataset_freiburg1_desk/rgb", images[frame_idx]), cv2.IMREAD_COLOR)
      show_anns(masks > 0.0, img, object_ids)
2024-12-18.19-55-19.mp4
@DapengFeng
Copy link
Author

DapengFeng commented Dec 18, 2024

Similar issues can be found in #13 #105 #185 #224

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant