Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Version update #31

Open
wants to merge 3 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
dataset
*.pyc
__pycache__
.vscode
*outputs*
wandb
logs
results
site
*.DS_Store
2 changes: 1 addition & 1 deletion cfgs/config_msn_partseg.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ common:

num_points: 2048
num_classes: 50
batch_size: 28
batch_size: 12

base_lr: 0.001
lr_clip: 0.00001
Expand Down
2 changes: 1 addition & 1 deletion data/ShapeNetPartLoader.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ class ShapeNetPart():
def __init__(self, root, num_points = 2048, split='train', normalize=True, transforms = None):
self.transforms = transforms
self.num_points = num_points
self.root = root
self.root = "/home/thomas/HELIX/research/Relation-Shape-CNN/dataset/shapenet"
self.catfile = os.path.join(self.root, 'synsetoffset2category.txt')
self.normalize = normalize

Expand Down
2 changes: 2 additions & 0 deletions models/rscnn_msn_seg.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,7 @@ def forward(self, pointcloud: torch.cuda.FloatTensor, cls):

l_xyz, l_features = [xyz], [features]
for i in range(len(self.SA_modules)):
print("down", i)
if i < 5:
li_xyz, li_features = self.SA_modules[i](l_xyz[i], l_features[i])
if li_xyz is not None:
Expand All @@ -151,6 +152,7 @@ def forward(self, pointcloud: torch.cuda.FloatTensor, cls):
_, global_out2_feat = self.SA_modules[5](l_xyz[3], l_features[3])

for i in range(-1, -(len(self.FP_modules) + 1), -1):
print("up", i)
l_features[i - 1 - 1] = self.FP_modules[i](
l_xyz[i - 1 - 1], l_xyz[i - 1], l_features[i - 1 - 1], l_features[i - 1]
)
Expand Down
1 change: 1 addition & 0 deletions train_partseg.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import torch
torch.backends.cudnn.enabled = False
import torch.optim as optim
import torch.optim.lr_scheduler as lr_sched
import torch.nn as nn
Expand Down
129 changes: 117 additions & 12 deletions utils/pointnet2_modules.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,119 @@
from typing import *
import torch
import torch.nn as nn
import torch.nn.functional as F

import pointnet2_utils
import torch_points as tp
import pytorch_utils as pt_utils
from typing import List
import numpy as np
import time
import math

class QueryAndGroup(nn.Module):
r"""
Groups with a ball query of radius
Parameters
---------
radius : float32
Radius of ball
nsample : int32
Maximum number of points to gather in the ball
"""

def __init__(self, radius: float, nsample: int, use_xyz: bool = True):
super().__init__()
self.radius, self.nsample, self.use_xyz = radius, nsample, use_xyz

def forward(
self,
xyz: torch.Tensor,
new_xyz: torch.Tensor,
features: torch.Tensor = None,
fps_idx: torch.IntTensor = None
) -> Tuple[torch.Tensor]:
r"""
Parameters
----------
xyz : torch.Tensor
xyz coordinates of the features (B, N, 3)
new_xyz : torch.Tensor
centriods (B, npoint, 3)
features : torch.Tensor
Descriptors of the features (B, C, N)
Returns
-------
new_features : torch.Tensor
(B, 3 + C, npoint, nsample) tensor
"""

idx = tp.ball_query(self.radius, self.nsample, xyz, new_xyz, mode='dense')
xyz_trans = xyz.transpose(1, 2).contiguous()
grouped_xyz = tp.grouping_operation(
xyz_trans, idx
) # (B, 3, npoint, nsample)
raw_grouped_xyz = grouped_xyz
grouped_xyz -= new_xyz.transpose(1, 2).unsqueeze(-1)

if features is not None:
grouped_features = tp.grouping_operation(features, idx)
if self.use_xyz:
new_features = torch.cat([raw_grouped_xyz, grouped_xyz, grouped_features],
dim=1) # (B, C + 3 + 3, npoint, nsample)
else:
new_features = grouped_features
else:
assert self.use_xyz, "Cannot have not features and not use xyz as a feature!"
new_features = torch.cat([raw_grouped_xyz, grouped_xyz], dim = 1)

return new_features


class GroupAll(nn.Module):
r"""
Groups all features
Parameters
---------
"""

def __init__(self, use_xyz: bool = True):
super().__init__()
self.use_xyz = use_xyz

