Skip to content

Commit

Permalink
Fix bug related to refine_scale2d and introduce a new parameter pause…
Browse files Browse the repository at this point in the history
… refine after reset (#354)
  • Loading branch information
jb-ye authored Aug 22, 2024
1 parent 5ec2670 commit 44544f7
Showing 1 changed file with 14 additions and 4 deletions.
18 changes: 14 additions & 4 deletions gsplat/strategy/default.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,10 @@ class DefaultStrategy(Strategy):
refine_start_iter (int): Start refining GSs after this iteration. Default is 500.
refine_stop_iter (int): Stop refining GSs after this iteration. Default is 15_000.
reset_every (int): Reset opacities every this steps. Default is 3000.
refine_every (int): Reine GSs every this steps. Default is 100.
refine_every (int): Refine GSs every this steps. Default is 100.
pause_refine_after_reset (int): Pause refining GSs until this number of steps after
reset, Default is 0 (no pause at all) and one might want to set this number to the
number of images in training set.
absgrad (bool): Use absolute gradients for GS splitting. Default is False.
revised_opacity (bool): Whether to use revised opacity heuristic from
arXiv:2404.06109 (experimental). Default is False.
Expand Down Expand Up @@ -80,6 +83,7 @@ class DefaultStrategy(Strategy):
refine_stop_iter: int = 15_000
reset_every: int = 3000
refine_every: int = 100
pause_refine_after_reset: int = 0
absgrad: bool = False
revised_opacity: bool = False
verbose: bool = False
Expand Down Expand Up @@ -155,7 +159,11 @@ def step_post_backward(

self._update_state(params, state, info, packed=packed)

if step > self.refine_start_iter and step % self.refine_every == 0:
if (
step > self.refine_start_iter
and step % self.refine_every == 0
and step % self.reset_every >= self.pause_refine_after_reset
):
# grow GSs
n_dupli, n_split = self._grow_gs(params, optimizers, state, step)
if self.verbose:
Expand All @@ -175,6 +183,8 @@ def step_post_backward(
# reset running stats
state["grad2d"].zero_()
state["count"].zero_()
if self.refine_scale2d_stop_iter > 0:
state["radii"].zero_()
torch.cuda.empty_cache()

if step % self.reset_every == 0:
Expand Down Expand Up @@ -258,9 +268,9 @@ def _grow_gs(
n_dupli = is_dupli.sum().item()

is_large = ~is_small
if step < self.refine_scale2d_stop_iter:
is_large |= state["radii"] > self.grow_scale2d
is_split = is_grad_high & is_large
if step < self.refine_scale2d_stop_iter:
is_split |= state["radii"] > self.grow_scale2d
n_split = is_split.sum().item()

# first duplicate
Expand Down

0 comments on commit 44544f7

Please sign in to comment.