From 1df24da87701d5d086459c3c0b694cf752eb7e68 Mon Sep 17 00:00:00 2001 From: manuel cuevas Date: Mon, 29 Jul 2024 07:03:59 -0700 Subject: [PATCH 1/3] added owlv2 support --- nanoowl/owl_predictor.py | 163 ++++++++++++++++++++++----------------- 1 file changed, 91 insertions(+), 72 deletions(-) diff --git a/nanoowl/owl_predictor.py b/nanoowl/owl_predictor.py index 1afb897..3fae698 100644 --- a/nanoowl/owl_predictor.py +++ b/nanoowl/owl_predictor.py @@ -21,6 +21,9 @@ import tempfile import os from torchvision.ops import roi_align +from transformers.models.owlv2.modeling_owlv2 import Owlv2ForObjectDetection +from transformers.models.owlv2.processing_owlv2 import Owlv2Processor + from transformers.models.owlvit.modeling_owlvit import OwlViTForObjectDetection from transformers.models.owlvit.processing_owlvit import OwlViTProcessor from dataclasses import dataclass @@ -39,9 +42,9 @@ def _owl_center_to_corners_format_torch(bboxes_center): center_x, center_y, width, height = bboxes_center.unbind(-1) bbox_corners = torch.stack( [ - (center_x - 0.5 * width), - (center_y - 0.5 * height), - (center_x + 0.5 * width), + (center_x - 0.5 * width), + (center_y - 0.5 * height), + (center_x + 0.5 * width), (center_y + 0.5 * height) ], dim=-1, @@ -50,11 +53,12 @@ def _owl_center_to_corners_format_torch(bboxes_center): def _owl_get_image_size(hf_name: str): - image_sizes = { "google/owlvit-base-patch32": 768, "google/owlvit-base-patch16": 768, "google/owlvit-large-patch14": 840, + "google/owlv2-base-patch16-ensemble": 960, + "google/owlv2-large-patch14-ensemble": 1008, } return image_sizes[hf_name] @@ -66,6 +70,8 @@ def _owl_get_patch_size(hf_name: str): "google/owlvit-base-patch32": 32, "google/owlvit-base-patch16": 16, "google/owlvit-large-patch14": 14, + "google/owlv2-base-patch16-ensemble": 16, + "google/owlv2-large-patch14-ensemble": 14, } return patch_sizes[hf_name] @@ -141,25 +147,35 @@ class OwlDecodeOutput: class OwlPredictor(torch.nn.Module): - + def __init__(self, - model_name: str = "google/owlvit-base-patch32", - device: str = "cuda", - image_encoder_engine: Optional[str] = None, - image_encoder_engine_max_batch_size: int = 1, - image_preprocessor: Optional[ImagePreprocessor] = None - ): + model_name: str = "google/owlvit-base-patch32", + device: str = "cuda", + image_encoder_engine: Optional[str] = None, + image_encoder_engine_max_batch_size: int = 1, + image_preprocessor: Optional[ImagePreprocessor] = None + ): super().__init__() self.image_size = _owl_get_image_size(model_name) self.device = device - self.model = OwlViTForObjectDetection.from_pretrained(model_name).to(self.device).eval() - self.processor = OwlViTProcessor.from_pretrained(model_name) + + model_type = model_name.split("/")[1].split('-')[0] + if model_type == 'owlv2': + self.model = Owlv2ForObjectDetection.from_pretrained(model_name).to(self.device).eval() + self.processor = Owlv2Processor.from_pretrained(model_name) + self.base_model = self.model.owlv2 + + else: + self.model = OwlViTForObjectDetection.from_pretrained(model_name).to(self.device).eval() + self.processor = OwlViTProcessor.from_pretrained(model_name) + self.base_model = self.model.owlvit + self.patch_size = _owl_get_patch_size(model_name) self.num_patches_per_side = self.image_size // self.patch_size self.box_bias = _owl_compute_box_bias(self.num_patches_per_side).to(self.device) - self.num_patches = (self.num_patches_per_side)**2 + self.num_patches = (self.num_patches_per_side) ** 2 self.mesh_grid = torch.stack( torch.meshgrid( torch.linspace(0., 1., self.image_size), @@ -168,33 +184,35 @@ def __init__(self, ).to(self.device).float() self.image_encoder_engine = None if image_encoder_engine is not None: - image_encoder_engine = OwlPredictor.load_image_encoder_engine(image_encoder_engine, image_encoder_engine_max_batch_size) + image_encoder_engine = OwlPredictor.load_image_encoder_engine(image_encoder_engine, + image_encoder_engine_max_batch_size) self.image_encoder_engine = image_encoder_engine - self.image_preprocessor = image_preprocessor.to(self.device).eval() if image_preprocessor else ImagePreprocessor().to(self.device).eval() + self.image_preprocessor = image_preprocessor.to( + self.device).eval() if image_preprocessor else ImagePreprocessor().to(self.device).eval() def get_num_patches(self): return self.num_patches def get_device(self): return self.device - + def get_image_size(self): return (self.image_size, self.image_size) - + def encode_text(self, text: List[str]) -> OwlEncodeTextOutput: text_input = self.processor(text=text, return_tensors="pt") input_ids = text_input['input_ids'].to(self.device) attention_mask = text_input['attention_mask'].to(self.device) - text_outputs = self.model.owlvit.text_model(input_ids, attention_mask) + text_outputs = self.base_model.text_model(input_ids, attention_mask) text_embeds = text_outputs[1] - text_embeds = self.model.owlvit.text_projection(text_embeds) + text_embeds = self.base_model.text_projection(text_embeds) return OwlEncodeTextOutput(text_embeds=text_embeds) def encode_image_torch(self, image: torch.Tensor) -> OwlEncodeImageOutput: - - vision_outputs = self.model.owlvit.vision_model(image) + + vision_outputs = self.base_model.vision_model(image) last_hidden_state = vision_outputs[0] - image_embeds = self.model.owlvit.vision_model.post_layernorm(last_hidden_state) + image_embeds = self.base_model.vision_model.post_layernorm(last_hidden_state) class_token_out = image_embeds[:, :1, :] image_embeds = image_embeds[:, 1:, :] * class_token_out image_embeds = self.model.layer_norm(image_embeds) # 768 dim @@ -220,7 +238,7 @@ def encode_image_torch(self, image: torch.Tensor) -> OwlEncodeImageOutput: ) return output - + def encode_image_trt(self, image: torch.Tensor) -> OwlEncodeImageOutput: return self.image_encoder_engine(image) @@ -230,7 +248,8 @@ def encode_image(self, image: torch.Tensor) -> OwlEncodeImageOutput: else: return self.encode_image_torch(image) - def extract_rois(self, image: torch.Tensor, rois: torch.Tensor, pad_square: bool = True, padding_scale: float = 1.0): + def extract_rois(self, image: torch.Tensor, rois: torch.Tensor, pad_square: bool = True, + padding_scale: float = 1.0): if len(rois) == 0: return torch.empty( (0, image.shape[1], self.image_size, self.image_size), @@ -244,13 +263,15 @@ def extract_rois(self, image: torch.Tensor, rois: torch.Tensor, pad_square: bool cx = (rois[..., 0] + rois[..., 2]) / 2 cy = (rois[..., 1] + rois[..., 3]) / 2 s = torch.max(w, h) - rois = torch.stack([cx-s, cy-s, cx+s, cy+s], dim=-1) + rois = torch.stack([cx - s, cy - s, cx + s, cy + s], dim=-1) # compute mask pad_x = (s - w) / (2 * s) pad_y = (s - h) / (2 * s) - mask_x = (self.mesh_grid[1][None, ...] > pad_x[..., None, None]) & (self.mesh_grid[1][None, ...] < (1. - pad_x[..., None, None])) - mask_y = (self.mesh_grid[0][None, ...] > pad_y[..., None, None]) & (self.mesh_grid[0][None, ...] < (1. - pad_y[..., None, None])) + mask_x = (self.mesh_grid[1][None, ...] > pad_x[..., None, None]) & ( + self.mesh_grid[1][None, ...] < (1. - pad_x[..., None, None])) + mask_y = (self.mesh_grid[0][None, ...] > pad_y[..., None, None]) & ( + self.mesh_grid[0][None, ...] < (1. - pad_y[..., None, None])) mask = (mask_x & mask_y) # extract rois @@ -261,8 +282,8 @@ def extract_rois(self, image: torch.Tensor, rois: torch.Tensor, pad_square: bool roi_images = (roi_images * mask[:, None, :, :]) return roi_images, rois - - def encode_rois(self, image: torch.Tensor, rois: torch.Tensor, pad_square: bool = True, padding_scale: float=1.0): + + def encode_rois(self, image: torch.Tensor, rois: torch.Tensor, pad_square: bool = True, padding_scale: float = 1.0): # with torch_timeit_sync("extract rois"): roi_images, rois = self.extract_rois(image, rois, pad_square, padding_scale) # with torch_timeit_sync("encode images"): @@ -271,14 +292,14 @@ def encode_rois(self, image: torch.Tensor, rois: torch.Tensor, pad_square: bool output.pred_boxes = pred_boxes return output - def decode(self, - image_output: OwlEncodeImageOutput, - text_output: OwlEncodeTextOutput, - threshold: Union[int, float, List[Union[int, float]]] = 0.1, - ) -> OwlDecodeOutput: + def decode(self, + image_output: OwlEncodeImageOutput, + text_output: OwlEncodeTextOutput, + threshold: Union[int, float, List[Union[int, float]]] = 0.1, + ) -> OwlDecodeOutput: if isinstance(threshold, (int, float)): - threshold = [threshold] * len(text_output.text_embeds) #apply single threshold to all labels + threshold = [threshold] * len(text_output.text_embeds) # apply single threshold to all labels num_input_images = image_output.image_class_embeds.shape[0] @@ -288,7 +309,7 @@ def decode(self, query_embeds = query_embeds / (torch.linalg.norm(query_embeds, dim=-1, keepdim=True) + 1e-6) logits = torch.einsum("...pd,...qd->...pq", image_class_embeds, query_embeds) logits = (logits + image_output.logit_shift) * image_output.logit_scale - + scores_sigmoid = torch.sigmoid(logits) scores_max = scores_sigmoid.max(dim=-1) labels = scores_max.indices @@ -297,9 +318,9 @@ def decode(self, for i, thresh in enumerate(threshold): label_mask = labels == i score_mask = scores > thresh - obj_mask = torch.logical_and(label_mask,score_mask) - masks.append(obj_mask) - + obj_mask = torch.logical_and(label_mask, score_mask) + masks.append(obj_mask) + mask = masks[0] for mask_t in masks[1:]: mask = torch.logical_or(mask, mask_t) @@ -329,18 +350,18 @@ def get_image_encoder_output_names(): ] return names + def export_image_encoder_onnx(self, + output_path: str, + use_dynamic_axes: bool = True, + batch_size: int = 1, + onnx_opset=17 + ): - def export_image_encoder_onnx(self, - output_path: str, - use_dynamic_axes: bool = True, - batch_size: int = 1, - onnx_opset=17 - ): - class TempModule(torch.nn.Module): def __init__(self, parent): super().__init__() self.parent = parent + def forward(self, image): output = self.parent.encode_image_torch(image) return ( @@ -354,13 +375,13 @@ def forward(self, image): data = torch.randn(batch_size, 3, self.image_size, self.image_size).to(self.device) if use_dynamic_axes: - dynamic_axes = { + dynamic_axes = { "image": {0: "batch"}, "image_embeds": {0: "batch"}, "image_class_embeds": {0: "batch"}, "logit_shift": {0: "batch"}, "logit_scale": {0: "batch"}, - "pred_boxes": {0: "batch"} + "pred_boxes": {0: "batch"} } else: dynamic_axes = {} @@ -368,15 +389,15 @@ def forward(self, image): model = TempModule(self) torch.onnx.export( - model, - data, - output_path, - input_names=self.get_image_encoder_input_names(), + model, + data, + output_path, + input_names=self.get_image_encoder_input_names(), output_names=self.get_image_encoder_output_names(), dynamic_axes=dynamic_axes, opset_version=onnx_opset ) - + @staticmethod def load_image_encoder_engine(engine_path: str, max_batch_size: int = 1): import tensorrt as trt @@ -401,7 +422,6 @@ def __init__(self, base_module: TRTModule, max_batch_size: int): @torch.no_grad() def forward(self, image): - b = image.shape[0] results = [] @@ -427,13 +447,13 @@ def forward(self, image): return image_encoder - def build_image_encoder_engine(self, - engine_path: str, - max_batch_size: int = 1, - fp16_mode = True, - onnx_path: Optional[str] = None, - onnx_opset: int = 17 - ): + def build_image_encoder_engine(self, + engine_path: str, + max_batch_size: int = 1, + fp16_mode=True, + onnx_path: Optional[str] = None, + onnx_opset: int = 17 + ): if onnx_path is None: onnx_dir = tempfile.mkdtemp() @@ -441,7 +461,7 @@ def build_image_encoder_engine(self, self.export_image_encoder_onnx(onnx_path, onnx_opset=onnx_opset) args = ["/usr/src/tensorrt/bin/trtexec"] - + args.append(f"--onnx={onnx_path}") args.append(f"--saveEngine={engine_path}") @@ -454,14 +474,14 @@ def build_image_encoder_engine(self, return self.load_image_encoder_engine(engine_path, max_batch_size) - def predict(self, - image: PIL.Image, - text: List[str], - text_encodings: Optional[OwlEncodeTextOutput], - threshold: Union[int, float, List[Union[int, float]]] = 0.1, - pad_square: bool = True, - - ) -> OwlDecodeOutput: + def predict(self, + image: PIL.Image, + text: List[str], + text_encodings: Optional[OwlEncodeTextOutput], + threshold: Union[int, float, List[Union[int, float]]] = 0.1, + pad_square: bool = True, + + ) -> OwlDecodeOutput: image_tensor = self.image_preprocessor.preprocess_pil_image(image) @@ -473,4 +493,3 @@ def predict(self, image_encodings = self.encode_rois(image_tensor, rois, pad_square=pad_square) return self.decode(image_encodings, text_encodings, threshold) - From 91baacd619bcf8c15bff1a0e985fc074bd7a202b Mon Sep 17 00:00:00 2001 From: manuel cuevas Date: Mon, 29 Jul 2024 11:59:32 -0700 Subject: [PATCH 2/3] align_rois argument added --- nanoowl/build_image_encoder_engine.py | 4 +++- nanoowl/owl_predictor.py | 15 ++++++++++++--- setup.py | 2 +- 3 files changed, 16 insertions(+), 5 deletions(-) diff --git a/nanoowl/build_image_encoder_engine.py b/nanoowl/build_image_encoder_engine.py index 6c5910b..8bcdceb 100644 --- a/nanoowl/build_image_encoder_engine.py +++ b/nanoowl/build_image_encoder_engine.py @@ -25,10 +25,12 @@ parser.add_argument("--model_name", type=str, default="google/owlvit-base-patch32") parser.add_argument("--fp16_mode", type=bool, default=True) parser.add_argument("--onnx_opset", type=int, default=16) + parser.add_argument("--align_rois", type=bool, default=True) args = parser.parse_args() predictor = OwlPredictor( - model_name=args.model_name + model_name=args.model_name, + align_rois =args.align_rois, ) predictor.build_image_encoder_engine( diff --git a/nanoowl/owl_predictor.py b/nanoowl/owl_predictor.py index 3fae698..594e7c2 100644 --- a/nanoowl/owl_predictor.py +++ b/nanoowl/owl_predictor.py @@ -65,7 +65,6 @@ def _owl_get_image_size(hf_name: str): def _owl_get_patch_size(hf_name: str): - patch_sizes = { "google/owlvit-base-patch32": 32, "google/owlvit-base-patch16": 16, @@ -153,11 +152,13 @@ def __init__(self, device: str = "cuda", image_encoder_engine: Optional[str] = None, image_encoder_engine_max_batch_size: int = 1, - image_preprocessor: Optional[ImagePreprocessor] = None + image_preprocessor: Optional[ImagePreprocessor] = None, + align_rois=True, ): super().__init__() + self.align_rois = align_rois self.image_size = _owl_get_image_size(model_name) self.device = device @@ -275,7 +276,15 @@ def extract_rois(self, image: torch.Tensor, rois: torch.Tensor, pad_square: bool mask = (mask_x & mask_y) # extract rois - roi_images = roi_align(image, [rois], output_size=self.get_image_size()) + if self.align_rois: + roi_images = roi_align(image, [rois], output_size=self.get_image_size()) + else: + # Crop the image for each object detected + roi_images = [] + for i in range(len(rois)): + bbox = tuple(rois[i]) + object_image = image.crop(bbox) + roi_images.append(object_image) # mask rois if pad_square: diff --git a/setup.py b/setup.py index 27230b0..a7de8dd 100644 --- a/setup.py +++ b/setup.py @@ -3,6 +3,6 @@ setup( name="nanoowl", - version="0.0.0", + version="0.0.1", packages=find_packages() ) \ No newline at end of file From 967fbd392a1ddb48a7d789200511d1910bfd7743 Mon Sep 17 00:00:00 2001 From: manuel cuevas Date: Tue, 30 Jul 2024 12:04:22 -0700 Subject: [PATCH 3/3] setup version update --- setup.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/setup.py b/setup.py index a7de8dd..bba1abe 100644 --- a/setup.py +++ b/setup.py @@ -3,6 +3,9 @@ setup( name="nanoowl", - version="0.0.1", + version="0.0.2", + description='NanoOWL is a project that optimizes OWL-ViT to run ' + '🔥 real-time 🔥 on NVIDIA Jetson Orin Platforms with ' + 'NVIDIA TensorRT', packages=find_packages() ) \ No newline at end of file