Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Jsaric/cupy postprocessing #71

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
138 changes: 138 additions & 0 deletions segmentation/model/post_processing/cupy_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,138 @@
from typing import Any
import torch
import cupy
import re


@cupy.memoize(for_each_device=True)
def cupy_launch(function, kernel):
return cupy.cuda.compile_with_cache(kernel).get_function(function)


kernel_count_classes_per_instance = '''
extern "C" __global__ void kernel_count_classes_per_instance(
const int n,
const long* sem_seg,
const long* ins_seg,
const int* is_thing_arr,
int* ins_seg_count_mat,
int* stuff_areas_count,
const int h,
const int w,
const int num_classes
) {
for (int intIndex = (blockIdx.x * blockDim.x) + threadIdx.x; intIndex < n; intIndex += blockDim.x * gridDim.x) {
const int i = ( intIndex / w) % h;
const int j = intIndex % w;

const int ins_id = ins_seg[i * w + j];
const int sem_class = sem_seg[i * w + j];

if(ins_id > 0 && is_thing_arr[sem_class] == 1) {
atomicAdd(&ins_seg_count_mat[ins_id * num_classes + sem_class], 1);
}

atomicAdd(&stuff_areas_count[sem_class], 1);
}
}
'''

kernel_paste_instance_and_semantic = '''
extern "C" __global__ void kernel_paste_instance_and_semantic(
const int n,
const long* sem_seg,
const long* ins_seg,
long* pan_seg,
const long* instance_classes,
const int* semseg_areas,
const int* thing_list,
const int label_divisor,
const int stuff_area,
const int h,
const int w
){
for (int intIndex = (blockIdx.x * blockDim.x) + threadIdx.x; intIndex < n; intIndex += blockDim.x * gridDim.x) {
const int i = ( intIndex / w) % h;
const int j = intIndex % w;
const int ins_id = ins_seg[i * w + j];
const int sem_class = sem_seg[i * w + j];

if(ins_id > 0 && thing_list[sem_class] == 1) {
pan_seg[i * w + j] = instance_classes[ins_id] * label_divisor + ins_id;
}
else if(ins_id == 0 && thing_list[sem_class] == 0 && semseg_areas[sem_class] >= stuff_area) {
pan_seg[i * w + j] = sem_class * label_divisor;
}
}
}
'''


class _FunctionCountInstanceClassesAndStuffArea(torch.autograd.Function):
@staticmethod
def backward(ctx: Any, *grad_outputs: Any) -> Any:
raise NotImplementedError

@staticmethod
def forward(self, sem_seg, ins_seg, is_thing_arr, max_instances, num_classes):
sem_seg = sem_seg.squeeze()
ins_seg = ins_seg.squeeze()
ss_h, ss_w = sem_seg.shape
ins_h, ins_w = ins_seg.shape
assert (ss_h == ins_h)
assert (ss_w == ins_w)

ins_seg_count_mat = torch.zeros([max_instances, num_classes], dtype=torch.int32).cuda()
stuff_areas_count_mat = torch.zeros(num_classes, dtype=torch.int32).cuda()

if sem_seg.is_cuda:
n = sem_seg.nelement()
cupy_launch('kernel_count_classes_per_instance', kernel_count_classes_per_instance)(
grid=tuple([int((n + 512 - 1) / 512), ]),
block=tuple([512, ]),
args=[n, sem_seg.data_ptr(), ins_seg.data_ptr(), is_thing_arr.data_ptr(), ins_seg_count_mat.data_ptr(),
stuff_areas_count_mat.data_ptr(), ss_h, ss_w, num_classes]
)
else:
raise NotImplementedError()
return ins_seg_count_mat, stuff_areas_count_mat


class _FunctionPasteInstanceAndSemantic(torch.autograd.Function):
@staticmethod
def backward(ctx: Any, *grad_outputs: Any) -> Any:
raise NotImplementedError

@staticmethod
def forward(self, sem_seg, ins_seg, pan_seg, ins_classes,
semseg_areas, is_thing_arr, stuff_area, void_label, label_divisor):
sem_seg = sem_seg.squeeze()
ins_seg = ins_seg.squeeze()
pan_seg = pan_seg.squeeze()
ss_h, ss_w = sem_seg.shape
ins_h, ins_w = ins_seg.shape
pan_h, pan_w = pan_seg.shape
assert (pan_h == ss_h == ins_h)
assert (pan_w == ss_w == ins_w)

