From ef88bca7ef6c288c99a5225af6e8aea8470a41da Mon Sep 17 00:00:00 2001 From: Ruilong Li Date: Wed, 7 Feb 2024 21:29:04 +0000 Subject: [PATCH] vdb in examples --- examples/train_ngp_nerf_occ.py | 35 ++++++++++---- examples/utils.py | 83 +++++++++++++++------------------- 2 files changed, 64 insertions(+), 54 deletions(-) diff --git a/examples/train_ngp_nerf_occ.py b/examples/train_ngp_nerf_occ.py index 026c9e3..c2c02fb 100644 --- a/examples/train_ngp_nerf_occ.py +++ b/examples/train_ngp_nerf_occ.py @@ -59,9 +59,7 @@ def run(args): max_steps = 20000 init_batch_size = 1024 target_sample_batch_size = 1 << 18 - weight_decay = ( - 1e-5 if args.scene in ["materials", "ficus", "drums"] else 1e-6 - ) + weight_decay = 1e-5 if args.scene in ["materials", "ficus", "drums"] else 1e-6 # scene parameters aabb = torch.tensor([-1.5, -1.5, -1.5, 1.5, 1.5, 1.5], device=device) near_plane = 0.0 @@ -95,9 +93,26 @@ def run(args): **test_dataset_kwargs, ) - estimator = OccGridEstimator( - roi_aabb=aabb, resolution=grid_resolution, levels=grid_nlvl - ).to(device) + if args.vdb: + from fvdb import sparse_grid_from_dense + + from nerfacc.estimators.vdb import VDBEstimator + + assert grid_nlvl == 1, "VDBEstimator only supports grid_nlvl=1" + voxel_sizes = (aabb[3:] - aabb[:3]) / grid_resolution + origins = aabb[:3] + voxel_sizes / 2 + grid = sparse_grid_from_dense( + 1, + (grid_resolution, grid_resolution, grid_resolution), + voxel_sizes=voxel_sizes, + origins=origins, + ) + estimator = VDBEstimator(grid).to(device) + estimator.aabbs = [aabb] + else: + estimator = OccGridEstimator( + roi_aabb=aabb, resolution=grid_resolution, levels=grid_nlvl + ).to(device) # setup the radiance field we want to train. grad_scaler = torch.cuda.amp.GradScaler(2**10) @@ -171,8 +186,7 @@ def occ_eval_fn(x): # dynamic batch size for rays to keep sample batch size constant. num_rays = len(pixels) num_rays = int( - num_rays - * (target_sample_batch_size / float(n_rendering_samples)) + num_rays * (target_sample_batch_size / float(n_rendering_samples)) ) train_dataset.update_num_rays(num_rays) @@ -278,6 +292,11 @@ def occ_eval_fn(x): choices=NERF_SYNTHETIC_SCENES + MIPNERF360_UNBOUNDED_SCENES, help="which scene to use", ) + parser.add_argument( + "--vdb", + action="store_true", + help="use VDBEstimator instead of OccGridEstimator", + ) args = parser.parse_args() run(args) diff --git a/examples/utils.py b/examples/utils.py index d3979f4..caecf8a 100644 --- a/examples/utils.py +++ b/examples/utils.py @@ -73,18 +73,12 @@ def render_image_with_occgrid( if len(rays_shape) == 3: height, width, _ = rays_shape num_rays = height * width - rays = namedtuple_map( - lambda r: r.reshape([num_rays] + list(r.shape[2:])), rays - ) + rays = namedtuple_map(lambda r: r.reshape([num_rays] + list(r.shape[2:])), rays) else: num_rays, _ = rays_shape results = [] - chunk = ( - torch.iinfo(torch.int32).max - if radiance_field.training - else test_chunk_size - ) + chunk = torch.iinfo(torch.int32).max if radiance_field.training else test_chunk_size for i in range(0, num_rays, chunk): chunk_rays = namedtuple_map(lambda r: r[i : i + chunk], rays) @@ -92,35 +86,42 @@ def render_image_with_occgrid( rays_d = chunk_rays.viewdirs def sigma_fn(t_starts, t_ends, ray_indices): - t_origins = rays_o[ray_indices] - t_dirs = rays_d[ray_indices] - positions = t_origins + t_dirs * (t_starts + t_ends)[:, None] / 2.0 - if timestamps is not None: - # dnerf - t = ( - timestamps[ray_indices] - if radiance_field.training - else timestamps.expand_as(positions[:, :1]) - ) - sigmas = radiance_field.query_density(positions, t) + if t_starts.shape[0] == 0: + sigmas = torch.empty((0, 1), device=t_starts.device) else: - sigmas = radiance_field.query_density(positions) + t_origins = rays_o[ray_indices] + t_dirs = rays_d[ray_indices] + positions = t_origins + t_dirs * (t_starts + t_ends)[:, None] / 2.0 + if timestamps is not None: + # dnerf + t = ( + timestamps[ray_indices] + if radiance_field.training + else timestamps.expand_as(positions[:, :1]) + ) + sigmas = radiance_field.query_density(positions, t) + else: + sigmas = radiance_field.query_density(positions) return sigmas.squeeze(-1) def rgb_sigma_fn(t_starts, t_ends, ray_indices): - t_origins = rays_o[ray_indices] - t_dirs = rays_d[ray_indices] - positions = t_origins + t_dirs * (t_starts + t_ends)[:, None] / 2.0 - if timestamps is not None: - # dnerf - t = ( - timestamps[ray_indices] - if radiance_field.training - else timestamps.expand_as(positions[:, :1]) - ) - rgbs, sigmas = radiance_field(positions, t, t_dirs) + if t_starts.shape[0] == 0: + rgbs = torch.empty((0, 3), device=t_starts.device) + sigmas = torch.empty((0, 1), device=t_starts.device) else: - rgbs, sigmas = radiance_field(positions, t_dirs) + t_origins = rays_o[ray_indices] + t_dirs = rays_d[ray_indices] + positions = t_origins + t_dirs * (t_starts + t_ends)[:, None] / 2.0 + if timestamps is not None: + # dnerf + t = ( + timestamps[ray_indices] + if radiance_field.training + else timestamps.expand_as(positions[:, :1]) + ) + rgbs, sigmas = radiance_field(positions, t, t_dirs) + else: + rgbs, sigmas = radiance_field(positions, t_dirs) return rgbs, sigmas.squeeze(-1) ray_indices, t_starts, t_ends = estimator.sampling( @@ -180,9 +181,7 @@ def render_image_with_propnet( if len(rays_shape) == 3: height, width, _ = rays_shape num_rays = height * width - rays = namedtuple_map( - lambda r: r.reshape([num_rays] + list(r.shape[2:])), rays - ) + rays = namedtuple_map(lambda r: r.reshape([num_rays] + list(r.shape[2:])), rays) else: num_rays, _ = rays_shape @@ -207,11 +206,7 @@ def rgb_sigma_fn(t_starts, t_ends, ray_indices): return rgb, sigmas.squeeze(-1) results = [] - chunk = ( - torch.iinfo(torch.int32).max - if radiance_field.training - else test_chunk_size - ) + chunk = torch.iinfo(torch.int32).max if radiance_field.training else test_chunk_size for i in range(0, num_rays, chunk): chunk_rays = namedtuple_map(lambda r: r[i : i + chunk], rays) t_starts, t_ends = estimator.sampling( @@ -276,18 +271,14 @@ def render_image_with_occgrid_test( if len(rays_shape) == 3: height, width, _ = rays_shape num_rays = height * width - rays = namedtuple_map( - lambda r: r.reshape([num_rays] + list(r.shape[2:])), rays - ) + rays = namedtuple_map(lambda r: r.reshape([num_rays] + list(r.shape[2:])), rays) else: num_rays, _ = rays_shape def rgb_sigma_fn(t_starts, t_ends, ray_indices): t_origins = rays.origins[ray_indices] t_dirs = rays.viewdirs[ray_indices] - positions = ( - t_origins + t_dirs * (t_starts[:, None] + t_ends[:, None]) / 2.0 - ) + positions = t_origins + t_dirs * (t_starts[:, None] + t_ends[:, None]) / 2.0 if timestamps is not None: # dnerf t = (