Skip to content

Commit

Permalink
add continue_cull_post_densification option and some refactor
Browse files Browse the repository at this point in the history
  • Loading branch information
jb-ye authored and Jianbo Ye committed Jan 7, 2024
1 parent c2db378 commit f784463
Show file tree
Hide file tree
Showing 2 changed files with 118 additions and 93 deletions.
2 changes: 1 addition & 1 deletion nerfstudio/configs/method_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -601,7 +601,7 @@
method_configs["gaussian-splatting"] = TrainerConfig(
method_name="gaussian-splatting",
steps_per_eval_image=100,
steps_per_eval_batch=100,
steps_per_eval_batch=0,
steps_per_save=2000,
steps_per_eval_all_images=1000,
max_num_iterations=30000,
Expand Down
209 changes: 117 additions & 92 deletions nerfstudio/models/gaussian_splatting.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,8 @@ class GaussianSplattingModelConfig(ModelConfig):
"""threshold of opacity for culling gaussians. One can set it to a lower value (e.g. 0.005) for higher quality."""
cull_scale_thresh: float = 0.5
"""threshold of scale for culling huge gaussians"""
continue_cull_post_densification: bool = True
"""If True, continue to cull gaussians post refinement"""
reset_alpha_every: int = 30
"""Every this many refinement steps, reset the alpha"""
densify_grad_thresh: float = 0.0002
Expand Down Expand Up @@ -283,6 +285,12 @@ def remove_from_optim(self, optimizer, deleted_mask, new_params):
optimizer.param_groups[0]["params"] = new_params
optimizer.state[new_params[0]] = param_state

def remove_from_all_optim(self, optimizers, deleted_mask):
param_groups = self.get_gaussian_param_groups()
for group, param in param_groups.items():
self.remove_from_optim(optimizers.optimizers[group], deleted_mask, param)
torch.cuda.empty_cache()

def dup_in_optim(self, optimizer, dup_mask, new_params, n=2):
"""adds the parameters to the optimizer"""
param = optimizer.param_groups[0]["params"][0]
Expand All @@ -304,7 +312,16 @@ def dup_in_optim(self, optimizer, dup_mask, new_params, n=2):
optimizer.param_groups[0]["params"] = new_params
del param

def dup_in_all_optim(self, optimizers, dup_mask, n):
param_groups = self.get_gaussian_param_groups()
for group, param in param_groups.items():
self.dup_in_optim(optimizers.optimizers[group], dup_mask, param, n)

def after_train(self, step: int):
assert step == self.step
# to save some training time, we no longer need to update those stats post refinement
if self.step >= self.config.stop_split_at:
return
with torch.no_grad():
# keep track of a moving average of grad norms
visible_mask = (self.radii > 0).flatten()
Expand Down Expand Up @@ -334,100 +351,108 @@ def set_background(self, back_color: torch.Tensor):
self.back_color = back_color

def refinement_after(self, optimizers: Optimizers, step):
if self.step > self.config.warmup_length and self.step < self.config.stop_split_at:
with torch.no_grad():
# Offset all the opacity reset logic by refine_every so that we don't
# save checkpoints right when the opacity is reset (saves every 2k)
# then cull
# only split/cull if we've seen every image since opacity reset
reset_interval = self.config.reset_alpha_every * self.config.refine_every
if self.step % reset_interval > self.num_train_data + self.config.refine_every:
# then we densify
assert (
assert step == self.step
if self.step <= self.config.warmup_length:
return
with torch.no_grad():
# Offset all the opacity reset logic by refine_every so that we don't
# save checkpoints right when the opacity is reset (saves every 2k)
# then cull
# only split/cull if we've seen every image since opacity reset
reset_interval = self.config.reset_alpha_every * self.config.refine_every
do_densification = (
self.step < self.config.stop_split_at
and self.step % reset_interval > self.num_train_data + self.config.refine_every
)
if do_densification:
# then we densify
assert (
self.xys_grad_norm is not None and self.vis_counts is not None and self.max_2Dsize is not None
)
avg_grad_norm = (
)
avg_grad_norm = (
(self.xys_grad_norm / self.vis_counts) * 0.5 * max(self.last_size[0], self.last_size[1])
)
high_grads = (avg_grad_norm > self.config.densify_grad_thresh).squeeze()
splits = (self.scales.exp().max(dim=-1).values > self.config.densify_size_thresh).squeeze()
if self.step < self.config.stop_screen_size_at:
splits |= (self.max_2Dsize > self.config.split_screen_size).squeeze()
splits &= high_grads
nsamps = self.config.n_split_samples
(
split_means,
split_features_dc,
split_features_rest,
split_opacities,
split_scales,
split_quats,
) = self.split_gaussians(splits, nsamps)

dups = (self.scales.exp().max(dim=-1).values <= self.config.densify_size_thresh).squeeze()
dups &= high_grads
(
dup_means,
dup_features_dc,
dup_features_rest,
dup_opacities,
dup_scales,
dup_quats,
) = self.dup_gaussians(dups)
self.means = Parameter(torch.cat([self.means.detach(), split_means, dup_means], dim=0))
self.features_dc = Parameter(
torch.cat([self.features_dc.detach(), split_features_dc, dup_features_dc], dim=0)
)
self.features_rest = Parameter(
torch.cat([self.features_rest.detach(), split_features_rest, dup_features_rest], dim=0)
)
self.opacities = Parameter(
torch.cat([self.opacities.detach(), split_opacities, dup_opacities], dim=0)
)
self.scales = Parameter(torch.cat([self.scales.detach(), split_scales, dup_scales], dim=0))
self.quats = Parameter(torch.cat([self.quats.detach(), split_quats, dup_quats], dim=0))
# append zeros to the max_2Dsize tensor
self.max_2Dsize = torch.cat(
[self.max_2Dsize, torch.zeros_like(split_scales[:, 0]), torch.zeros_like(dup_scales[:, 0])],
dim=0,
)

split_idcs = torch.where(splits)[0]
param_groups = self.get_gaussian_param_groups()
for group, param in param_groups.items():
self.dup_in_optim(optimizers.optimizers[group], split_idcs, param, n=nsamps)

dup_idcs = torch.where(dups)[0]
param_groups = self.get_gaussian_param_groups()
for group, param in param_groups.items():
self.dup_in_optim(optimizers.optimizers[group], dup_idcs, param, 1)

# After a guassian is split into two new gaussians, the original one should also be pruned.
splits_mask = torch.cat(
(splits, torch.zeros(nsamps * splits.sum() + dups.sum(), device=self.device, dtype=torch.bool))
)
deleted_mask = self.cull_gaussians(splits_mask)
param_groups = self.get_gaussian_param_groups()
for group, param in param_groups.items():
self.remove_from_optim(optimizers.optimizers[group], deleted_mask, param)
torch.cuda.empty_cache()

if self.step % reset_interval == self.config.refine_every:
# Reset value is set to be twice of the cull_alpha_thresh
reset_value = self.config.cull_alpha_thresh * 2.0
self.opacities.data = torch.clamp(
self.opacities.data, max=torch.logit(torch.tensor(reset_value, device=self.device)).item()
)
# reset the exp of optimizer
optim = optimizers.optimizers["opacity"]
param = optim.param_groups[0]["params"][0]
param_state = optim.state[param]
param_state["exp_avg"] = torch.zeros_like(param_state["exp_avg"])
param_state["exp_avg_sq"] = torch.zeros_like(param_state["exp_avg_sq"])

self.xys_grad_norm = None
self.vis_counts = None
self.max_2Dsize = None
)
high_grads = (avg_grad_norm > self.config.densify_grad_thresh).squeeze()
splits = (self.scales.exp().max(dim=-1).values > self.config.densify_size_thresh).squeeze()
if self.step < self.config.stop_screen_size_at:
splits |= (self.max_2Dsize > self.config.split_screen_size).squeeze()
splits &= high_grads
nsamps = self.config.n_split_samples
(
split_means,
split_features_dc,
split_features_rest,
split_opacities,
split_scales,
split_quats,
) = self.split_gaussians(splits, nsamps)

dups = (self.scales.exp().max(dim=-1).values <= self.config.densify_size_thresh).squeeze()
dups &= high_grads
(
dup_means,
dup_features_dc,
dup_features_rest,
dup_opacities,
dup_scales,
dup_quats,
) = self.dup_gaussians(dups)
self.means = Parameter(torch.cat([self.means.detach(), split_means, dup_means], dim=0))
self.features_dc = Parameter(
torch.cat([self.features_dc.detach(), split_features_dc, dup_features_dc], dim=0)
)
self.features_rest = Parameter(
torch.cat([self.features_rest.detach(), split_features_rest, dup_features_rest], dim=0)
)
self.opacities = Parameter(
torch.cat([self.opacities.detach(), split_opacities, dup_opacities], dim=0)
)
self.scales = Parameter(torch.cat([self.scales.detach(), split_scales, dup_scales], dim=0))
self.quats = Parameter(torch.cat([self.quats.detach(), split_quats, dup_quats], dim=0))
# append zeros to the max_2Dsize tensor
self.max_2Dsize = torch.cat(
[self.max_2Dsize, torch.zeros_like(split_scales[:, 0]), torch.zeros_like(dup_scales[:, 0])],
dim=0,
)

split_idcs = torch.where(splits)[0]
self.dup_in_all_optim(optimizers, split_idcs, nsamps)

dup_idcs = torch.where(dups)[0]
self.dup_in_all_optim(optimizers, dup_idcs, 1)

# After a guassian is split into two new gaussians, the original one should also be pruned.
splits_mask = torch.cat(
(splits, torch.zeros(nsamps * splits.sum() + dups.sum(), device=self.device, dtype=torch.bool))
)

if do_densification:
deleted_mask = self.cull_gaussians(splits_mask)
elif self.step >= self.config.stop_split_at and self.config.continue_cull_post_densification:
deleted_mask = self.cull_gaussians()
else:
# if we donot allow culling post refinement, no more gaussians will be pruned.
deleted_mask = None

if deleted_mask is not None:
self.remove_from_all_optim(optimizers, deleted_mask)

if self.step < self.config.stop_split_at and self.step % reset_interval == self.config.refine_every:
# Reset value is set to be twice of the cull_alpha_thresh
reset_value = self.config.cull_alpha_thresh * 2.0
self.opacities.data = torch.clamp(
self.opacities.data, max=torch.logit(torch.tensor(reset_value, device=self.device)).item()
)
# reset the exp of optimizer
optim = optimizers.optimizers["opacity"]
param = optim.param_groups[0]["params"][0]
param_state = optim.state[param]
param_state["exp_avg"] = torch.zeros_like(param_state["exp_avg"])
param_state["exp_avg_sq"] = torch.zeros_like(param_state["exp_avg_sq"])

self.xys_grad_norm = None
self.vis_counts = None
self.max_2Dsize = None

def cull_gaussians(self, extra_cull_mask: Optional[torch.Tensor] = None):
"""
Expand Down

0 comments on commit f784463

Please sign in to comment.