Skip to content

Commit

Permalink
vdb in examples
Browse files Browse the repository at this point in the history
  • Loading branch information
liruilong940607 committed Feb 7, 2024
1 parent d646277 commit ef88bca
Show file tree
Hide file tree
Showing 2 changed files with 64 additions and 54 deletions.
35 changes: 27 additions & 8 deletions examples/train_ngp_nerf_occ.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,9 +59,7 @@ def run(args):
max_steps = 20000
init_batch_size = 1024
target_sample_batch_size = 1 << 18
weight_decay = (
1e-5 if args.scene in ["materials", "ficus", "drums"] else 1e-6
)
weight_decay = 1e-5 if args.scene in ["materials", "ficus", "drums"] else 1e-6
# scene parameters
aabb = torch.tensor([-1.5, -1.5, -1.5, 1.5, 1.5, 1.5], device=device)
near_plane = 0.0
Expand Down Expand Up @@ -95,9 +93,26 @@ def run(args):
**test_dataset_kwargs,
)

estimator = OccGridEstimator(
roi_aabb=aabb, resolution=grid_resolution, levels=grid_nlvl
).to(device)
if args.vdb:
from fvdb import sparse_grid_from_dense

from nerfacc.estimators.vdb import VDBEstimator

assert grid_nlvl == 1, "VDBEstimator only supports grid_nlvl=1"
voxel_sizes = (aabb[3:] - aabb[:3]) / grid_resolution
origins = aabb[:3] + voxel_sizes / 2
grid = sparse_grid_from_dense(
1,
(grid_resolution, grid_resolution, grid_resolution),
voxel_sizes=voxel_sizes,
origins=origins,
)
estimator = VDBEstimator(grid).to(device)
estimator.aabbs = [aabb]
else:
estimator = OccGridEstimator(
roi_aabb=aabb, resolution=grid_resolution, levels=grid_nlvl
).to(device)

# setup the radiance field we want to train.
grad_scaler = torch.cuda.amp.GradScaler(2**10)
Expand Down Expand Up @@ -171,8 +186,7 @@ def occ_eval_fn(x):
# dynamic batch size for rays to keep sample batch size constant.
num_rays = len(pixels)
num_rays = int(
num_rays
* (target_sample_batch_size / float(n_rendering_samples))
num_rays * (target_sample_batch_size / float(n_rendering_samples))
)
train_dataset.update_num_rays(num_rays)

Expand Down Expand Up @@ -278,6 +292,11 @@ def occ_eval_fn(x):
choices=NERF_SYNTHETIC_SCENES + MIPNERF360_UNBOUNDED_SCENES,
help="which scene to use",
)
parser.add_argument(
"--vdb",
action="store_true",
help="use VDBEstimator instead of OccGridEstimator",
)
args = parser.parse_args()

