diff --git a/nerfstudio/data/pixel_samplers.py b/nerfstudio/data/pixel_samplers.py index ef6238aaf5..c24bb42803 100644 --- a/nerfstudio/data/pixel_samplers.py +++ b/nerfstudio/data/pixel_samplers.py @@ -178,22 +178,50 @@ def sample_method_fisheye( if isinstance(mask, torch.Tensor) and not self.config.ignore_mask: indices = self.sample_method(batch_size, num_images, image_height, image_width, mask=mask, device=device) else: - rand_samples = torch.rand((batch_size, 3), device=device) - # convert random samples tto radius and theta - radii = self.config.fisheye_crop_radius * torch.sqrt(rand_samples[:, 1]) - theta = 2.0 * torch.pi * rand_samples[:, 2] - - # convert radius and theta to x and y between -radii and radii - x = radii * torch.cos(theta) - y = radii * torch.sin(theta) + # Rejection sampling. + valid: Optional[torch.Tensor] = None + indices = None + while True: + samples_needed = batch_size if valid is None else int(batch_size - torch.sum(valid).item()) + + # Check if done! + if samples_needed == 0: + break + + rand_samples = torch.rand((samples_needed, 2), device=device) + # Convert random samples to radius and theta. + radii = self.config.fisheye_crop_radius * torch.sqrt(rand_samples[:, 0]) + theta = 2.0 * torch.pi * rand_samples[:, 1] + + # Convert radius and theta to x and y. + x = (radii * torch.cos(theta) + image_width // 2).long() + y = (radii * torch.sin(theta) + image_height // 2).long() + sampled_indices = torch.stack( + [torch.randint(0, num_images, size=(samples_needed,), device=device), y, x], dim=-1 + ) - # Multiply by the batch size and height/width to get pixel indices. - indices = torch.floor( - torch.stack([rand_samples[:, 0], y, x], dim=1) - * torch.tensor([num_images, image_height // 2, image_width // 2], device=device) - + torch.tensor([0, image_height // 2, image_width // 2], device=device) - ).long() + # Update indices. + if valid is None: + indices = sampled_indices + valid = ( + (sampled_indices[:, 1] >= 0) + & (sampled_indices[:, 1] < image_height) + & (sampled_indices[:, 2] >= 0) + & (sampled_indices[:, 2] < image_width) + ) + else: + assert indices is not None + not_valid = ~valid + indices[not_valid, :] = sampled_indices + valid[not_valid] = ( + (sampled_indices[:, 1] >= 0) + & (sampled_indices[:, 1] < image_height) + & (sampled_indices[:, 2] >= 0) + & (sampled_indices[:, 2] < image_width) + ) + assert indices is not None + assert indices.shape == (batch_size, 3) return indices def collate_image_dataset_batch(self, batch: Dict, num_rays_per_batch: int, keep_full_image: bool = False):