def forward(
self,
xyz: torch.Tensor,
new_xyz: torch.Tensor,
features: torch.Tensor = None
) -> Tuple[torch.Tensor]:
r"""
Parameters
----------
xyz : torch.Tensor
xyz coordinates of the features (B, N, 3)
new_xyz : torch.Tensor
Ignored
features : torch.Tensor
Descriptors of the features (B, C, N)
Returns
-------
new_features : torch.Tensor
(B, C + 3, 1, N) tensor
"""

grouped_xyz = xyz.transpose(1, 2).unsqueeze(2)
if features is not None:
grouped_features = features.unsqueeze(2)
if self.use_xyz:
new_features = torch.cat([grouped_xyz, grouped_features],
dim=1) # (B, 3 + C, 1, N)
else:
new_features = grouped_features
else:
new_features = grouped_xyz

return new_features

class _PointnetSAModuleBase(nn.Module):

def __init__(self):
Expand Down Expand Up @@ -38,8 +143,8 @@ def forward(self, xyz: torch.Tensor,
new_features_list = []
xyz_flipped = xyz.transpose(1, 2).contiguous()
if self.npoint is not None:
fps_idx = pointnet2_utils.furthest_point_sample(xyz, self.npoint) # (B, npoint)
new_xyz = pointnet2_utils.gather_operation(xyz_flipped, fps_idx).transpose(1, 2).contiguous()
fps_idx = tp.furthest_point_sample(xyz, self.npoint) # (B, npoint)
new_xyz = tp.gather_operation(xyz_flipped, fps_idx).transpose(1, 2).contiguous()
fps_idx = fps_idx.data
else:
new_xyz = None
Expand Down Expand Up @@ -82,7 +187,7 @@ def __init__(
mlps: List[List[int]],
use_xyz: bool = True,
bias = True,
init = nn.init.kaiming_normal,
init = nn.init.kaiming_normal_,
first_layer = False,
relation_prior = 1
):
Expand Down Expand Up @@ -112,7 +217,7 @@ def __init__(
stride = (1, 1), bias = bias)
init(xyz_raising.weight)
if bias:
nn.init.constant(xyz_raising.bias, 0)
nn.init.constant_(xyz_raising.bias, 0)
elif npoint is not None:
mapping_func1 = nn.Conv2d(in_channels = in_channels, out_channels = math.floor(C_out / 4), kernel_size = (1, 1),
stride = (1, 1), bias = bias)
Expand All @@ -122,14 +227,14 @@ def __init__(
init(mapping_func1.weight)
init(mapping_func2.weight)
if bias:
nn.init.constant(mapping_func1.bias, 0)
nn.init.constant(mapping_func2.bias, 0)
nn.init.constant_(mapping_func1.bias, 0)
nn.init.constant_(mapping_func2.bias, 0)

# channel raising mapping
cr_mapping = nn.Conv1d(in_channels = C_in if not first_layer else 16, out_channels = C_out, kernel_size = 1,
stride = 1, bias = bias)
init(cr_mapping.weight)
nn.init.constant(cr_mapping.bias, 0)
nn.init.constant_(cr_mapping.bias, 0)

if first_layer:
mapping = [mapping_func1, mapping_func2, cr_mapping, xyz_raising]
Expand All @@ -140,8 +245,8 @@ def __init__(
radius = radii[i]
nsample = nsamples[i]
self.groupers.append(
pointnet2_utils.QueryAndGroup(radius, nsample, use_xyz=use_xyz)
if npoint is not None else pointnet2_utils.GroupAll(use_xyz)
QueryAndGroup(radius, nsample, use_xyz=use_xyz)
if npoint is not None else GroupAll(use_xyz)
)
mlp_spec = mlps[i]
if use_xyz:
Expand Down Expand Up @@ -224,12 +329,12 @@ def forward(
(B, mlp[-1], n) tensor of the features of the unknown features
"""

dist, idx = pointnet2_utils.three_nn(unknown, known)
dist, idx = tp.three_nn(unknown, known)
dist_recip = 1.0 / (dist + 1e-8)
norm = torch.sum(dist_recip, dim=2, keepdim=True)
weight = dist_recip / norm

interpolated_feats = pointnet2_utils.three_interpolate(
interpolated_feats = tp.three_interpolate(
known_feats, idx, weight
)
if unknow_feats is not None:
Expand Down
30 changes: 16 additions & 14 deletions utils/pytorch_utils/pytorch_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ def __init__(
self,
C_in,
C_out,
activation = nn.ReLU(inplace=True),
activation = nn.ReLU(),
mapping = None,
relation_prior = 1,
first_layer = False
Expand All @@ -44,7 +44,6 @@ def __init__(
self.xyz_raising = mapping[3]

def forward(self, input): # input: (B, 3 + 3 + C_in, npoint, centroid + nsample)

x = input[:, 3:, :, :] # (B, C_in, npoint, nsample+1), input features
C_in = x.size()[1]
nsample = x.size()[3]
Expand All @@ -62,13 +61,16 @@ def forward(self, input): # input: (B, 3 + 3 + C_in, npoint, centroid + nsample)
h_xi_xj = torch.cat((h_xi_xj, coord_xi, abs_coord, delta_x), dim = 1)
elif self.relation_prior == 2:
h_xi_xj = torch.cat((h_xi_xj, coord_xi, zero_vec, abs_coord, zero_vec, delta_x, zero_vec), dim = 1)
del coord_xi, abs_coord, delta_x

h_xi_xj = self.mapping_func2(self.activation(self.bn_mapping(self.mapping_func1(h_xi_xj))))
#try:
# h_xi_xj = self.mapping_func1(h_xi_xj)
#except:
h_xi_xj = self.mapping_func1(h_xi_xj)
h_xi_xj = self.activation(self.bn_mapping(h_xi_xj))
h_xi_xj = self.mapping_func2(h_xi_xj)
if self.first_layer:
x = self.activation(self.bn_xyz_raising(self.xyz_raising(x)))
x = F.max_pool2d(self.activation(self.bn_rsconv(torch.mul(h_xi_xj, x))), kernel_size = (1, nsample)).squeeze(3) # (B, C_in, npoint)
del h_xi_xj
x = self.activation(self.bn_channel_raising(self.cr_mapping(x)))

return x
Expand Down Expand Up @@ -139,7 +141,7 @@ def __init__(
self,
C_in,
C_out,
init=nn.init.kaiming_normal,
init=nn.init.kaiming_normal_,
bias = True,
activation = nn.ReLU(inplace=True)
):
Expand All @@ -152,7 +154,7 @@ def __init__(

init(self.conv_avg.weight)
if bias:
nn.init.constant(self.conv_avg.bias, 0)
nn.init.constant_(self.conv_avg.bias, 0)

def forward(self, x):
nsample = x.size()[3]
Expand Down Expand Up @@ -198,8 +200,8 @@ def __init__(self, in_size, batch_norm=None, name=""):
super().__init__()
self.add_module(name + "bn", batch_norm(in_size))

nn.init.constant(self[0].weight, 1.0)
nn.init.constant(self[0].bias, 0)
nn.init.constant_(self[0].weight, 1.0)
nn.init.constant_(self[0].bias, 0)


class BatchNorm1d(_BNBase):
Expand Down Expand Up @@ -251,7 +253,7 @@ def __init__(
)
init(conv_unit.weight)
if bias:
nn.init.constant(conv_unit.bias, 0)
nn.init.constant_(conv_unit.bias, 0)

if bn:
if not preact:
Expand Down Expand Up @@ -288,7 +290,7 @@ def __init__(
padding: int = 0,
activation=nn.ReLU(inplace=True),
bn: bool = False,
init=nn.init.kaiming_normal,
init=nn.init.kaiming_normal_,
bias: bool = True,
preact: bool = False,
name: str = ""
Expand Down Expand Up @@ -322,7 +324,7 @@ def __init__(
padding: Tuple[int, int] = (0, 0),
activation=nn.ReLU(inplace=True),
bn: bool = False,
init=nn.init.kaiming_normal,
init=nn.init.kaiming_normal_,
bias: bool = True,
preact: bool = False,
name: str = ""
Expand Down Expand Up @@ -356,7 +358,7 @@ def __init__(
padding: Tuple[int, int, int] = (0, 0, 0),
activation=nn.ReLU(inplace=True),
bn: bool = False,
init=nn.init.kaiming_normal,
init=nn.init.kaiming_normal_,
bias: bool = True,
preact: bool = False,
name: str = ""
Expand Down Expand Up @@ -397,7 +399,7 @@ def __init__(
if init is not None:
init(fc.weight)
if not bn:
nn.init.constant(fc.bias, 0)
nn.init.constant_(fc.bias, 0)

if preact:
if bn:
Expand Down