diff --git a/segmentation/model/post_processing/cupy_utils.py b/segmentation/model/post_processing/cupy_utils.py new file mode 100644 index 0000000..92df157 --- /dev/null +++ b/segmentation/model/post_processing/cupy_utils.py @@ -0,0 +1,138 @@ +from typing import Any +import torch +import cupy +import re + + +@cupy.memoize(for_each_device=True) +def cupy_launch(function, kernel): + return cupy.cuda.compile_with_cache(kernel).get_function(function) + + +kernel_count_classes_per_instance = ''' + extern "C" __global__ void kernel_count_classes_per_instance( + const int n, + const long* sem_seg, + const long* ins_seg, + const int* is_thing_arr, + int* ins_seg_count_mat, + int* stuff_areas_count, + const int h, + const int w, + const int num_classes + ) { + for (int intIndex = (blockIdx.x * blockDim.x) + threadIdx.x; intIndex < n; intIndex += blockDim.x * gridDim.x) { + const int i = ( intIndex / w) % h; + const int j = intIndex % w; + + const int ins_id = ins_seg[i * w + j]; + const int sem_class = sem_seg[i * w + j]; + + if(ins_id > 0 && is_thing_arr[sem_class] == 1) { + atomicAdd(&ins_seg_count_mat[ins_id * num_classes + sem_class], 1); + } + + atomicAdd(&stuff_areas_count[sem_class], 1); + } + } +''' + +kernel_paste_instance_and_semantic = ''' + extern "C" __global__ void kernel_paste_instance_and_semantic( + const int n, + const long* sem_seg, + const long* ins_seg, + long* pan_seg, + const long* instance_classes, + const int* semseg_areas, + const int* thing_list, + const int label_divisor, + const int stuff_area, + const int h, + const int w + ){ + for (int intIndex = (blockIdx.x * blockDim.x) + threadIdx.x; intIndex < n; intIndex += blockDim.x * gridDim.x) { + const int i = ( intIndex / w) % h; + const int j = intIndex % w; + const int ins_id = ins_seg[i * w + j]; + const int sem_class = sem_seg[i * w + j]; + + if(ins_id > 0 && thing_list[sem_class] == 1) { + pan_seg[i * w + j] = instance_classes[ins_id] * label_divisor + ins_id; + } + else if(ins_id == 0 && thing_list[sem_class] == 0 && semseg_areas[sem_class] >= stuff_area) { + pan_seg[i * w + j] = sem_class * label_divisor; + } + } + } +''' + + +class _FunctionCountInstanceClassesAndStuffArea(torch.autograd.Function): + @staticmethod + def backward(ctx: Any, *grad_outputs: Any) -> Any: + raise NotImplementedError + + @staticmethod + def forward(self, sem_seg, ins_seg, is_thing_arr, max_instances, num_classes): + sem_seg = sem_seg.squeeze() + ins_seg = ins_seg.squeeze() + ss_h, ss_w = sem_seg.shape + ins_h, ins_w = ins_seg.shape + assert (ss_h == ins_h) + assert (ss_w == ins_w) + + ins_seg_count_mat = torch.zeros([max_instances, num_classes], dtype=torch.int32).cuda() + stuff_areas_count_mat = torch.zeros(num_classes, dtype=torch.int32).cuda() + + if sem_seg.is_cuda: + n = sem_seg.nelement() + cupy_launch('kernel_count_classes_per_instance', kernel_count_classes_per_instance)( + grid=tuple([int((n + 512 - 1) / 512), ]), + block=tuple([512, ]), + args=[n, sem_seg.data_ptr(), ins_seg.data_ptr(), is_thing_arr.data_ptr(), ins_seg_count_mat.data_ptr(), + stuff_areas_count_mat.data_ptr(), ss_h, ss_w, num_classes] + ) + else: + raise NotImplementedError() + return ins_seg_count_mat, stuff_areas_count_mat + + +class _FunctionPasteInstanceAndSemantic(torch.autograd.Function): + @staticmethod + def backward(ctx: Any, *grad_outputs: Any) -> Any: + raise NotImplementedError + + @staticmethod + def forward(self, sem_seg, ins_seg, pan_seg, ins_classes, + semseg_areas, is_thing_arr, stuff_area, void_label, label_divisor): + sem_seg = sem_seg.squeeze() + ins_seg = ins_seg.squeeze() + pan_seg = pan_seg.squeeze() + ss_h, ss_w = sem_seg.shape + ins_h, ins_w = ins_seg.shape + pan_h, pan_w = pan_seg.shape + assert (pan_h == ss_h == ins_h) + assert (pan_w == ss_w == ins_w) + + if sem_seg.is_cuda: + n = sem_seg.nelement() + cupy_launch('kernel_paste_instance_and_semantic', kernel_paste_instance_and_semantic)( + grid=tuple([int((n + 512 - 1) / 512)]), + block=tuple([512]), + args=[n, sem_seg.data_ptr(), ins_seg.data_ptr(), pan_seg.data_ptr(), ins_classes.data_ptr(), + semseg_areas.data_ptr(), is_thing_arr.data_ptr(), label_divisor, stuff_area, ss_h, ss_w] + ) + else: + raise NotImplementedError() + return pan_seg + + +def count_classes_per_instance_and_stuff_areas(sem_seg, ins_seg, is_thing_arr, max_instances, num_classes): + return _FunctionCountInstanceClassesAndStuffArea.apply(sem_seg, ins_seg, is_thing_arr, max_instances, num_classes) + + +def merge_instance_and_semantic(sem_seg, ins_seg, pan_seg, ins_classes, semseg_areas, is_thing_arr, stuff_area, + void_label, label_divisor): + return _FunctionPasteInstanceAndSemantic.apply(sem_seg, ins_seg, pan_seg, ins_classes, semseg_areas, + is_thing_arr, stuff_area, void_label, label_divisor) diff --git a/segmentation/model/post_processing/instance_post_processing.py b/segmentation/model/post_processing/instance_post_processing.py index cf02f60..46e5b09 100755 --- a/segmentation/model/post_processing/instance_post_processing.py +++ b/segmentation/model/post_processing/instance_post_processing.py @@ -5,8 +5,8 @@ import torch import torch.nn.functional as F - from .semantic_post_processing import get_semantic_segmentation +from .cupy_utils import count_classes_per_instance_and_stuff_areas, merge_instance_and_semantic __all__ = ['find_instance_center', 'get_instance_segmentation', 'get_panoptic_segmentation'] @@ -39,15 +39,17 @@ def find_instance_center(ctr_hmp, threshold=0.1, nms_kernel=3, top_k=None): assert len(ctr_hmp.size()) == 2, 'Something is wrong with center heatmap dimension.' # find non-zero elements - ctr_all = torch.nonzero(ctr_hmp > 0) + ctr_all = torch.nonzero(ctr_hmp > 0, as_tuple=True) + centers = torch.stack(ctr_all, 1) if top_k is None: - return ctr_all - elif ctr_all.size(0) < top_k: - return ctr_all + return centers + elif len(centers) < top_k: + return centers else: # find top k centers. - top_k_scores, _ = torch.topk(torch.flatten(ctr_hmp), top_k) - return torch.nonzero(ctr_hmp > top_k_scores[-1]) + scores = ctr_hmp[ctr_all] + _, indices = torch.topk(scores, top_k) + return centers[indices] def group_pixels(ctr, offsets): @@ -70,10 +72,8 @@ def group_pixels(ctr, offsets): y_coord = torch.arange(height, dtype=offsets.dtype, device=offsets.device).repeat(1, width, 1).transpose(1, 2) x_coord = torch.arange(width, dtype=offsets.dtype, device=offsets.device).repeat(1, height, 1) coord = torch.cat((y_coord, x_coord), dim=0) - ctr_loc = coord + offsets ctr_loc = ctr_loc.reshape((2, height * width)).transpose(1, 0) - # ctr: [K, 2] -> [K, 1, 2] # ctr_loc = [H*W, 2] -> [1, H*W, 2] ctr = ctr.unsqueeze(1) @@ -120,7 +120,7 @@ def get_instance_segmentation(sem_seg, ctr_hmp, offsets, thing_list, threshold=0 return thing_seg * ins_seg, ctr.unsqueeze(0) -def merge_semantic_and_instance(sem_seg, ins_seg, label_divisor, thing_list, stuff_area, void_label): +def merge_semantic_and_instance(sem_seg, ins_seg, label_divisor, thing_list, stuff_area, void_label, num_classes=19): """ Post-processing for panoptic segmentation, by merging semantic segmentation label and class agnostic instance segmentation label. @@ -131,56 +131,36 @@ def merge_semantic_and_instance(sem_seg, ins_seg, label_divisor, thing_list, stu thing_list: A List of thing class id. stuff_area: An Integer, remove stuff whose area is less tan stuff_area. void_label: An Integer, indicates the region has no confident prediction. + top_k: An Integer, top k centers to keep. + num_classes: An Integer, number of semantic classes. Returns: A Tensor of shape [1, H, W] (to be gathered by distributed data parallel). Raises: ValueError, if batch size is not 1. """ - # In case thing mask does not align with semantic prediction pan_seg = torch.zeros_like(sem_seg) + void_label - thing_seg = ins_seg > 0 - semantic_thing_seg = torch.zeros_like(sem_seg) - for thing_class in thing_list: - semantic_thing_seg[sem_seg == thing_class] = 1 - - # keep track of instance id for each class - class_id_tracker = {} + tl = torch.tensor(thing_list).view(-1) + is_thing_arr = torch.zeros(num_classes, dtype=torch.int32) + is_thing_arr[tl] = 1 + is_thing_arr = is_thing_arr.cuda() # paste thing by majority voting - instance_ids = torch.unique(ins_seg) - for ins_id in instance_ids: - if ins_id == 0: - continue - # Make sure only do majority voting within semantic_thing_seg - thing_mask = (ins_seg == ins_id) & (semantic_thing_seg == 1) - if torch.nonzero(thing_mask).size(0) == 0: - continue - class_id, _ = torch.mode(sem_seg[thing_mask].view(-1, )) - if class_id.item() in class_id_tracker: - new_ins_id = class_id_tracker[class_id.item()] - else: - class_id_tracker[class_id.item()] = 1 - new_ins_id = 1 - class_id_tracker[class_id.item()] += 1 - pan_seg[thing_mask] = class_id * label_divisor + new_ins_id - - # paste stuff to unoccupied area - class_ids = torch.unique(sem_seg) - for class_id in class_ids: - if class_id.item() in thing_list: - # thing class - continue - # calculate stuff area - stuff_mask = (sem_seg == class_id) & (~thing_seg) - area = torch.nonzero(stuff_mask).size(0) - if area >= stuff_area: - pan_seg[stuff_mask] = class_id * label_divisor + max_ids = ins_seg.max() + 1 + instance_classes_mat, stuff_areas = count_classes_per_instance_and_stuff_areas( + sem_seg, ins_seg, is_thing_arr, max_ids, num_classes + ) + + instance_classes = instance_classes_mat.argmax(1) + + pan_seg = merge_instance_and_semantic(sem_seg, ins_seg, pan_seg, instance_classes, stuff_areas, + is_thing_arr, stuff_area, void_label, label_divisor) + pan_seg = pan_seg.view(1, *pan_seg.shape[-2:]) return pan_seg def get_panoptic_segmentation(sem, ctr_hmp, offsets, thing_list, label_divisor, stuff_area, void_label, - threshold=0.1, nms_kernel=3, top_k=None, foreground_mask=None): + threshold=0.1, nms_kernel=3, top_k=None, foreground_mask=None, num_classes=None): """ Post-processing for panoptic segmentation. Arguments: @@ -199,6 +179,7 @@ def get_panoptic_segmentation(sem, ctr_hmp, offsets, thing_list, label_divisor, top_k: An Integer, top k centers to keep. foreground_mask: A Tensor of shape [N, 2, H, W] of raw foreground mask, where N is the batch size, we only support N=1. Or, a processed Tensor of shape [1, H, W]. + num_classes: An Integer, number of semantic classes. Returns: A Tensor of shape [1, H, W] (to be gathered by distributed data parallel), int64. Raises: @@ -232,6 +213,7 @@ def get_panoptic_segmentation(sem, ctr_hmp, offsets, thing_list, label_divisor, instance, center = get_instance_segmentation(semantic, ctr_hmp, offsets, thing_list, threshold=threshold, nms_kernel=nms_kernel, top_k=top_k, thing_seg=thing_seg) - panoptic = merge_semantic_and_instance(semantic, instance, label_divisor, thing_list, stuff_area, void_label) + panoptic = merge_semantic_and_instance(semantic, instance, label_divisor, thing_list, stuff_area, void_label, + num_classes=num_classes) return panoptic, center diff --git a/tools/test_net_single_core.py b/tools/test_net_single_core.py index e2de345..fcee9b4 100755 --- a/tools/test_net_single_core.py +++ b/tools/test_net_single_core.py @@ -258,7 +258,8 @@ def main(): threshold=config.POST_PROCESSING.CENTER_THRESHOLD, nms_kernel=config.POST_PROCESSING.NMS_KERNEL, top_k=config.POST_PROCESSING.TOP_K_INSTANCE, - foreground_mask=foreground_pred) + foreground_mask=foreground_pred, + num_classes=data_loader.dataset.num_classes) else: panoptic_pred = None torch.cuda.synchronize(device)