Skip to content

Commit

Permalink
refactor voxel-based segmentor
Browse files Browse the repository at this point in the history
  • Loading branch information
Xiangxu-0103 committed Dec 3, 2023
1 parent 5559545 commit 1edbe4b
Show file tree
Hide file tree
Showing 26 changed files with 221 additions and 331 deletions.
2 changes: 1 addition & 1 deletion configs/_base_/models/cylinder3d.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
grid_shape = [480, 360, 32]
point_cloud_range = [0, -3.14159265359, -4, 50, 3.14159265359, 2]
model = dict(
type='Cylinder3D',
type='VoxelSegmentor',
data_preprocessor=dict(
type='Det3DDataPreprocessor',
voxel=True,
Expand Down
3 changes: 2 additions & 1 deletion configs/_base_/models/minkunet.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
model = dict(
type='MinkUNet',
type='VoxelSegmentor',
data_preprocessor=dict(
type='Det3DDataPreprocessor',
voxel=True,
Expand All @@ -26,6 +26,7 @@
type='MinkUNetHead',
channels=96,
num_classes=19,
batch_first=False,
dropout_ratio=0,
loss_ce=dict(type='mmdet.CrossEntropyLoss', avg_non_ignore=True),
ignore_index=19),
Expand Down
3 changes: 2 additions & 1 deletion configs/_base_/models/spvcnn.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
model = dict(
type='MinkUNet',
type='VoxelSegmentor',
data_preprocessor=dict(
type='Det3DDataPreprocessor',
voxel=True,
Expand Down Expand Up @@ -27,6 +27,7 @@
type='MinkUNetHead',
channels=96,
num_classes=19,
batch_first=False,
dropout_ratio=0,
loss_ce=dict(type='mmdet.CrossEntropyLoss', avg_non_ignore=True),
ignore_index=19),
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
_base_ = ['./minkunet_w32_8xb2-15e_semantickitti.py']
_base_ = ['./minkunet18_w32_torchsparse_8xb2-amp-15e_semantickitti.py']

model = dict(
backbone=dict(
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
_base_ = ['./minkunet_w32_8xb2-15e_semantickitti.py']
_base_ = ['./minkunet18_w32_torchsparse_8xb2-amp-15e_semantickitti.py']

model = dict(
backbone=dict(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,4 +4,5 @@

model = dict(
data_preprocessor=dict(batch_first=True),
backbone=dict(sparseconv_backend='minkowski'))
backbone=dict(sparseconv_backend='minkowski'),
decode_head=dict(batch_first=True))
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

model = dict(
data_preprocessor=dict(batch_first=True),
backbone=dict(sparseconv_backend='spconv'))
backbone=dict(sparseconv_backend='spconv'),
decode_head=dict(batch_first=True))

optim_wrapper = dict(type='AmpOptimWrapper', loss_scale='dynamic')
Original file line number Diff line number Diff line change
Expand Up @@ -4,4 +4,5 @@

model = dict(
data_preprocessor=dict(batch_first=True),
backbone=dict(sparseconv_backend='spconv'))
backbone=dict(sparseconv_backend='spconv'),
decode_head=dict(batch_first=True))
2 changes: 1 addition & 1 deletion configs/spvcnn/spvcnn_w16_8xb2-amp-15e_semantickitti.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
_base_ = ['./spvcnn_w32_8xb2-15e_semantickitti.py']
_base_ = ['./spvcnn_w32_8xb2-amp-15e_semantickitti.py']

model = dict(
backbone=dict(
Expand Down
2 changes: 1 addition & 1 deletion configs/spvcnn/spvcnn_w20_8xb2-amp-15e_semantickitti.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
_base_ = ['./spvcnn_w32_8xb2-15e_semantickitti.py']
_base_ = ['./spvcnn_w32_8xb2-amp-15e_semantickitti.py']

model = dict(
backbone=dict(
Expand Down
4 changes: 2 additions & 2 deletions mmdet3d/configs/_base_/models/cylinder3d.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,15 @@
# Copyright (c) OpenMMLab. All rights reserved.
from mmdet3d.models import Cylinder3D
from mmdet3d.models.backbones import Asymm3DSpconv
from mmdet3d.models.data_preprocessors import Det3DDataPreprocessor
from mmdet3d.models.decode_heads.cylinder3d_head import Cylinder3DHead
from mmdet3d.models.losses import LovaszLoss
from mmdet3d.models.segmentors import VoxelSegmentor
from mmdet3d.models.voxel_encoders import SegVFE

grid_shape = [480, 360, 32]
point_cloud_range = [0, -3.14159265359, -4, 50, 3.14159265359, 2]
model = dict(
type=Cylinder3D,
type=VoxelSegmentor,
data_preprocessor=dict(
type=Det3DDataPreprocessor,
voxel=True,
Expand Down
4 changes: 2 additions & 2 deletions mmdet3d/configs/_base_/models/minkunet.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,10 @@
from mmdet3d.models.data_preprocessors.data_preprocessor import \
Det3DDataPreprocessor
from mmdet3d.models.decode_heads.minkunet_head import MinkUNetHead
from mmdet3d.models.segmentors.minkunet import MinkUNet
from mmdet3d.models.segmentors import VoxelSegmentor

model = dict(
type=MinkUNet,
type=VoxelSegmentor,
data_preprocessor=dict(
type=Det3DDataPreprocessor,
voxel=True,
Expand Down
16 changes: 9 additions & 7 deletions mmdet3d/models/backbones/cylinder3d.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
from mmcv.ops import (SparseConv3d, SparseConvTensor, SparseInverseConv3d,
SubMConv3d)
from mmengine.model import BaseModule
from torch import Tensor

from mmdet3d.registry import MODELS
from mmdet3d.utils import ConfigType
Expand Down Expand Up @@ -457,12 +456,14 @@ def __init__(self,
indice_key='ddcm',
norm_cfg=norm_cfg)

def forward(self, voxel_features: Tensor, coors: Tensor,
batch_size: int) -> SparseConvTensor:
def forward(self, voxel_dict: dict) -> dict:
"""Forward pass."""
coors = coors.int()
ret = SparseConvTensor(voxel_features, coors, np.array(self.grid_size),
batch_size)
voxel_features = voxel_dict['voxel_feats']
voxel_coors = voxel_dict['voxel_coors']
voxel_coors = voxel_coors.int()
batch_size = voxel_dict['coors'][-1, 0].item() + 1
ret = SparseConvTensor(voxel_features, voxel_coors,

Check warning on line 465 in mmdet3d/models/backbones/cylinder3d.py

View check run for this annotation

Codecov / codecov/patch

mmdet3d/models/backbones/cylinder3d.py#L461-L465

Added lines #L461 - L465 were not covered by tests
np.array(self.grid_size), batch_size)
ret = self.down_context(ret)

down_skip_list = []
Expand All @@ -477,5 +478,6 @@ def forward(self, voxel_features: Tensor, coors: Tensor,

ddcm = self.ddcm(up)
ddcm.features = torch.cat((ddcm.features, up.features), 1)
voxel_dict['voxel_feats'] = ddcm

Check warning on line 481 in mmdet3d/models/backbones/cylinder3d.py

View check run for this annotation

Codecov / codecov/patch

mmdet3d/models/backbones/cylinder3d.py#L481

Added line #L481 was not covered by tests

return ddcm
return voxel_dict

Check warning on line 483 in mmdet3d/models/backbones/cylinder3d.py

View check run for this annotation

Codecov / codecov/patch

mmdet3d/models/backbones/cylinder3d.py#L483

Added line #L483 was not covered by tests
22 changes: 12 additions & 10 deletions mmdet3d/models/backbones/minkunet_backbone.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import torch
from mmengine.model import BaseModule
from mmengine.registry import MODELS
from torch import Tensor, nn
from torch import nn

from mmdet3d.models.layers.minkowski_engine_block import (
IS_MINKOWSKI_ENGINE_AVAILABLE, MinkowskiBasicBlock, MinkowskiBottleneck,
Expand Down Expand Up @@ -55,8 +55,8 @@ class MinkUNetBackbone(BaseModule):
decoder_blocks (List[int]): Number of blocks in each decode layer.
block_type (str): Type of block in encoder and decoder.
sparseconv_backend (str): Sparse convolutional backend.
init_cfg (dict or :obj:`ConfigDict` or List[dict or :obj:`ConfigDict`]
, optional): Initialization config dict.
init_cfg (dict or :obj:`ConfigDict` or List[dict or :obj:`ConfigDict`],
optional): Initialization config dict.
"""

def __init__(self,
Expand Down Expand Up @@ -196,17 +196,17 @@ def __init__(self,
[decoder_layer[0],
nn.Sequential(*decoder_layer[1:])]))

def forward(self, voxel_features: Tensor, coors: Tensor) -> Tensor:
def forward(self, voxel_dict: dict) -> dict:
"""Forward function.
Args:
voxel_features (Tensor): Voxel features in shape (N, C).
coors (Tensor): Coordinates in shape (N, 4),
the columns in the order of (x_idx, y_idx, z_idx, batch_idx).
voxel_dict (dict): Dict containing voxel features.
Returns:
Tensor: Backbone features.
dict: Backbone features.
"""
voxel_features = voxel_dict['voxels']
coors = voxel_dict['coors']

Check warning on line 209 in mmdet3d/models/backbones/minkunet_backbone.py

View check run for this annotation

Codecov / codecov/patch

mmdet3d/models/backbones/minkunet_backbone.py#L208-L209

Added lines #L208 - L209 were not covered by tests
if self.sparseconv_backend == 'torchsparse':
x = torchsparse.SparseTensor(voxel_features, coors)
elif self.sparseconv_backend == 'spconv':
Expand Down Expand Up @@ -240,6 +240,8 @@ def forward(self, voxel_features: Tensor, coors: Tensor) -> Tensor:
decoder_outs.append(x)

if self.sparseconv_backend == 'spconv':
return decoder_outs[-1].features
voxel_dict['voxel_feats'] = decoder_outs[-1].features

Check warning on line 243 in mmdet3d/models/backbones/minkunet_backbone.py

View check run for this annotation

Codecov / codecov/patch

mmdet3d/models/backbones/minkunet_backbone.py#L243

Added line #L243 was not covered by tests
else:
return decoder_outs[-1].F
voxel_dict['voxel_feats'] = decoder_outs[-1].F

Check warning on line 245 in mmdet3d/models/backbones/minkunet_backbone.py

View check run for this annotation

Codecov / codecov/patch

mmdet3d/models/backbones/minkunet_backbone.py#L245

Added line #L245 was not covered by tests

return voxel_dict

Check warning on line 247 in mmdet3d/models/backbones/minkunet_backbone.py

View check run for this annotation

Codecov / codecov/patch

mmdet3d/models/backbones/minkunet_backbone.py#L247

Added line #L247 was not covered by tests
14 changes: 6 additions & 8 deletions mmdet3d/models/data_preprocessors/data_preprocessor.py
Original file line number Diff line number Diff line change
Expand Up @@ -424,9 +424,9 @@ def voxelize(self, points: List[Tensor],
voxels = torch.cat(voxels, dim=0)
coors = torch.cat(coors, dim=0)
elif self.voxel_type == 'minkunet':
voxels, coors = [], []
voxels, coors, point2voxel_maps, voxel_inds = [], [], [], []

Check warning on line 427 in mmdet3d/models/data_preprocessors/data_preprocessor.py

View check run for this annotation

Codecov / codecov/patch

mmdet3d/models/data_preprocessors/data_preprocessor.py#L427

Added line #L427 was not covered by tests
voxel_size = points[0].new_tensor(self.voxel_layer.voxel_size)
for i, (res, data_sample) in enumerate(zip(points, data_samples)):
for i, res in enumerate(points):
res_coors = torch.round(res[:, :3] / voxel_size).int()
res_coors -= res_coors.min(0)[0]

Expand All @@ -439,24 +439,22 @@ def voxelize(self, points: List[Tensor],
inds = np.random.choice(
inds, self.max_voxels, replace=False)
inds = torch.from_numpy(inds).cuda()
if hasattr(data_sample.gt_pts_seg, 'pts_semantic_mask'):
data_sample.gt_pts_seg.voxel_semantic_mask \
= data_sample.gt_pts_seg.pts_semantic_mask[inds]
res_voxel_coors = res_coors[inds]
res_voxels = res[inds]
if self.batch_first:
res_voxel_coors = F.pad(
res_voxel_coors, (1, 0), mode='constant', value=i)
data_sample.batch_idx = res_voxel_coors[:, 0]
else:
res_voxel_coors = F.pad(
res_voxel_coors, (0, 1), mode='constant', value=i)
data_sample.batch_idx = res_voxel_coors[:, -1]
data_sample.point2voxel_map = point2voxel_map.long()
voxels.append(res_voxels)
coors.append(res_voxel_coors)
point2voxel_maps.append(point2voxel_map)
voxel_inds.append(inds)

Check warning on line 453 in mmdet3d/models/data_preprocessors/data_preprocessor.py

View check run for this annotation

Codecov / codecov/patch

mmdet3d/models/data_preprocessors/data_preprocessor.py#L452-L453

Added lines #L452 - L453 were not covered by tests
voxels = torch.cat(voxels, dim=0)
coors = torch.cat(coors, dim=0)
voxel_dict['point2voxel_maps'] = point2voxel_maps
voxel_dict['voxel_inds'] = voxel_inds

Check warning on line 457 in mmdet3d/models/data_preprocessors/data_preprocessor.py

View check run for this annotation

Codecov / codecov/patch

mmdet3d/models/data_preprocessors/data_preprocessor.py#L456-L457

Added lines #L456 - L457 were not covered by tests

else:
raise ValueError(f'Invalid voxelization type {self.voxel_type}')
Expand Down
26 changes: 11 additions & 15 deletions mmdet3d/models/decode_heads/cylinder3d_head.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from mmdet3d.models.data_preprocessors.voxelize import dynamic_scatter_3d
from mmdet3d.registry import MODELS
from mmdet3d.structures.det3d_data_sample import SampleList
from mmdet3d.utils import OptConfigType
from mmdet3d.utils import ConfigType, OptConfigType
from .decode_head import Base3DDecodeHead


Expand Down Expand Up @@ -88,36 +88,32 @@ def loss_by_feat(self, voxel_dict: dict,

return loss

def predict(
self,
voxel_dict: dict,
batch_data_samples: SampleList,
) -> List[Tensor]:
def predict(self, voxel_dict: dict, batch_input_metas: List[dict],
test_cfg: ConfigType) -> List[Tensor]:
"""Forward function for testing.
Args:
voxel_dict (dict): Features from backbone.
batch_data_samples (List[:obj:`Det3DDataSample`]): The det3d data
samples. It usually includes information such as `metainfo` and
`gt_pts_seg`. We use `point2voxel_map` in this function.
batch_input_metas (List[dict]): Meta information of a batch of
samples.
test_cfg (dict or :obj:`ConfigDict`): The testing config.
Returns:
List[Tensor]: List of point-wise segmentation logits.
"""
voxel_dict = self.forward(voxel_dict)
seg_pred_list = self.predict_by_feat(voxel_dict, batch_data_samples)
seg_pred_list = self.predict_by_feat(voxel_dict, batch_input_metas)
return seg_pred_list

Check warning on line 106 in mmdet3d/models/decode_heads/cylinder3d_head.py

View check run for this annotation

Codecov / codecov/patch

mmdet3d/models/decode_heads/cylinder3d_head.py#L104-L106

Added lines #L104 - L106 were not covered by tests

def predict_by_feat(self, voxel_dict: dict,
batch_data_samples: SampleList) -> List[Tensor]:
batch_input_metas: List[dict]) -> Tensor:
"""Predict function.
Args:
voxel_dict (dict): The dict may contain `logits`,
`point2voxel_map`.
batch_data_samples (List[:obj:`Det3DDataSample`]): The det3d data
samples. It usually includes information such as `metainfo` and
`gt_pts_seg`.
batch_input_metas (List[dict]): Meta information of a batch of
samples.
Returns:
List[Tensor]: List of point-wise segmentation logits.
Expand All @@ -126,7 +122,7 @@ def predict_by_feat(self, voxel_dict: dict,

seg_pred_list = []
coors = voxel_dict['voxel_coors']

Check warning on line 124 in mmdet3d/models/decode_heads/cylinder3d_head.py

View check run for this annotation

Codecov / codecov/patch

mmdet3d/models/decode_heads/cylinder3d_head.py#L124

Added line #L124 was not covered by tests
for batch_idx in range(len(batch_data_samples)):
for batch_idx in range(len(batch_input_metas)):
batch_mask = coors[:, 0] == batch_idx
seg_logits_sample = seg_logits[batch_mask]
point2voxel_map = voxel_dict['point2voxel_maps'][batch_idx].long()

Check warning on line 128 in mmdet3d/models/decode_heads/cylinder3d_head.py

View check run for this annotation

Codecov / codecov/patch

mmdet3d/models/decode_heads/cylinder3d_head.py#L126-L128

Added lines #L126 - L128 were not covered by tests
Expand Down
Loading

0 comments on commit 1edbe4b

Please sign in to comment.