Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Speed up mask sampling with rejection sampling #2585

Merged
47 changes: 39 additions & 8 deletions nerfstudio/data/pixel_samplers.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

import random
from dataclasses import dataclass, field
import warnings
from typing import (
Dict,
Optional,
Expand Down Expand Up @@ -49,6 +50,10 @@ class PixelSamplerConfig(InstantiateConfig):
"""List of whether or not camera i is equirectangular."""
fisheye_crop_radius: Optional[float] = None
"""Set to the radius (in pixels) for fisheye cameras."""
rejection_sample_mask: bool = True
"""Whether or not to use rejection sampling when sampling images with masks"""
max_num_iterations: int = 100
"""If rejection sampling masks, the maximum number of times to sample"""


class PixelSampler:
Expand Down Expand Up @@ -95,15 +100,41 @@ def sample_method(
num_images: number of images to sample over
mask: mask of possible pixels in an image to sample from.
"""
indices = (
torch.rand((batch_size, 3), device=device)
* torch.tensor([num_images, image_height, image_width], device=device)
).long()

if isinstance(mask, torch.Tensor):
nonzero_indices = torch.nonzero(mask[..., 0], as_tuple=False)
chosen_indices = random.sample(range(len(nonzero_indices)), k=batch_size)
indices = nonzero_indices[chosen_indices]
else:
indices = (
torch.rand((batch_size, 3), device=device)
* torch.tensor([num_images, image_height, image_width], device=device)
).long()
if self.config.rejection_sample_mask:
num_valid = 0
for _ in range(self.config.max_num_iterations):
c, y, x = (i.flatten() for i in torch.split(indices, 1, dim=-1))
chosen_indices_validity = mask[..., 0][c, y, x].bool()
num_valid = int(torch.sum(chosen_indices_validity).item())
if num_valid == batch_size:
break
else:
replacement_indices = (
torch.rand((batch_size - num_valid, 3), device=device)
* torch.tensor([num_images, image_height, image_width], device=device)
).long()
indices[~chosen_indices_validity] = replacement_indices

if num_valid != batch_size:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Instead of just raising warning, I think it would make sense to default back to the slow non-rejection sampling if this occurs. I would still issue a warning, but it would suck for training to fail on the off chance that not enough valid indices are generated in time.

warnings.warn(
"""
Masked sampling failed, mask is either empty or mostly empty.
Reverting behavior to non-rejection sampling. Consider setting
pipeline.datamanager.pixel-sampler.rejection-sample-mask to False
or increasing pipeline.datamanager.pixel-sampler.max-num-iterations
"""
)
self.config.rejection_sample_mask = False
else:
nonzero_indices = torch.nonzero(mask[..., 0], as_tuple=False)
chosen_indices = random.sample(range(len(nonzero_indices)), k=batch_size)
indices = nonzero_indices[chosen_indices]

return indices

Expand Down
Loading