diff --git a/nerfstudio/models/gaussian_splatting.py b/nerfstudio/models/gaussian_splatting.py index f8a02dae58..8092db00a0 100644 --- a/nerfstudio/models/gaussian_splatting.py +++ b/nerfstudio/models/gaussian_splatting.py @@ -35,7 +35,11 @@ from nerfstudio.cameras.cameras import Cameras from nerfstudio.data.scene_box import OrientedBox -from nerfstudio.engine.callbacks import TrainingCallback, TrainingCallbackAttributes, TrainingCallbackLocation +from nerfstudio.engine.callbacks import ( + TrainingCallback, + TrainingCallbackAttributes, + TrainingCallbackLocation, +) from nerfstudio.engine.optimizers import Optimizers # need following import for background color override @@ -79,7 +83,9 @@ def SH2RGB(sh): return sh * C0 + 0.5 -def projection_matrix(znear, zfar, fovx, fovy, device: Union[str, torch.device] = "cpu"): +def projection_matrix( + znear, zfar, fovx, fovy, device: Union[str, torch.device] = "cpu" +): """ Constructs an OpenGL-style perspective projection matrix. """ @@ -111,7 +117,7 @@ class GaussianSplattingModelConfig(ModelConfig): """period of steps where gaussians are culled and densified""" resolution_schedule: int = 250 """training starts at 1/d resolution, every n steps this is doubled""" - background_color: Literal["random", "black", "white"] = "random" + background_color: Literal["random", "black", "white"] = "white" """Whether to randomize the background color.""" num_downscales: int = 0 """at the beginning, resolution is 1/2^d, where d is this number""" @@ -196,9 +202,13 @@ def populate_modules(self): self.features_rest = torch.nn.Parameter(shs[:, 1:, :]) else: self.features_dc = torch.nn.Parameter(torch.rand(self.num_points, 3)) - self.features_rest = torch.nn.Parameter(torch.zeros((self.num_points, dim_sh - 1, 3))) + self.features_rest = torch.nn.Parameter( + torch.zeros((self.num_points, dim_sh - 1, 3)) + ) - self.opacities = torch.nn.Parameter(torch.logit(0.1 * torch.ones(self.num_points, 1))) + self.opacities = torch.nn.Parameter( + torch.logit(0.1 * torch.ones(self.num_points, 1)) + ) # metrics from torchmetrics.image import PeakSignalNoiseRatio @@ -240,7 +250,9 @@ def load_state_dict(self, dict, **kwargs): # type: ignore self.opacities = torch.nn.Parameter(torch.zeros(newp, 1, device=self.device)) self.features_dc = torch.nn.Parameter(torch.zeros(newp, 3, device=self.device)) self.features_rest = torch.nn.Parameter( - torch.zeros(newp, num_sh_bases(self.config.sh_degree) - 1, 3, device=self.device) + torch.zeros( + newp, num_sh_bases(self.config.sh_degree) - 1, 3, device=self.device + ) ) super().load_state_dict(dict, **kwargs) @@ -256,7 +268,9 @@ def k_nearest_sklearn(self, x: torch.Tensor, k: int): # Build the nearest neighbors model from sklearn.neighbors import NearestNeighbors - nn_model = NearestNeighbors(n_neighbors=k + 1, algorithm="auto", metric="euclidean").fit(x_np) + nn_model = NearestNeighbors( + n_neighbors=k + 1, algorithm="auto", metric="euclidean" + ).fit(x_np) # Find the k-nearest neighbors distances, indices = nn_model.kneighbors(x_np) @@ -295,13 +309,20 @@ def dup_in_optim(self, optimizer, dup_mask, new_params, n=2): param_state = optimizer.state[param] repeat_dims = (n,) + tuple(1 for _ in range(param_state["exp_avg"].dim() - 1)) param_state["exp_avg"] = torch.cat( - [param_state["exp_avg"], torch.zeros_like(param_state["exp_avg"][dup_mask.squeeze()]).repeat(*repeat_dims)], + [ + param_state["exp_avg"], + torch.zeros_like(param_state["exp_avg"][dup_mask.squeeze()]).repeat( + *repeat_dims + ), + ], dim=0, ) param_state["exp_avg_sq"] = torch.cat( [ param_state["exp_avg_sq"], - torch.zeros_like(param_state["exp_avg_sq"][dup_mask.squeeze()]).repeat(*repeat_dims), + torch.zeros_like(param_state["exp_avg_sq"][dup_mask.squeeze()]).repeat( + *repeat_dims + ), ], dim=0, ) @@ -332,14 +353,17 @@ def after_train(self, step: int): else: assert self.vis_counts is not None self.vis_counts[visible_mask] = self.vis_counts[visible_mask] + 1 - self.xys_grad_norm[visible_mask] = grads[visible_mask] + self.xys_grad_norm[visible_mask] + self.xys_grad_norm[visible_mask] = ( + grads[visible_mask] + self.xys_grad_norm[visible_mask] + ) # update the max screen size, as a ratio of number of pixels if self.max_2Dsize is None: self.max_2Dsize = torch.zeros_like(self.radii, dtype=torch.float32) newradii = self.radii.detach()[visible_mask] self.max_2Dsize[visible_mask] = torch.maximum( - self.max_2Dsize[visible_mask], newradii / float(max(self.last_size[0], self.last_size[1])) + self.max_2Dsize[visible_mask], + newradii / float(max(self.last_size[0], self.last_size[1])), ) def set_crop(self, crop_box: Optional[OrientedBox]): @@ -361,16 +385,30 @@ def refinement_after(self, optimizers: Optimizers, step): reset_interval = self.config.reset_alpha_every * self.config.refine_every do_densification = ( self.step < self.config.stop_split_at - and self.step % reset_interval > self.num_train_data + self.config.refine_every + and self.step % reset_interval + > self.num_train_data + self.config.refine_every ) if do_densification: # then we densify - assert self.xys_grad_norm is not None and self.vis_counts is not None and self.max_2Dsize is not None - avg_grad_norm = (self.xys_grad_norm / self.vis_counts) * 0.5 * max(self.last_size[0], self.last_size[1]) + assert ( + self.xys_grad_norm is not None + and self.vis_counts is not None + and self.max_2Dsize is not None + ) + avg_grad_norm = ( + (self.xys_grad_norm / self.vis_counts) + * 0.5 + * max(self.last_size[0], self.last_size[1]) + ) high_grads = (avg_grad_norm > self.config.densify_grad_thresh).squeeze() - splits = (self.scales.exp().max(dim=-1).values > self.config.densify_size_thresh).squeeze() + splits = ( + self.scales.exp().max(dim=-1).values + > self.config.densify_size_thresh + ).squeeze() if self.step < self.config.stop_screen_size_at: - splits |= (self.max_2Dsize > self.config.split_screen_size).squeeze() + splits |= ( + self.max_2Dsize > self.config.split_screen_size + ).squeeze() splits &= high_grads nsamps = self.config.n_split_samples ( @@ -382,7 +420,10 @@ def refinement_after(self, optimizers: Optimizers, step): split_quats, ) = self.split_gaussians(splits, nsamps) - dups = (self.scales.exp().max(dim=-1).values <= self.config.densify_size_thresh).squeeze() + dups = ( + self.scales.exp().max(dim=-1).values + <= self.config.densify_size_thresh + ).squeeze() dups &= high_grads ( dup_means, @@ -392,19 +433,43 @@ def refinement_after(self, optimizers: Optimizers, step): dup_scales, dup_quats, ) = self.dup_gaussians(dups) - self.means = Parameter(torch.cat([self.means.detach(), split_means, dup_means], dim=0)) + self.means = Parameter( + torch.cat([self.means.detach(), split_means, dup_means], dim=0) + ) self.features_dc = Parameter( - torch.cat([self.features_dc.detach(), split_features_dc, dup_features_dc], dim=0) + torch.cat( + [self.features_dc.detach(), split_features_dc, dup_features_dc], + dim=0, + ) ) self.features_rest = Parameter( - torch.cat([self.features_rest.detach(), split_features_rest, dup_features_rest], dim=0) + torch.cat( + [ + self.features_rest.detach(), + split_features_rest, + dup_features_rest, + ], + dim=0, + ) + ) + self.opacities = Parameter( + torch.cat( + [self.opacities.detach(), split_opacities, dup_opacities], dim=0 + ) + ) + self.scales = Parameter( + torch.cat([self.scales.detach(), split_scales, dup_scales], dim=0) + ) + self.quats = Parameter( + torch.cat([self.quats.detach(), split_quats, dup_quats], dim=0) ) - self.opacities = Parameter(torch.cat([self.opacities.detach(), split_opacities, dup_opacities], dim=0)) - self.scales = Parameter(torch.cat([self.scales.detach(), split_scales, dup_scales], dim=0)) - self.quats = Parameter(torch.cat([self.quats.detach(), split_quats, dup_quats], dim=0)) # append zeros to the max_2Dsize tensor self.max_2Dsize = torch.cat( - [self.max_2Dsize, torch.zeros_like(split_scales[:, 0]), torch.zeros_like(dup_scales[:, 0])], + [ + self.max_2Dsize, + torch.zeros_like(split_scales[:, 0]), + torch.zeros_like(dup_scales[:, 0]), + ], dim=0, ) @@ -416,11 +481,21 @@ def refinement_after(self, optimizers: Optimizers, step): # After a guassian is split into two new gaussians, the original one should also be pruned. splits_mask = torch.cat( - (splits, torch.zeros(nsamps * splits.sum() + dups.sum(), device=self.device, dtype=torch.bool)) + ( + splits, + torch.zeros( + nsamps * splits.sum() + dups.sum(), + device=self.device, + dtype=torch.bool, + ), + ) ) deleted_mask = self.cull_gaussians(splits_mask) - elif self.step >= self.config.stop_split_at and self.config.continue_cull_post_densification: + elif ( + self.step >= self.config.stop_split_at + and self.config.continue_cull_post_densification + ): deleted_mask = self.cull_gaussians() else: # if we donot allow culling post refinement, no more gaussians will be pruned. @@ -429,11 +504,17 @@ def refinement_after(self, optimizers: Optimizers, step): if deleted_mask is not None: self.remove_from_all_optim(optimizers, deleted_mask) - if self.step < self.config.stop_split_at and self.step % reset_interval == self.config.refine_every: + if ( + self.step < self.config.stop_split_at + and self.step % reset_interval == self.config.refine_every + ): # Reset value is set to be twice of the cull_alpha_thresh reset_value = self.config.cull_alpha_thresh * 2.0 self.opacities.data = torch.clamp( - self.opacities.data, max=torch.logit(torch.tensor(reset_value, device=self.device)).item() + self.opacities.data, + max=torch.logit( + torch.tensor(reset_value, device=self.device) + ).item(), ) # reset the exp of optimizer optim = optimizers.optimizers["opacity"] @@ -453,18 +534,25 @@ def cull_gaussians(self, extra_cull_mask: Optional[torch.Tensor] = None): """ n_bef = self.num_points # cull transparent ones - culls = (torch.sigmoid(self.opacities) < self.config.cull_alpha_thresh).squeeze() + culls = ( + torch.sigmoid(self.opacities) < self.config.cull_alpha_thresh + ).squeeze() below_alpha_count = torch.sum(culls).item() toobigs_count = 0 if extra_cull_mask is not None: culls = culls | extra_cull_mask if self.step > self.config.refine_every * self.config.reset_alpha_every: # cull huge ones - toobigs = (torch.exp(self.scales).max(dim=-1).values > self.config.cull_scale_thresh).squeeze() + toobigs = ( + torch.exp(self.scales).max(dim=-1).values + > self.config.cull_scale_thresh + ).squeeze() if self.step < self.config.stop_screen_size_at: # cull big screen space assert self.max_2Dsize is not None - toobigs = toobigs | (self.max_2Dsize > self.config.cull_screen_size).squeeze() + toobigs = ( + toobigs | (self.max_2Dsize > self.config.cull_screen_size).squeeze() + ) culls = culls | toobigs toobigs_count = torch.sum(toobigs).item() self.means = Parameter(self.means[~culls].detach()) @@ -487,12 +575,18 @@ def split_gaussians(self, split_mask, samps): """ n_splits = split_mask.sum().item() - CONSOLE.log(f"Splitting {split_mask.sum().item()/self.num_points} gaussians: {n_splits}/{self.num_points}") - centered_samples = torch.randn((samps * n_splits, 3), device=self.device) # Nx3 of axis-aligned scales + CONSOLE.log( + f"Splitting {split_mask.sum().item()/self.num_points} gaussians: {n_splits}/{self.num_points}" + ) + centered_samples = torch.randn( + (samps * n_splits, 3), device=self.device + ) # Nx3 of axis-aligned scales scaled_samples = ( torch.exp(self.scales[split_mask].repeat(samps, 1)) * centered_samples ) # how these scales are rotated - quats = self.quats[split_mask] / self.quats[split_mask].norm(dim=-1, keepdim=True) # normalize them first + quats = self.quats[split_mask] / self.quats[split_mask].norm( + dim=-1, keepdim=True + ) # normalize them first rots = quat_to_rotmat(quats.repeat(samps, 1)) # how these scales are rotated rotated_samples = torch.bmm(rots, scaled_samples[..., None]).squeeze() new_means = rotated_samples + self.means[split_mask].repeat(samps, 1) @@ -503,25 +597,45 @@ def split_gaussians(self, split_mask, samps): new_opacities = self.opacities[split_mask].repeat(samps, 1) # step 4, sample new scales size_fac = 1.6 - new_scales = torch.log(torch.exp(self.scales[split_mask]) / size_fac).repeat(samps, 1) - self.scales[split_mask] = torch.log(torch.exp(self.scales[split_mask]) / size_fac) + new_scales = torch.log(torch.exp(self.scales[split_mask]) / size_fac).repeat( + samps, 1 + ) + self.scales[split_mask] = torch.log( + torch.exp(self.scales[split_mask]) / size_fac + ) # step 5, sample new quats new_quats = self.quats[split_mask].repeat(samps, 1) - return new_means, new_features_dc, new_features_rest, new_opacities, new_scales, new_quats + return ( + new_means, + new_features_dc, + new_features_rest, + new_opacities, + new_scales, + new_quats, + ) def dup_gaussians(self, dup_mask): """ This function duplicates gaussians that are too small """ n_dups = dup_mask.sum().item() - CONSOLE.log(f"Duplicating {dup_mask.sum().item()/self.num_points} gaussians: {n_dups}/{self.num_points}") + CONSOLE.log( + f"Duplicating {dup_mask.sum().item()/self.num_points} gaussians: {n_dups}/{self.num_points}" + ) dup_means = self.means[dup_mask] dup_features_dc = self.features_dc[dup_mask] dup_features_rest = self.features_rest[dup_mask] dup_opacities = self.opacities[dup_mask] dup_scales = self.scales[dup_mask] dup_quats = self.quats[dup_mask] - return dup_means, dup_features_dc, dup_features_rest, dup_opacities, dup_scales, dup_quats + return ( + dup_means, + dup_features_dc, + dup_features_rest, + dup_opacities, + dup_scales, + dup_quats, + ) @property def num_points(self): @@ -531,7 +645,11 @@ def get_training_callbacks( self, training_callback_attributes: TrainingCallbackAttributes ) -> List[TrainingCallback]: cbs = [] - cbs.append(TrainingCallback([TrainingCallbackLocation.BEFORE_TRAIN_ITERATION], self.step_cb)) + cbs.append( + TrainingCallback( + [TrainingCallbackLocation.BEFORE_TRAIN_ITERATION], self.step_cb + ) + ) # The order of these matters cbs.append( TrainingCallback( @@ -573,7 +691,13 @@ def get_param_groups(self) -> Dict[str, List[Parameter]]: def _get_downscale_factor(self): if self.training: - return 2 ** max((self.config.num_downscales - self.step // self.config.resolution_schedule), 0) + return 2 ** max( + ( + self.config.num_downscales + - self.step // self.config.resolution_schedule + ), + 0, + ) else: return 1 @@ -591,10 +715,7 @@ def get_outputs(self, camera: Cameras) -> Dict[str, Union[torch.Tensor, List]]: print("Called get_outputs with not a camera") return {} assert camera.shape[0] == 1, "Only one camera at a time" - if self.training: - # currently relies on the branch vickie/camera-grads - self.camera_optimizer.apply_to_camera(camera) - + # get the background color if self.training: if self.config.background_color == "random": @@ -610,7 +731,11 @@ def get_outputs(self, camera: Cameras) -> Dict[str, Union[torch.Tensor, List]]: if self.crop_box is not None and not self.training: crop_ids = self.crop_box.within(self.means).squeeze() if crop_ids.sum() == 0: - return {"rgb": background.repeat(int(camera.height.item()), int(camera.width.item()), 1)} + return { + "rgb": background.repeat( + int(camera.height.item()), int(camera.width.item()), 1 + ) + } else: crop_ids = None camera_downscale = self._get_downscale_factor() @@ -619,7 +744,9 @@ def get_outputs(self, camera: Cameras) -> Dict[str, Union[torch.Tensor, List]]: R = camera.camera_to_worlds[0, :3, :3] # 3 x 3 T = camera.camera_to_worlds[0, :3, 3:4] # 3 x 1 # flip the z and y axes to align with gsplat conventions - R_edit = torch.diag(torch.tensor([1, -1, -1], device=self.device, dtype=R.dtype)) + R_edit = torch.diag( + torch.tensor([1, -1, -1], device=self.device, dtype=R.dtype) + ) R = R @ R_edit # analytic matrix inverse to get world2camera matrix R_inv = R.T @@ -657,7 +784,9 @@ def get_outputs(self, camera: Cameras) -> Dict[str, Union[torch.Tensor, List]]: scales_crop = self.scales quats_crop = self.quats - colors_crop = torch.cat((features_dc_crop[:, None, :], features_rest_crop), dim=1) + colors_crop = torch.cat( + (features_dc_crop[:, None, :], features_rest_crop), dim=1 + ) self.xys, depths, self.radii, conics, num_tiles_hit, cov3d = project_gaussians( # type: ignore means_crop, @@ -675,14 +804,20 @@ def get_outputs(self, camera: Cameras) -> Dict[str, Union[torch.Tensor, List]]: tile_bounds, ) # type: ignore if (self.radii).sum() == 0: - return {"rgb": background.repeat(int(camera.height.item()), int(camera.width.item()), 1)} + return { + "rgb": background.repeat( + int(camera.height.item()), int(camera.width.item()), 1 + ) + } # Important to allow xys grads to populate properly if self.training: self.xys.retain_grad() if self.config.sh_degree > 0: - viewdirs = means_crop.detach() - camera.camera_to_worlds.detach()[..., :3, 3] # (N, 3) + viewdirs = ( + means_crop.detach() - camera.camera_to_worlds.detach()[..., :3, 3] + ) # (N, 3) viewdirs = viewdirs / viewdirs.norm(dim=-1, keepdim=True) n = min(self.step // self.config.sh_degree_interval, self.config.sh_degree) rgbs = spherical_harmonics(n, viewdirs, colors_crop) @@ -721,7 +856,9 @@ def get_outputs(self, camera: Cameras) -> Dict[str, Union[torch.Tensor, List]]: H, W, background=torch.ones(3, device=self.device) * 10, - )[..., 0:1] # type: ignore + )[ + ..., 0:1 + ] # type: ignore return {"rgb": rgb, "depth": depth_im} # type: ignore @@ -740,7 +877,9 @@ def get_gt_img(self, image: torch.Tensor): # torchvision can be slow to import, so we do it lazily. import torchvision.transforms.functional as TF - gt_img = TF.resize(image.permute(2, 0, 1), newsize, antialias=None).permute(1, 2, 0) + gt_img = TF.resize(image.permute(2, 0, 1), newsize, antialias=None).permute( + 1, 2, 0 + ) else: gt_img = image return gt_img.to(self.device) @@ -760,7 +899,9 @@ def get_metrics_dict(self, outputs, batch) -> Dict[str, torch.Tensor]: metrics_dict["gaussian_count"] = self.num_points return metrics_dict - def get_loss_dict(self, outputs, batch, metrics_dict=None) -> Dict[str, torch.Tensor]: + def get_loss_dict( + self, outputs, batch, metrics_dict=None + ) -> Dict[str, torch.Tensor]: """Computes and returns the losses dict. Args: @@ -770,12 +911,16 @@ def get_loss_dict(self, outputs, batch, metrics_dict=None) -> Dict[str, torch.Te """ gt_img = self.get_gt_img(batch["image"]) Ll1 = torch.abs(gt_img - outputs["rgb"]).mean() - simloss = 1 - self.ssim(gt_img.permute(2, 0, 1)[None, ...], outputs["rgb"].permute(2, 0, 1)[None, ...]) + simloss = 1 - self.ssim( + gt_img.permute(2, 0, 1)[None, ...], + outputs["rgb"].permute(2, 0, 1)[None, ...], + ) if self.config.use_scale_regularization and self.step % 10 == 0: scale_exp = torch.exp(self.scales) scale_reg = ( torch.maximum( - scale_exp.amax(dim=-1) / scale_exp.amin(dim=-1), torch.tensor(self.config.max_gauss_ratio) + scale_exp.amax(dim=-1) / scale_exp.amin(dim=-1), + torch.tensor(self.config.max_gauss_ratio), ) - self.config.max_gauss_ratio ) @@ -784,12 +929,15 @@ def get_loss_dict(self, outputs, batch, metrics_dict=None) -> Dict[str, torch.Te scale_reg = torch.tensor(0.0).to(self.device) return { - "main_loss": (1 - self.config.ssim_lambda) * Ll1 + self.config.ssim_lambda * simloss, + "main_loss": (1 - self.config.ssim_lambda) * Ll1 + + self.config.ssim_lambda * simloss, "scale_reg": scale_reg, } @torch.no_grad() - def get_outputs_for_camera(self, camera: Cameras, obb_box: Optional[OrientedBox] = None) -> Dict[str, torch.Tensor]: + def get_outputs_for_camera( + self, camera: Cameras, obb_box: Optional[OrientedBox] = None + ) -> Dict[str, torch.Tensor]: """Takes in a camera, generates the raybundle, and computes the output of the model. Overridden for a camera-based gaussian model. @@ -822,7 +970,9 @@ def get_image_metrics_and_images( import torchvision.transforms.functional as TF newsize = [batch["image"].shape[0] // d, batch["image"].shape[1] // d] - predicted_rgb = TF.resize(outputs["rgb"].permute(2, 0, 1), newsize, antialias=None).permute(1, 2, 0) + predicted_rgb = TF.resize( + outputs["rgb"].permute(2, 0, 1), newsize, antialias=None + ).permute(1, 2, 0) else: predicted_rgb = outputs["rgb"]