if sem_seg.is_cuda:
n = sem_seg.nelement()
cupy_launch('kernel_paste_instance_and_semantic', kernel_paste_instance_and_semantic)(
grid=tuple([int((n + 512 - 1) / 512)]),
block=tuple([512]),
args=[n, sem_seg.data_ptr(), ins_seg.data_ptr(), pan_seg.data_ptr(), ins_classes.data_ptr(),
semseg_areas.data_ptr(), is_thing_arr.data_ptr(), label_divisor, stuff_area, ss_h, ss_w]
)
else:
raise NotImplementedError()
return pan_seg


def count_classes_per_instance_and_stuff_areas(sem_seg, ins_seg, is_thing_arr, max_instances, num_classes):
return _FunctionCountInstanceClassesAndStuffArea.apply(sem_seg, ins_seg, is_thing_arr, max_instances, num_classes)


def merge_instance_and_semantic(sem_seg, ins_seg, pan_seg, ins_classes, semseg_areas, is_thing_arr, stuff_area,
void_label, label_divisor):
return _FunctionPasteInstanceAndSemantic.apply(sem_seg, ins_seg, pan_seg, ins_classes, semseg_areas,
is_thing_arr, stuff_area, void_label, label_divisor)
78 changes: 30 additions & 48 deletions segmentation/model/post_processing/instance_post_processing.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@

import torch
import torch.nn.functional as F

from .semantic_post_processing import get_semantic_segmentation
from .cupy_utils import count_classes_per_instance_and_stuff_areas, merge_instance_and_semantic

__all__ = ['find_instance_center', 'get_instance_segmentation', 'get_panoptic_segmentation']

Expand Down Expand Up @@ -39,15 +39,17 @@ def find_instance_center(ctr_hmp, threshold=0.1, nms_kernel=3, top_k=None):
assert len(ctr_hmp.size()) == 2, 'Something is wrong with center heatmap dimension.'

# find non-zero elements
ctr_all = torch.nonzero(ctr_hmp > 0)
ctr_all = torch.nonzero(ctr_hmp > 0, as_tuple=True)
centers = torch.stack(ctr_all, 1)
if top_k is None:
return ctr_all
elif ctr_all.size(0) < top_k:
return ctr_all
return centers
elif len(centers) < top_k:
return centers
else:
# find top k centers.
top_k_scores, _ = torch.topk(torch.flatten(ctr_hmp), top_k)
return torch.nonzero(ctr_hmp > top_k_scores[-1])
scores = ctr_hmp[ctr_all]
_, indices = torch.topk(scores, top_k)
return centers[indices]


def group_pixels(ctr, offsets):
Expand All @@ -70,10 +72,8 @@ def group_pixels(ctr, offsets):
y_coord = torch.arange(height, dtype=offsets.dtype, device=offsets.device).repeat(1, width, 1).transpose(1, 2)
x_coord = torch.arange(width, dtype=offsets.dtype, device=offsets.device).repeat(1, height, 1)
coord = torch.cat((y_coord, x_coord), dim=0)

ctr_loc = coord + offsets
ctr_loc = ctr_loc.reshape((2, height * width)).transpose(1, 0)

