Skip to content

Commit

Permalink
Refine KITTI evaluation code.
Browse files Browse the repository at this point in the history
  • Loading branch information
will-jl944 authored Aug 25, 2022
1 parent dc2ce65 commit 958886e
Showing 1 changed file with 14 additions and 14 deletions.
28 changes: 14 additions & 14 deletions paddle3d/thirdparty/kitti_object_eval_python/eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -362,22 +362,22 @@ def fused_compute_statistics(overlaps,
dc_num += dc_nums[i]


def calculate_iou_partly(gt_annos,
dt_annos,
def calculate_iou_partly(dt_annos,
gt_annos,
metric,
num_parts=50,
z_axis=1,
z_center=1.0):
"""fast iou algorithm. this function can be used independently to
do result analysis.
Args:
gt_annos: dict, must from get_label_annos() in kitti_common.py
dt_annos: dict, must from get_label_annos() in kitti_common.py
gt_annos: dict, must from get_label_annos() in kitti_common.py
metric: eval type. 0: bbox, 1: bev, 2: 3d
num_parts: int. a parameter for fast calculate algorithm
z_axis: height axis. kitti camera use 1, lidar use 2.
"""
assert len(gt_annos) == len(dt_annos)
assert len(dt_annos) == len(gt_annos)
total_dt_num = np.stack([len(a["name"]) for a in dt_annos], 0)
total_gt_num = np.stack([len(a["name"]) for a in gt_annos], 0)
num_examples = len(gt_annos)
Expand All @@ -392,7 +392,7 @@ def calculate_iou_partly(gt_annos,
if metric == 0:
gt_boxes = np.concatenate([a["bbox"] for a in gt_annos_part], 0)
dt_boxes = np.concatenate([a["bbox"] for a in dt_annos_part], 0)
overlap_part = image_box_overlap(gt_boxes, dt_boxes)
overlap_part = image_box_overlap(dt_boxes, gt_boxes)
elif metric == 1:
loc = np.concatenate(
[a["location"][:, bev_axes] for a in gt_annos_part], 0)
Expand All @@ -408,8 +408,8 @@ def calculate_iou_partly(gt_annos,
rots = np.concatenate([a["rotation_y"] for a in dt_annos_part], 0)
dt_boxes = np.concatenate([loc, dims, rots[..., np.newaxis]],
axis=1)
overlap_part = bev_box_overlap(gt_boxes,
dt_boxes).astype(np.float64)
overlap_part = bev_box_overlap(dt_boxes,
gt_boxes).astype(np.float64)
elif metric == 2:
loc = np.concatenate([a["location"] for a in gt_annos_part], 0)
dims = np.concatenate([a["dimensions"] for a in gt_annos_part], 0)
Expand All @@ -422,7 +422,7 @@ def calculate_iou_partly(gt_annos,
dt_boxes = np.concatenate([loc, dims, rots[..., np.newaxis]],
axis=1)
overlap_part = d3_box_overlap(
gt_boxes, dt_boxes, z_axis=z_axis,
dt_boxes, gt_boxes, z_axis=z_axis,
z_center=z_center).astype(np.float64)
else:
raise ValueError("unknown metric")
Expand All @@ -431,21 +431,21 @@ def calculate_iou_partly(gt_annos,
overlaps = []
example_idx = 0
for j, num_part in enumerate(split_parts):
gt_annos_part = gt_annos[example_idx:example_idx + num_part]
dt_annos_part = dt_annos[example_idx:example_idx + num_part]
# gt_annos_part = gt_annos[example_idx:example_idx + num_part]
# dt_annos_part = dt_annos[example_idx:example_idx + num_part]
gt_num_idx, dt_num_idx = 0, 0
for i in range(num_part):
gt_box_num = total_gt_num[example_idx + i]
dt_box_num = total_dt_num[example_idx + i]
overlaps.append(
parted_overlaps[j][gt_num_idx:gt_num_idx +
gt_box_num, dt_num_idx:dt_num_idx +
dt_box_num])
parted_overlaps[j][dt_num_idx:dt_num_idx +
dt_box_num, gt_num_idx:gt_num_idx +
gt_box_num])
gt_num_idx += gt_box_num
dt_num_idx += dt_box_num
example_idx += num_part

return overlaps, parted_overlaps, total_gt_num, total_dt_num
return overlaps, parted_overlaps, total_dt_num, total_gt_num


def _prepare_data(gt_annos, dt_annos, current_class, difficulty):
Expand Down

0 comments on commit 958886e

Please sign in to comment.