To perform inference using Meta's Segment Anything Model 2.1 (SAM 2.1) in a Python notebook, follow these steps:
-
Clone the SAM 2.1 Repository:
!git clone https://github.com/facebookresearch/sam2.git
This command clones the SAM 2.1 repository to your local environment.
-
Navigate to the SAM 2.1 Directory:
%cd sam2
Ensure you're in the directory containing
setup.py
. -
Install SAM 2.1 in Development Mode:
!pip install -e .[dev] -q
This installs SAM 2.1 along with its development dependencies.
-
Download Model Checkpoints:
%cd checkpoints !./download_ckpts.sh
This script downloads the necessary model checkpoints.
-
Install the
supervision
Library:!pip install supervision -q
The
supervision
library aids in visualizing segmentation results. -
Import Required Libraries:
import torch from sam2.build_sam import build_sam2 from sam2.automatic_mask_generator import SAM2AutomaticMaskGenerator import supervision as sv from PIL import Image import numpy as np
-
Set Up the Model and Mask Generator:
model_cfg = "configs/sam2.1/sam2.1_hiera_b+.yaml" checkpoint = "path_to_checkpoint.pt" # Replace with the actual path sam2 = build_sam2(model_cfg, checkpoint, device="cuda") mask_generator = SAM2AutomaticMaskGenerator(sam2)
Replace
"path_to_checkpoint.pt"
with the path to your downloaded checkpoint. -
Load and Process the Input Image:
image_path = "path_to_image.jpg" # Replace with your image path image = Image.open(image_path) image_np = np.array(image)
Ensure the image is in RGB format.
-
Generate Segmentation Masks:
result = mask_generator.generate(image_np) detections = sv.Detections.from_sam(sam_result=result) mask_annotator = sv.MaskAnnotator(color_lookup=sv.ColorLookup.INDEX) annotated_image = image_np.copy() annotated_image = mask_annotator.annotate(annotated_image, detections=detections)
-
Visualize the Results:
import matplotlib.pyplot as plt fig, axes = plt.subplots(1, 2, figsize=(10, 5)) axes[0].imshow(annotated_image) axes[0].set_title('Segmented Image') axes[0].axis('off') axes[1].imshow(image_np) axes[1].set_title('Original Image') axes[1].axis('off') plt.tight_layout() plt.show()
This code displays the original and segmented images side by side.
For more detailed examples and advanced usage, refer to the SAM 2.1 GitHub repository.