diff --git a/nerfacc/volrend.py b/nerfacc/volrend.py index 34ff1c0..fdb8cbe 100644 --- a/nerfacc/volrend.py +++ b/nerfacc/volrend.py @@ -88,11 +88,12 @@ def rendering( # Query sigma/alpha and color with gradients if rgb_sigma_fn is not None: - if t_starts.shape[0] != 0: - rgbs, sigmas = rgb_sigma_fn(t_starts, t_ends, ray_indices) - else: - rgbs = torch.empty((0, 3), device=t_starts.device) - sigmas = torch.empty((0,), device=t_starts.device) + rgbs, sigmas = rgb_sigma_fn(t_starts, t_ends, ray_indices) + # if t_starts.shape[0] != 0: + # rgbs, sigmas = rgb_sigma_fn(t_starts, t_ends, ray_indices) + # else: + # rgbs = torch.empty((0, 3), device=t_starts.device) + # sigmas = torch.empty((0,), device=t_starts.device) assert rgbs.shape[-1] == 3, "rgbs must have 3 channels, got {}".format( rgbs.shape ) @@ -115,11 +116,12 @@ def rendering( "rgbs": rgbs, } elif rgb_alpha_fn is not None: - if t_starts.shape[0] != 0: - rgbs, alphas = rgb_alpha_fn(t_starts, t_ends, ray_indices) - else: - rgbs = torch.empty((0, 3), device=t_starts.device) - alphas = torch.empty((0,), device=t_starts.device) + rgbs, alphas = rgb_alpha_fn(t_starts, t_ends, ray_indices) + # if t_starts.shape[0] != 0: + # rgbs, alphas = rgb_alpha_fn(t_starts, t_ends, ray_indices) + # else: + # rgbs = torch.empty((0, 3), device=t_starts.device) + # alphas = torch.empty((0,), device=t_starts.device) assert rgbs.shape[-1] == 3, "rgbs must have 3 channels, got {}".format( rgbs.shape )