Skip to content

Commit

Permalink
u
Browse files Browse the repository at this point in the history
  • Loading branch information
AsDeadAsADodo committed Apr 12, 2023
1 parent 22680d4 commit 56d61f3
Show file tree
Hide file tree
Showing 4 changed files with 139 additions and 64 deletions.
12 changes: 3 additions & 9 deletions datasets/dtu_yao.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,26 +100,20 @@ def __getitem__(self, idx):
proj_matrices.append(proj_mat)

if i == 0: # reference view
depth_values = np.arange(depth_min, depth_interval * self.ndepths + depth_min, depth_interval,
depth_values = 1.0/np.linspace(1.0/depth_min, 1.0/(depth_interval * self.ndepths + depth_min), self.ndepths,
dtype=np.float32)
mask = self.read_img(mask_filename)
depth = self.read_depth(depth_filename)

imgs = np.stack(imgs).transpose([0, 3, 1, 2])
proj_matrices = np.stack(proj_matrices)

return {"imgs": imgs,
return {
"imgs": imgs,
"proj_matrices": proj_matrices,
"depth": depth,
"depth_values": depth_values,
"mask": mask,
"intrinsics":intrinsics_list,
"extrinsics":extrinsics_list,
"depth_planes":{
"depth_start":depth_min,
"number":self.ndepths,
"depth_interval":depth_interval,
}
}


Expand Down
63 changes: 63 additions & 0 deletions models/convgru.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable

from functools import reduce
from operator import __add__

class ConvGRUCell(nn.Module):

def __init__(self,
input_channel,
output_channel,
kernel,
activation=nn.Tanh(),
normalize=True):
super(ConvGRUCell, self).__init__()
self._input_channel = input_channel
self._output_channel = output_channel
self._kernel = kernel
self._activation = activation
self._normalize = normalize
self._feature_axis = 1

# Internal parameters used to reproduce Tensorflow "Same" padding.
# For some reasons, padding dimensions are reversed wrt kernel sizes,
# first comes width then height in the 2D case.
#conv_padding = reduce(__add__,
#[(k // 2 + (k - 2 * (k // 2)) - 1, k // 2) for k in self._kernel[::-1]])
#pad = nn.ZeroPad2d(conv_padding)
self.gate_conv = nn.Conv2d(self.input_channel, self.input_channel, self._kernel,padding='same')
self.conv2d = nn.Conv2d(self.input_channel, self.output_channel, self._kernel,padding='same')

self.reset_gate_norm = nn.InstanceNorm2d(input_channel_number,affine=True)
self.update_gate_norm = nn.InstanceNorm2d(input_channel_number,affine=True)

self.output_norm = nn.GroupNorm(1, self._input_channel, 1e-5, True)


def forward(self,x,h):
# x shape = (B,D,H,W)
inputs = Variable(torch.cat((x,h),self._feature_axis))
gate_conv = self.gate_conv(inputs)
reset_gate, update_gate = torch.split(gate_conv, gate_conv.shape[1] // 2, self._feature_axis)

reset_gate = self.reset_gate_norm(reset_gate)
update_gate = self.reset_gate_norm(update_gate)

reset_gate = torch.sigmoid(reset_gate)
reset_gate = torch.sigmoid(update_gate)

inputs = Variable(torch.cat((x,reset_gate * h),self._feature_axis))

conv = self.conv2d(inputs)
conv = self.output_norm(conv)

y = self._activation(conv)

output = update_gate * h + (1-update_gate) * y

return Variable(output),Variable(output)


32 changes: 32 additions & 0 deletions models/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,39 @@ def forward(self, x):
dconv1 = F.relu(self.dconv1(dconv2) + self.redir1(x), inplace=True)
return dconv1

def homo_warping_depthwise(src_fea, src_proj, ref_proj, depth_value):
# src_fea: [B, C, H, W]
# src_proj: [B, 4, 4]
# ref_proj: [B, 4, 4]
# depth_value: [B] # TODO: B, 1
# out: [B, C, H, W]
batch, channels = src_fea.shape[0], src_fea.shape[1]
height, width = src_fea.shape[2], src_fea.shape[3]

with torch.no_grad():
proj = torch.matmul(src_proj, torch.inverse(ref_proj))
rot = proj[:, :3, :3] # [B,3,3]
trans = proj[:, :3, 3:4] # [B,3,1]

y, x = torch.meshgrid([torch.arange(0, height, dtype=torch.float32, device=src_fea.device),
torch.arange(0, width, dtype=torch.float32, device=src_fea.device)])
y, x = y.contiguous(), x.contiguous()
y, x = y.view(height * width), x.view(height * width)
xyz = torch.stack((x, y, torch.ones_like(x))) # [3, H*W]
xyz = torch.unsqueeze(xyz, 0).repeat(batch, 1, 1) # [B, 3, H*W]
rot_xyz = torch.matmul(rot, xyz) # [B, 3, H*W]
rot_depth_xyz = rot_xyz.repeat(1, 1, 1) * depth_value.view(batch, 1, 1)
proj_xyz = rot_depth_xyz + trans.view(batch, 3, 1) # [B, 3, Ndepth, H*W]
proj_xyz[:,2:3,:][proj_xyz[:, 2:3,:] == 0] += 0.0001 # WHY BUG
proj_xy = proj_xyz[:, :2, :] / proj_xyz[:, 2:3, :] # [B, 2, Ndepth, H*W]
proj_x_normalized = proj_xy[:, 0, :] / ((width - 1) / 2) - 1
proj_y_normalized = proj_xy[:, 1, :] / ((height - 1) / 2) - 1
proj_xy = torch.stack((proj_x_normalized, proj_y_normalized), dim=2) # [B, Ndepth, H*W, 2]
grid = proj_xy

warped_src_fea = F.grid_sample(src_fea, grid.view(batch, 1 * height, width, 2), mode='bilinear',
padding_mode='zeros').type(torch.float32)
return warped_src_fea
def homo_warping(src_fea, src_proj, ref_proj, depth_values):
# src_fea: [B, C, H, W]
# src_proj: [B, 4, 4]
Expand Down
96 changes: 41 additions & 55 deletions models/mvsnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from .module import *
from .warping import get_homographies, warp_homographies
from .gru import GRU
from .convgru import ConvGRUCell


class FeatureNet(nn.Module):
Expand Down Expand Up @@ -119,80 +120,65 @@ def __init__(self, refine=True):
self.feature = FeatureNet()
self.cost_regularization = CostConvGRURegNet()

self.conv2d = nn.Conv2d(2, 1, (3,3),padding='same')
if self.refine:
self.refine_network = RefineNet()

def compute_cost_volume(self, warped):
'''
构建方差代价体
Warped: N x C x M x H x W
returns: 1 x C x M x H x W
'''
warped_sq = warped ** 2
av_warped = warped.mean(0)
av_warped_sq = warped_sq.mean(0)
cost = av_warped_sq - (av_warped ** 2)

return cost.unsqueeze(0)

def compute_depth(self, prob_volume, depth_start, depth_interval, depth_num):
'''
计算深度图?需要确定
prob_volume: 1 x D x H x W
'''
_, M, H, W = prob_volume.shape
# prob_indices = HW shaped vector
probs, indices = prob_volume.max(1)
depth_range = depth_start + torch.arange(depth_num).float() * depth_interval
depth_range = depth_range.to(prob_volume.device)
depths = torch.index_select(depth_range, 0, indices.flatten())
depth_image = depths.view(H, W)
prob_image = probs.view(H, W)

return depth_image, prob_image

def forward(self, imgs, intrinsics, extrinsics,depth_planes):
def forward(self, imgs, proj_matrices,depth_value):
imgs = torch.unbind(imgs, 1)
proj_matrices = torch.unbind(proj_matrices,1)

assert len(imgs) == len(proj_matrices)

num_depth = len(depth_values)
# step 1. feature extraction
# in: images; out: 32-channel feature maps
features = [self.feature(img) for img in imgs]
features = torch.tensor([item.cpu().detach().numpy() for item in features]).cuda().squeeze()
intrinsics = torch.tensor([item.cpu().detach().numpy() for item in intrinsics]).cuda().squeeze()
extrinsics = torch.tensor([item.cpu().detach().numpy() for item in extrinsics]).cuda().squeeze()
ref_feature,src_features = features[0],features[1:]
ref_proj,src_projs = proj_matrices[0],proj_matrices[1:]


# step 2. 可微单应性变换 + 代价体GRU正则
# 以下三个属性为硬编码读取,如果数据集内深度平面信息不同,需要修改
number_of_depth_planes = depth_planes["number"].item()
depth_interval = depth_planes["depth_interval"].item()
depth_start = depth_planes["depth_start"].item()
Hs = get_homographies(features, intrinsics, extrinsics, depth_start, depth_interval, number_of_depth_planes)

# N, C, D, H, W = warped.shape
depth_costs = []
costs_volume_reg = []

gru1_input_channel = 32
gru1_output_channel = 16
gru2_output_channel = 4
gru3_output_channel = 2
state1 = torch.zeros(B,gru1_fiters,H,W)
state2 = torch.zeros(B,gru2_fiters,H,W)
state3 = torch.zeros(B,gru3_fiters,H,W)

convGRUCell1 = ConvGRUCell(input_channel=gru1_input_channel,kernel=[3,3],output_channel=gru1_output_channel)
convGRUCell2 = ConvGRUCell(input_channel=gru1_output_channel,kernel=[3,3],output_channel=gru2_output_channel)
convGRUCell3 = ConvGRUCell(input_channel=gru2_output_channel,kernel=[3,3],output_channel=gru3_output_channel)



for d in range(number_of_depth_planes):
# 参考图像特征图
ref_f = features[:1]

# 单应变换到参考图像虚拟平面的特征
warped = warp_homographies(features[1:], Hs[1:, d])
all_f = torch.cat((ref_f, warped), 0)

# cost_d = 1 x C x H x W
cost_d = self.compute_cost_volume(all_f)
reg_cost = self.cost_regularization(-cost_d)

depth_costs.append(reg_cost)
ref_volume = ref_feature
warped_volume = None
for src_fea,src_proj in zip(src_features,src_projs):
warped_volume = homo_warping_depthwise(src_fea, src_proj, ref_proj, depth_values[:, d])
warped_volume = (warped_volume - ref_volume).pow_(2)
volume_variance = warped_volumes / len(src_features)
cost_map_reg1,state1 = convGRUCell1(-volume_variance,state1)
cost_map_reg2,state2 = convGRUCell2(cost_map_reg1,state2)
cost_map_reg3,state3 = convGRUCell3(cost_map_reg2,state3)
cost_map_reg = self.conv2d(cost_map_reg3)
costs_volume_reg.append(cost_map_reg)

prob_volume = torch.cat(depth_costs, 1)
prob_volume = torch.cat(costs_volume_reg, 1).squeeze(2)
#print(prob_volume.shape)
# softmax prob_volume
softmax_probs = torch.softmax(prob_volume, 1)

depth, prob_image = self.compute_depth(softmax_probs, depth_start, depth_interval, number_of_depth_planes)


return {'prob_volume': softmax_probs}


# step 4. depth map refinement
#if not self.refine:
Expand All @@ -205,7 +191,7 @@ def forward(self, imgs, intrinsics, extrinsics,depth_planes):
def mvsnet_loss(depth_est, depth_gt, mask,depth_value,return_prob_map=False):
# depth_value: B * NUM
# get depth mask
mask_true = mask > 0.5
mask_true = mask
valid_pixel_num = torch.sum(mask_true, dim=[1,2]) + 1e-6

shape = depth_gt.shape
Expand Down

0 comments on commit 56d61f3

Please sign in to comment.