From 44544f7bdbbb9f0002e86025967e6e7dc877942d Mon Sep 17 00:00:00 2001 From: "J.Y." <132313008+jb-ye@users.noreply.github.com> Date: Wed, 21 Aug 2024 20:50:28 -0400 Subject: [PATCH] Fix bug related to refine_scale2d and introduce a new parameter pause refine after reset (#354) --- gsplat/strategy/default.py | 18 ++++++++++++++---- 1 file changed, 14 insertions(+), 4 deletions(-) diff --git a/gsplat/strategy/default.py b/gsplat/strategy/default.py index 0e311fea9..30ffae6cd 100644 --- a/gsplat/strategy/default.py +++ b/gsplat/strategy/default.py @@ -46,7 +46,10 @@ class DefaultStrategy(Strategy): refine_start_iter (int): Start refining GSs after this iteration. Default is 500. refine_stop_iter (int): Stop refining GSs after this iteration. Default is 15_000. reset_every (int): Reset opacities every this steps. Default is 3000. - refine_every (int): Reine GSs every this steps. Default is 100. + refine_every (int): Refine GSs every this steps. Default is 100. + pause_refine_after_reset (int): Pause refining GSs until this number of steps after + reset, Default is 0 (no pause at all) and one might want to set this number to the + number of images in training set. absgrad (bool): Use absolute gradients for GS splitting. Default is False. revised_opacity (bool): Whether to use revised opacity heuristic from arXiv:2404.06109 (experimental). Default is False. @@ -80,6 +83,7 @@ class DefaultStrategy(Strategy): refine_stop_iter: int = 15_000 reset_every: int = 3000 refine_every: int = 100 + pause_refine_after_reset: int = 0 absgrad: bool = False revised_opacity: bool = False verbose: bool = False @@ -155,7 +159,11 @@ def step_post_backward( self._update_state(params, state, info, packed=packed) - if step > self.refine_start_iter and step % self.refine_every == 0: + if ( + step > self.refine_start_iter + and step % self.refine_every == 0 + and step % self.reset_every >= self.pause_refine_after_reset + ): # grow GSs n_dupli, n_split = self._grow_gs(params, optimizers, state, step) if self.verbose: @@ -175,6 +183,8 @@ def step_post_backward( # reset running stats state["grad2d"].zero_() state["count"].zero_() + if self.refine_scale2d_stop_iter > 0: + state["radii"].zero_() torch.cuda.empty_cache() if step % self.reset_every == 0: @@ -258,9 +268,9 @@ def _grow_gs( n_dupli = is_dupli.sum().item() is_large = ~is_small - if step < self.refine_scale2d_stop_iter: - is_large |= state["radii"] > self.grow_scale2d is_split = is_grad_high & is_large + if step < self.refine_scale2d_stop_iter: + is_split |= state["radii"] > self.grow_scale2d n_split = is_split.sum().item() # first duplicate