run(args)
83 changes: 37 additions & 46 deletions examples/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,54 +73,55 @@ def render_image_with_occgrid(
if len(rays_shape) == 3:
height, width, _ = rays_shape
num_rays = height * width
rays = namedtuple_map(
lambda r: r.reshape([num_rays] + list(r.shape[2:])), rays
)
rays = namedtuple_map(lambda r: r.reshape([num_rays] + list(r.shape[2:])), rays)
else:
num_rays, _ = rays_shape

results = []
chunk = (
torch.iinfo(torch.int32).max
if radiance_field.training
else test_chunk_size
)
chunk = torch.iinfo(torch.int32).max if radiance_field.training else test_chunk_size
for i in range(0, num_rays, chunk):
chunk_rays = namedtuple_map(lambda r: r[i : i + chunk], rays)

rays_o = chunk_rays.origins
rays_d = chunk_rays.viewdirs

def sigma_fn(t_starts, t_ends, ray_indices):
t_origins = rays_o[ray_indices]
t_dirs = rays_d[ray_indices]
positions = t_origins + t_dirs * (t_starts + t_ends)[:, None] / 2.0
if timestamps is not None:
# dnerf
t = (
timestamps[ray_indices]
if radiance_field.training
else timestamps.expand_as(positions[:, :1])
)
sigmas = radiance_field.query_density(positions, t)
if t_starts.shape[0] == 0:
sigmas = torch.empty((0, 1), device=t_starts.device)
else:
sigmas = radiance_field.query_density(positions)
t_origins = rays_o[ray_indices]
t_dirs = rays_d[ray_indices]
positions = t_origins + t_dirs * (t_starts + t_ends)[:, None] / 2.0
if timestamps is not None:
# dnerf
t = (
timestamps[ray_indices]
if radiance_field.training
else timestamps.expand_as(positions[:, :1])
)
sigmas = radiance_field.query_density(positions, t)
else:
sigmas = radiance_field.query_density(positions)
return sigmas.squeeze(-1)

def rgb_sigma_fn(t_starts, t_ends, ray_indices):
t_origins = rays_o[ray_indices]
t_dirs = rays_d[ray_indices]
positions = t_origins + t_dirs * (t_starts + t_ends)[:, None] / 2.0
if timestamps is not None:
# dnerf
t = (
timestamps[ray_indices]
if radiance_field.training
else timestamps.expand_as(positions[:, :1])
)
rgbs, sigmas = radiance_field(positions, t, t_dirs)
if t_starts.shape[0] == 0:
rgbs = torch.empty((0, 3), device=t_starts.device)
sigmas = torch.empty((0, 1), device=t_starts.device)
else:
rgbs, sigmas = radiance_field(positions, t_dirs)
t_origins = rays_o[ray_indices]
t_dirs = rays_d[ray_indices]
positions = t_origins + t_dirs * (t_starts + t_ends)[:, None] / 2.0
if timestamps is not None:
# dnerf
t = (
timestamps[ray_indices]
if radiance_field.training
else timestamps.expand_as(positions[:, :1])
)
rgbs, sigmas = radiance_field(positions, t, t_dirs)
else:
rgbs, sigmas = radiance_field(positions, t_dirs)
return rgbs, sigmas.squeeze(-1)

ray_indices, t_starts, t_ends = estimator.sampling(
Expand Down Expand Up @@ -180,9 +181,7 @@ def render_image_with_propnet(
if len(rays_shape) == 3:
height, width, _ = rays_shape
num_rays = height * width
rays = namedtuple_map(
lambda r: r.reshape([num_rays] + list(r.shape[2:])), rays
)
rays = namedtuple_map(lambda r: r.reshape([num_rays] + list(r.shape[2:])), rays)
else:
num_rays, _ = rays_shape

Expand All @@ -207,11 +206,7 @@ def rgb_sigma_fn(t_starts, t_ends, ray_indices):
return rgb, sigmas.squeeze(-1)

results = []
chunk = (
torch.iinfo(torch.int32).max
if radiance_field.training
else test_chunk_size
)
chunk = torch.iinfo(torch.int32).max if radiance_field.training else test_chunk_size
for i in range(0, num_rays, chunk):
chunk_rays = namedtuple_map(lambda r: r[i : i + chunk], rays)
t_starts, t_ends = estimator.sampling(
Expand Down Expand Up @@ -276,18 +271,14 @@ def render_image_with_occgrid_test(
if len(rays_shape) == 3:
height, width, _ = rays_shape
num_rays = height * width
rays = namedtuple_map(
lambda r: r.reshape([num_rays] + list(r.shape[2:])), rays
)
rays = namedtuple_map(lambda r: r.reshape([num_rays] + list(r.shape[2:])), rays)
else:
num_rays, _ = rays_shape

def rgb_sigma_fn(t_starts, t_ends, ray_indices):
t_origins = rays.origins[ray_indices]
t_dirs = rays.viewdirs[ray_indices]
positions = (
t_origins + t_dirs * (t_starts[:, None] + t_ends[:, None]) / 2.0
)
positions = t_origins + t_dirs * (t_starts[:, None] + t_ends[:, None]) / 2.0
if timestamps is not None:
# dnerf
t = (
Expand Down

0 comments on commit ef88bca

Please sign in to comment.