Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature/owlv2 support #36

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion nanoowl/build_image_encoder_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,10 +25,12 @@
parser.add_argument("--model_name", type=str, default="google/owlvit-base-patch32")
parser.add_argument("--fp16_mode", type=bool, default=True)
parser.add_argument("--onnx_opset", type=int, default=16)
parser.add_argument("--align_rois", type=bool, default=True)
args = parser.parse_args()

predictor = OwlPredictor(
model_name=args.model_name
model_name=args.model_name,
align_rois =args.align_rois,
)

predictor.build_image_encoder_engine(
Expand Down
176 changes: 102 additions & 74 deletions nanoowl/owl_predictor.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,9 @@
import tempfile
import os
from torchvision.ops import roi_align
from transformers.models.owlv2.modeling_owlv2 import Owlv2ForObjectDetection
from transformers.models.owlv2.processing_owlv2 import Owlv2Processor

from transformers.models.owlvit.modeling_owlvit import OwlViTForObjectDetection
from transformers.models.owlvit.processing_owlvit import OwlViTProcessor
from dataclasses import dataclass
Expand All @@ -39,9 +42,9 @@ def _owl_center_to_corners_format_torch(bboxes_center):
center_x, center_y, width, height = bboxes_center.unbind(-1)
bbox_corners = torch.stack(
[
(center_x - 0.5 * width),
(center_y - 0.5 * height),
(center_x + 0.5 * width),
(center_x - 0.5 * width),
(center_y - 0.5 * height),
(center_x + 0.5 * width),
(center_y + 0.5 * height)
],
dim=-1,
Expand All @@ -50,22 +53,24 @@ def _owl_center_to_corners_format_torch(bboxes_center):


def _owl_get_image_size(hf_name: str):

image_sizes = {
"google/owlvit-base-patch32": 768,
"google/owlvit-base-patch16": 768,
"google/owlvit-large-patch14": 840,
"google/owlv2-base-patch16-ensemble": 960,
"google/owlv2-large-patch14-ensemble": 1008,
}

return image_sizes[hf_name]


def _owl_get_patch_size(hf_name: str):

patch_sizes = {
"google/owlvit-base-patch32": 32,
"google/owlvit-base-patch16": 16,
"google/owlvit-large-patch14": 14,
"google/owlv2-base-patch16-ensemble": 16,
"google/owlv2-large-patch14-ensemble": 14,
}

return patch_sizes[hf_name]
Expand Down Expand Up @@ -141,25 +146,37 @@ class OwlDecodeOutput:


class OwlPredictor(torch.nn.Module):

def __init__(self,
model_name: str = "google/owlvit-base-patch32",
device: str = "cuda",
image_encoder_engine: Optional[str] = None,
image_encoder_engine_max_batch_size: int = 1,
image_preprocessor: Optional[ImagePreprocessor] = None
):
model_name: str = "google/owlvit-base-patch32",
device: str = "cuda",
image_encoder_engine: Optional[str] = None,
image_encoder_engine_max_batch_size: int = 1,
image_preprocessor: Optional[ImagePreprocessor] = None,
align_rois=True,
):

super().__init__()

self.align_rois = align_rois
self.image_size = _owl_get_image_size(model_name)
self.device = device
self.model = OwlViTForObjectDetection.from_pretrained(model_name).to(self.device).eval()
self.processor = OwlViTProcessor.from_pretrained(model_name)

model_type = model_name.split("/")[1].split('-')[0]
if model_type == 'owlv2':
self.model = Owlv2ForObjectDetection.from_pretrained(model_name).to(self.device).eval()
self.processor = Owlv2Processor.from_pretrained(model_name)
self.base_model = self.model.owlv2

else:
self.model = OwlViTForObjectDetection.from_pretrained(model_name).to(self.device).eval()
self.processor = OwlViTProcessor.from_pretrained(model_name)
self.base_model = self.model.owlvit

self.patch_size = _owl_get_patch_size(model_name)
self.num_patches_per_side = self.image_size // self.patch_size
self.box_bias = _owl_compute_box_bias(self.num_patches_per_side).to(self.device)
self.num_patches = (self.num_patches_per_side)**2
self.num_patches = (self.num_patches_per_side) ** 2
self.mesh_grid = torch.stack(
torch.meshgrid(
torch.linspace(0., 1., self.image_size),
Expand All @@ -168,33 +185,35 @@ def __init__(self,
).to(self.device).float()
self.image_encoder_engine = None
if image_encoder_engine is not None:
image_encoder_engine = OwlPredictor.load_image_encoder_engine(image_encoder_engine, image_encoder_engine_max_batch_size)
image_encoder_engine = OwlPredictor.load_image_encoder_engine(image_encoder_engine,
image_encoder_engine_max_batch_size)
self.image_encoder_engine = image_encoder_engine
self.image_preprocessor = image_preprocessor.to(self.device).eval() if image_preprocessor else ImagePreprocessor().to(self.device).eval()
self.image_preprocessor = image_preprocessor.to(
self.device).eval() if image_preprocessor else ImagePreprocessor().to(self.device).eval()

def get_num_patches(self):
return self.num_patches

def get_device(self):
return self.device

def get_image_size(self):
return (self.image_size, self.image_size)

def encode_text(self, text: List[str]) -> OwlEncodeTextOutput:
text_input = self.processor(text=text, return_tensors="pt")
input_ids = text_input['input_ids'].to(self.device)
attention_mask = text_input['attention_mask'].to(self.device)
text_outputs = self.model.owlvit.text_model(input_ids, attention_mask)
text_outputs = self.base_model.text_model(input_ids, attention_mask)
text_embeds = text_outputs[1]
text_embeds = self.model.owlvit.text_projection(text_embeds)
text_embeds = self.base_model.text_projection(text_embeds)
return OwlEncodeTextOutput(text_embeds=text_embeds)

def encode_image_torch(self, image: torch.Tensor) -> OwlEncodeImageOutput:
vision_outputs = self.model.owlvit.vision_model(image)

vision_outputs = self.base_model.vision_model(image)
last_hidden_state = vision_outputs[0]
image_embeds = self.model.owlvit.vision_model.post_layernorm(last_hidden_state)
image_embeds = self.base_model.vision_model.post_layernorm(last_hidden_state)
class_token_out = image_embeds[:, :1, :]
image_embeds = image_embeds[:, 1:, :] * class_token_out
image_embeds = self.model.layer_norm(image_embeds) # 768 dim
Expand All @@ -220,7 +239,7 @@ def encode_image_torch(self, image: torch.Tensor) -> OwlEncodeImageOutput:
)

return output

def encode_image_trt(self, image: torch.Tensor) -> OwlEncodeImageOutput:
return self.image_encoder_engine(image)

Expand All @@ -230,7 +249,8 @@ def encode_image(self, image: torch.Tensor) -> OwlEncodeImageOutput:
else:
return self.encode_image_torch(image)

def extract_rois(self, image: torch.Tensor, rois: torch.Tensor, pad_square: bool = True, padding_scale: float = 1.0):
def extract_rois(self, image: torch.Tensor, rois: torch.Tensor, pad_square: bool = True,
padding_scale: float = 1.0):
if len(rois) == 0:
return torch.empty(
(0, image.shape[1], self.image_size, self.image_size),
Expand All @@ -244,25 +264,35 @@ def extract_rois(self, image: torch.Tensor, rois: torch.Tensor, pad_square: bool
cx = (rois[..., 0] + rois[..., 2]) / 2
cy = (rois[..., 1] + rois[..., 3]) / 2
s = torch.max(w, h)
rois = torch.stack([cx-s, cy-s, cx+s, cy+s], dim=-1)
rois = torch.stack([cx - s, cy - s, cx + s, cy + s], dim=-1)

# compute mask
pad_x = (s - w) / (2 * s)
pad_y = (s - h) / (2 * s)
mask_x = (self.mesh_grid[1][None, ...] > pad_x[..., None, None]) & (self.mesh_grid[1][None, ...] < (1. - pad_x[..., None, None]))
mask_y = (self.mesh_grid[0][None, ...] > pad_y[..., None, None]) & (self.mesh_grid[0][None, ...] < (1. - pad_y[..., None, None]))
mask_x = (self.mesh_grid[1][None, ...] > pad_x[..., None, None]) & (
self.mesh_grid[1][None, ...] < (1. - pad_x[..., None, None]))
mask_y = (self.mesh_grid[0][None, ...] > pad_y[..., None, None]) & (
self.mesh_grid[0][None, ...] < (1. - pad_y[..., None, None]))
mask = (mask_x & mask_y)

# extract rois
roi_images = roi_align(image, [rois], output_size=self.get_image_size())
if self.align_rois:
roi_images = roi_align(image, [rois], output_size=self.get_image_size())
else:
# Crop the image for each object detected
roi_images = []
for i in range(len(rois)):
bbox = tuple(rois[i])
object_image = image.crop(bbox)
roi_images.append(object_image)

# mask rois
if pad_square:
roi_images = (roi_images * mask[:, None, :, :])

return roi_images, rois
def encode_rois(self, image: torch.Tensor, rois: torch.Tensor, pad_square: bool = True, padding_scale: float=1.0):

def encode_rois(self, image: torch.Tensor, rois: torch.Tensor, pad_square: bool = True, padding_scale: float = 1.0):
# with torch_timeit_sync("extract rois"):
roi_images, rois = self.extract_rois(image, rois, pad_square, padding_scale)
# with torch_timeit_sync("encode images"):
Expand All @@ -271,14 +301,14 @@ def encode_rois(self, image: torch.Tensor, rois: torch.Tensor, pad_square: bool
output.pred_boxes = pred_boxes
return output

def decode(self,
image_output: OwlEncodeImageOutput,
text_output: OwlEncodeTextOutput,
threshold: Union[int, float, List[Union[int, float]]] = 0.1,
) -> OwlDecodeOutput:
def decode(self,
image_output: OwlEncodeImageOutput,
text_output: OwlEncodeTextOutput,
threshold: Union[int, float, List[Union[int, float]]] = 0.1,
) -> OwlDecodeOutput:

if isinstance(threshold, (int, float)):
threshold = [threshold] * len(text_output.text_embeds) #apply single threshold to all labels
threshold = [threshold] * len(text_output.text_embeds) # apply single threshold to all labels

num_input_images = image_output.image_class_embeds.shape[0]

Expand All @@ -288,7 +318,7 @@ def decode(self,
query_embeds = query_embeds / (torch.linalg.norm(query_embeds, dim=-1, keepdim=True) + 1e-6)
logits = torch.einsum("...pd,...qd->...pq", image_class_embeds, query_embeds)
logits = (logits + image_output.logit_shift) * image_output.logit_scale

scores_sigmoid = torch.sigmoid(logits)
scores_max = scores_sigmoid.max(dim=-1)
labels = scores_max.indices
Expand All @@ -297,9 +327,9 @@ def decode(self,
for i, thresh in enumerate(threshold):
label_mask = labels == i
score_mask = scores > thresh
obj_mask = torch.logical_and(label_mask,score_mask)
masks.append(obj_mask)
obj_mask = torch.logical_and(label_mask, score_mask)
masks.append(obj_mask)

mask = masks[0]
for mask_t in masks[1:]:
mask = torch.logical_or(mask, mask_t)
Expand Down Expand Up @@ -329,18 +359,18 @@ def get_image_encoder_output_names():
]
return names

def export_image_encoder_onnx(self,
output_path: str,
use_dynamic_axes: bool = True,
batch_size: int = 1,
onnx_opset=17
):

def export_image_encoder_onnx(self,
output_path: str,
use_dynamic_axes: bool = True,
batch_size: int = 1,
onnx_opset=17
):

class TempModule(torch.nn.Module):
def __init__(self, parent):
super().__init__()
self.parent = parent

def forward(self, image):
output = self.parent.encode_image_torch(image)
return (
Expand All @@ -354,29 +384,29 @@ def forward(self, image):
data = torch.randn(batch_size, 3, self.image_size, self.image_size).to(self.device)

if use_dynamic_axes:
dynamic_axes = {
dynamic_axes = {
"image": {0: "batch"},
"image_embeds": {0: "batch"},
"image_class_embeds": {0: "batch"},
"logit_shift": {0: "batch"},
"logit_scale": {0: "batch"},
"pred_boxes": {0: "batch"}
"pred_boxes": {0: "batch"}
}
else:
dynamic_axes = {}

model = TempModule(self)

torch.onnx.export(
model,
data,
output_path,
input_names=self.get_image_encoder_input_names(),
model,
data,
output_path,
input_names=self.get_image_encoder_input_names(),
output_names=self.get_image_encoder_output_names(),
dynamic_axes=dynamic_axes,
opset_version=onnx_opset
)

@staticmethod
def load_image_encoder_engine(engine_path: str, max_batch_size: int = 1):
import tensorrt as trt
Expand All @@ -401,7 +431,6 @@ def __init__(self, base_module: TRTModule, max_batch_size: int):

@torch.no_grad()
def forward(self, image):

b = image.shape[0]

results = []
Expand All @@ -427,21 +456,21 @@ def forward(self, image):

return image_encoder

def build_image_encoder_engine(self,
engine_path: str,
max_batch_size: int = 1,
fp16_mode = True,
onnx_path: Optional[str] = None,
onnx_opset: int = 17
):
def build_image_encoder_engine(self,
engine_path: str,
max_batch_size: int = 1,
fp16_mode=True,
onnx_path: Optional[str] = None,
onnx_opset: int = 17
):

if onnx_path is None:
onnx_dir = tempfile.mkdtemp()
onnx_path = os.path.join(onnx_dir, "image_encoder.onnx")
self.export_image_encoder_onnx(onnx_path, onnx_opset=onnx_opset)

args = ["/usr/src/tensorrt/bin/trtexec"]

args.append(f"--onnx={onnx_path}")
args.append(f"--saveEngine={engine_path}")

Expand All @@ -454,14 +483,14 @@ def build_image_encoder_engine(self,

return self.load_image_encoder_engine(engine_path, max_batch_size)

def predict(self,
image: PIL.Image,
text: List[str],
text_encodings: Optional[OwlEncodeTextOutput],
threshold: Union[int, float, List[Union[int, float]]] = 0.1,
pad_square: bool = True,
) -> OwlDecodeOutput:
def predict(self,
image: PIL.Image,
text: List[str],
text_encodings: Optional[OwlEncodeTextOutput],
threshold: Union[int, float, List[Union[int, float]]] = 0.1,
pad_square: bool = True,

) -> OwlDecodeOutput:

image_tensor = self.image_preprocessor.preprocess_pil_image(image)

Expand All @@ -473,4 +502,3 @@ def predict(self,
image_encodings = self.encode_rois(image_tensor, rois, pad_square=pad_square)

return self.decode(image_encodings, text_encodings, threshold)

Loading