# ctr: [K, 2] -> [K, 1, 2]
# ctr_loc = [H*W, 2] -> [1, H*W, 2]
ctr = ctr.unsqueeze(1)
Expand Down Expand Up @@ -120,7 +120,7 @@ def get_instance_segmentation(sem_seg, ctr_hmp, offsets, thing_list, threshold=0
return thing_seg * ins_seg, ctr.unsqueeze(0)


def merge_semantic_and_instance(sem_seg, ins_seg, label_divisor, thing_list, stuff_area, void_label):
def merge_semantic_and_instance(sem_seg, ins_seg, label_divisor, thing_list, stuff_area, void_label, num_classes=19):
"""
Post-processing for panoptic segmentation, by merging semantic segmentation label and class agnostic
instance segmentation label.
Expand All @@ -131,56 +131,36 @@ def merge_semantic_and_instance(sem_seg, ins_seg, label_divisor, thing_list, stu
thing_list: A List of thing class id.
stuff_area: An Integer, remove stuff whose area is less tan stuff_area.
void_label: An Integer, indicates the region has no confident prediction.
top_k: An Integer, top k centers to keep.
num_classes: An Integer, number of semantic classes.
Returns:
A Tensor of shape [1, H, W] (to be gathered by distributed data parallel).
Raises:
ValueError, if batch size is not 1.
"""
# In case thing mask does not align with semantic prediction
pan_seg = torch.zeros_like(sem_seg) + void_label
thing_seg = ins_seg > 0
semantic_thing_seg = torch.zeros_like(sem_seg)
for thing_class in thing_list:
semantic_thing_seg[sem_seg == thing_class] = 1

# keep track of instance id for each class
class_id_tracker = {}
tl = torch.tensor(thing_list).view(-1)
is_thing_arr = torch.zeros(num_classes, dtype=torch.int32)
is_thing_arr[tl] = 1
is_thing_arr = is_thing_arr.cuda()

# paste thing by majority voting
instance_ids = torch.unique(ins_seg)
for ins_id in instance_ids:
if ins_id == 0:
continue
# Make sure only do majority voting within semantic_thing_seg
thing_mask = (ins_seg == ins_id) & (semantic_thing_seg == 1)
if torch.nonzero(thing_mask).size(0) == 0:
continue
class_id, _ = torch.mode(sem_seg[thing_mask].view(-1, ))
if class_id.item() in class_id_tracker:
new_ins_id = class_id_tracker[class_id.item()]
else:
class_id_tracker[class_id.item()] = 1
new_ins_id = 1
class_id_tracker[class_id.item()] += 1
pan_seg[thing_mask] = class_id * label_divisor + new_ins_id

# paste stuff to unoccupied area
class_ids = torch.unique(sem_seg)
for class_id in class_ids:
if class_id.item() in thing_list:
# thing class
continue
# calculate stuff area
stuff_mask = (sem_seg == class_id) & (~thing_seg)
area = torch.nonzero(stuff_mask).size(0)
if area >= stuff_area:
pan_seg[stuff_mask] = class_id * label_divisor
max_ids = ins_seg.max() + 1
instance_classes_mat, stuff_areas = count_classes_per_instance_and_stuff_areas(
sem_seg, ins_seg, is_thing_arr, max_ids, num_classes
)

instance_classes = instance_classes_mat.argmax(1)

pan_seg = merge_instance_and_semantic(sem_seg, ins_seg, pan_seg, instance_classes, stuff_areas,
is_thing_arr, stuff_area, void_label, label_divisor)

pan_seg = pan_seg.view(1, *pan_seg.shape[-2:])
return pan_seg


def get_panoptic_segmentation(sem, ctr_hmp, offsets, thing_list, label_divisor, stuff_area, void_label,
threshold=0.1, nms_kernel=3, top_k=None, foreground_mask=None):
threshold=0.1, nms_kernel=3, top_k=None, foreground_mask=None, num_classes=None):
"""
Post-processing for panoptic segmentation.
Arguments:
Expand All @@ -199,6 +179,7 @@ def get_panoptic_segmentation(sem, ctr_hmp, offsets, thing_list, label_divisor,
top_k: An Integer, top k centers to keep.
foreground_mask: A Tensor of shape [N, 2, H, W] of raw foreground mask, where N is the batch size,
we only support N=1. Or, a processed Tensor of shape [1, H, W].
num_classes: An Integer, number of semantic classes.
Returns:
A Tensor of shape [1, H, W] (to be gathered by distributed data parallel), int64.
Raises:
Expand Down Expand Up @@ -232,6 +213,7 @@ def get_panoptic_segmentation(sem, ctr_hmp, offsets, thing_list, label_divisor,
instance, center = get_instance_segmentation(semantic, ctr_hmp, offsets, thing_list,
threshold=threshold, nms_kernel=nms_kernel, top_k=top_k,
thing_seg=thing_seg)
panoptic = merge_semantic_and_instance(semantic, instance, label_divisor, thing_list, stuff_area, void_label)
panoptic = merge_semantic_and_instance(semantic, instance, label_divisor, thing_list, stuff_area, void_label,
num_classes=num_classes)

return panoptic, center
3 changes: 2 additions & 1 deletion tools/test_net_single_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -258,7 +258,8 @@ def main():
threshold=config.POST_PROCESSING.CENTER_THRESHOLD,
nms_kernel=config.POST_PROCESSING.NMS_KERNEL,
top_k=config.POST_PROCESSING.TOP_K_INSTANCE,
foreground_mask=foreground_pred)
foreground_mask=foreground_pred,
num_classes=data_loader.dataset.num_classes)
else:
panoptic_pred = None
torch.cuda.synchronize(device)
Expand Down