Skip to content

Commit

Permalink
Merge branch 'main' into brent/fix_tensor_dataclass_nontensor
Browse files Browse the repository at this point in the history
  • Loading branch information
kerrj authored Jan 18, 2024
2 parents 02dc47c + a8d5dc9 commit 3bc6c29
Show file tree
Hide file tree
Showing 7 changed files with 140 additions and 59 deletions.
22 changes: 22 additions & 0 deletions docs/nerfology/methods/splat.md
Original file line number Diff line number Diff line change
@@ -1,7 +1,23 @@
# Gaussian Splatting
<h4>Real-Time Radiance Field Rendering</h4>


```{button-link} https://repo-sam.inria.fr/fungraph/3d-gaussian-splatting/
:color: primary
:outline:
Paper Website
```

[3D Gaussian Splatting](https://repo-sam.inria.fr/fungraph/3d-gaussian-splatting/) was proposed in SIGGRAPH 2023 from INRIA, and is a completely different method of representing radiance fields by explicitly storing a collection of 3D volumetric gaussians. These can be "splatted", or projected, onto a 2D image provided a camera pose, and rasterized to obtain per-pixel colors. Because rasterization is very fast on GPUs, this method can render much faster than neural representations of radiance fields.

### Installation

```{button-link} https://docs.gsplat.studio/
:color: primary
:outline:
GSplat
```

Nerfstudio uses [gsplat](https://github.com/nerfstudio-project/gsplat) as its gaussian rasterization backend, an in-house re-implementation which is designed to be more developer friendly. This can be installed with `pip install gsplat`. The associated CUDA code will be compiled the first time gaussian splatting is executed. Some users with PyTorch 2.0 have experienced issues with this, which can be resolved by either installing gsplat from source, or upgrading torch to 2.1.

### Data
Expand All @@ -13,6 +29,12 @@ Because gaussian splatting trains on *full images* instead of bundles of rays, t
### Running the Method
To run gaussian splatting, run `ns-train gaussian-splatting --data <data>`. Just like NeRF methods, the splat can be interactively viewed in the web-viewer, loaded from a checkpoint, rendered, and exported.

#### Quality and Regularization
The default settings provided maintain a balance between speed, quality, and splat file size, but if you care more about quality than training speed or size, you can decrease the alpha cull threshold
(threshold to delete translucent gaussians) and disable culling after 15k steps like so: `ns-train gaussian-splatting --pipeline.model.cull_scale_thresh=0.005 --pipeline.model.continue_cull_post_densification=False --data <data>`

A common artifact in splatting is long, spikey gaussians. [PhysGaussian](https://xpandora.github.io/PhysGaussian/) proposes an example of a scale-regularizer that encourages gaussians to be more evenly shaped. To enable this, set the `use_scale_regularization` flag to `True`.

### Details
For more details on the method, see the [original paper](https://arxiv.org/abs/2308.04079). Additionally, for a detailed derivation of the gradients used in the gsplat library, see [here](https://arxiv.org/abs/2312.02121).

Expand Down
6 changes: 4 additions & 2 deletions nerfstudio/data/datamanagers/full_images_datamanager.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,8 @@ class FullImageDatamanagerConfig(DataManagerConfig):
"""Specifies the image indices to use during eval; if None, uses all."""
cache_images: Literal["cpu", "gpu"] = "cpu"
"""Whether to cache images in memory. If "cpu", caches on cpu. If "gpu", caches on device."""
cache_images_type: Literal["uint8", "float32"] = "float32"
"""The image type returned from manager, caching images in uint8 saves memory"""


class FullImageDatamanager(DataManager, Generic[TDataset]):
Expand Down Expand Up @@ -126,7 +128,7 @@ def cache_images(self, cache_images_option):
CONSOLE.log("Caching / undistorting train images")
for i in tqdm(range(len(self.train_dataset)), leave=False):
# cv2.undistort the images / cameras
data = self.train_dataset.get_data(i)
data = self.train_dataset.get_data(i, image_type=self.config.cache_images_type)
camera = self.train_dataset.cameras[i].reshape(())
K = camera.get_intrinsics_matrices().numpy()
if camera.distortion_params is None:
Expand Down Expand Up @@ -201,7 +203,7 @@ def cache_images(self, cache_images_option):
CONSOLE.log("Caching / undistorting eval images")
for i in tqdm(range(len(self.eval_dataset)), leave=False):
# cv2.undistort the images / cameras
data = self.eval_dataset.get_data(i)
data = self.eval_dataset.get_data(i, image_type=self.config.cache_images_type)
camera = self.eval_dataset.cameras[i].reshape(())
K = camera.get_intrinsics_matrices().numpy()
if camera.distortion_params is None:
Expand Down
39 changes: 33 additions & 6 deletions nerfstudio/data/datasets/base_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,12 @@

from copy import deepcopy
from pathlib import Path
from typing import Dict, List
from typing import Dict, List, Literal

import numpy as np
import numpy.typing as npt
import torch
from jaxtyping import Float
from jaxtyping import Float, UInt8
from PIL import Image
from torch import Tensor
from torch.utils.data import Dataset
Expand Down Expand Up @@ -77,24 +77,51 @@ def get_numpy_image(self, image_idx: int) -> npt.NDArray[np.uint8]:
assert image.shape[2] in [3, 4], f"Image shape of {image.shape} is in correct."
return image

def get_image(self, image_idx: int) -> Float[Tensor, "image_height image_width num_channels"]:
"""Returns a 3 channel image.
def get_image_float32(self, image_idx: int) -> Float[Tensor, "image_height image_width num_channels"]:
"""Returns a 3 channel image in float32 torch.Tensor.
Args:
image_idx: The image index in the dataset.
"""
image = torch.from_numpy(self.get_numpy_image(image_idx).astype("float32") / 255.0)
if self._dataparser_outputs.alpha_color is not None and image.shape[-1] == 4:
assert (self._dataparser_outputs.alpha_color >= 0).all() and (
self._dataparser_outputs.alpha_color <= 1
).all(), "alpha color given is out of range between [0, 1]."
image = image[:, :, :3] * image[:, :, -1:] + self._dataparser_outputs.alpha_color * (1.0 - image[:, :, -1:])
return image

def get_data(self, image_idx: int) -> Dict:
def get_image_uint8(self, image_idx: int) -> UInt8[Tensor, "image_height image_width num_channels"]:
"""Returns a 3 channel image in uint8 torch.Tensor.
Args:
image_idx: The image index in the dataset.
"""
image = torch.from_numpy(self.get_numpy_image(image_idx))
if self._dataparser_outputs.alpha_color is not None and image.shape[-1] == 4:
assert (self._dataparser_outputs.alpha_color >= 0).all() and (
self._dataparser_outputs.alpha_color <= 1
).all(), "alpha color given is out of range between [0, 1]."
image = image[:, :, :3] * image[:, :, -1:] / 255.0 + 255.0 * self._dataparser_outputs.alpha_color * (
1.0 - image[:, :, -1:] / 255.0
)
image = torch.clamp(image, min=0, max=255).to(torch.uint8)
return image

def get_data(self, image_idx: int, image_type: Literal["uint8", "float32"] = "float32") -> Dict:
"""Returns the ImageDataset data as a dictionary.
Args:
image_idx: The image index in the dataset.
image_type: the type of images returned
"""
image = self.get_image(image_idx)
if image_type == "float32":
image = self.get_image_float32(image_idx)
elif image_type == "uint8":
image = self.get_image_uint8(image_idx)
else:
raise NotImplementedError(f"image_type (={image_type}) getter was not implemented, use uint8 or float32")

data = {"image_idx": image_idx, "image": image}
if self._dataparser_outputs.mask_filenames is not None:
mask_filepath = self._dataparser_outputs.mask_filenames[image_idx]
Expand Down
50 changes: 42 additions & 8 deletions nerfstudio/data/pixel_samplers.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
"""

import random
import warnings
from dataclasses import dataclass, field
from typing import Dict, Optional, Type, Union

Expand All @@ -42,6 +43,10 @@ class PixelSamplerConfig(InstantiateConfig):
"""List of whether or not camera i is equirectangular."""
fisheye_crop_radius: Optional[float] = None
"""Set to the radius (in pixels) for fisheye cameras."""
rejection_sample_mask: bool = True
"""Whether or not to use rejection sampling when sampling images with masks"""
max_num_iterations: int = 100
"""If rejection sampling masks, the maximum number of times to sample"""


class PixelSampler:
Expand Down Expand Up @@ -88,15 +93,44 @@ def sample_method(
num_images: number of images to sample over
mask: mask of possible pixels in an image to sample from.
"""
indices = (
torch.rand((batch_size, 3), device=device)
* torch.tensor([num_images, image_height, image_width], device=device)
).long()

if isinstance(mask, torch.Tensor):
nonzero_indices = torch.nonzero(mask[..., 0], as_tuple=False)
chosen_indices = random.sample(range(len(nonzero_indices)), k=batch_size)
indices = nonzero_indices[chosen_indices]
else:
indices = (
torch.rand((batch_size, 3), device=device)
* torch.tensor([num_images, image_height, image_width], device=device)
).long()
if self.config.rejection_sample_mask:
num_valid = 0
for _ in range(self.config.max_num_iterations):
c, y, x = (i.flatten() for i in torch.split(indices, 1, dim=-1))
chosen_indices_validity = mask[..., 0][c, y, x].bool()
num_valid = int(torch.sum(chosen_indices_validity).item())
if num_valid == batch_size:
break
else:
replacement_indices = (
torch.rand((batch_size - num_valid, 3), device=device)
* torch.tensor([num_images, image_height, image_width], device=device)
).long()
indices[~chosen_indices_validity] = replacement_indices

if num_valid != batch_size:
warnings.warn(
"""
Masked sampling failed, mask is either empty or mostly empty.
Reverting behavior to non-rejection sampling. Consider setting
pipeline.datamanager.pixel-sampler.rejection-sample-mask to False
or increasing pipeline.datamanager.pixel-sampler.max-num-iterations
"""
)
self.config.rejection_sample_mask = False
nonzero_indices = torch.nonzero(mask[..., 0], as_tuple=False)
chosen_indices = random.sample(range(len(nonzero_indices)), k=batch_size)
indices = nonzero_indices[chosen_indices]
else:
nonzero_indices = torch.nonzero(mask[..., 0], as_tuple=False)
chosen_indices = random.sample(range(len(nonzero_indices)), k=batch_size)
indices = nonzero_indices[chosen_indices]

return indices

Expand Down
78 changes: 37 additions & 41 deletions nerfstudio/models/gaussian_splatting.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,10 +26,9 @@
import numpy as np
import torch
from gsplat._torch_impl import quat_to_rotmat
from gsplat.compute_cumulative_intersects import compute_cumulative_intersects
from gsplat.project_gaussians import ProjectGaussians
from gsplat.rasterize import RasterizeGaussians
from gsplat.sh import SphericalHarmonics, num_sh_bases
from gsplat.project_gaussians import project_gaussians
from gsplat.rasterize import rasterize_gaussians
from gsplat.sh import num_sh_bases, spherical_harmonics
from pytorch_msssim import SSIM
from torch.nn import Parameter

Expand Down Expand Up @@ -324,7 +323,8 @@ def after_train(self, step: int):
with torch.no_grad():
# keep track of a moving average of grad norms
visible_mask = (self.radii > 0).flatten()
grads = self.xys.grad.detach().norm(dim=-1) # TODO fill in
assert self.xys.grad is not None
grads = self.xys.grad.detach().norm(dim=-1)
# print(f"grad norm min {grads.min().item()} max {grads.max().item()} mean {grads.mean().item()} size {grads.shape}")
if self.xys_grad_norm is None:
self.xys_grad_norm = grads
Expand Down Expand Up @@ -629,13 +629,13 @@ def get_outputs(self, camera: Cameras) -> Dict[str, Union[torch.Tensor, List]]:
cy = camera.cy.item()
fovx = 2 * math.atan(camera.width / (2 * camera.fx))
fovy = 2 * math.atan(camera.height / (2 * camera.fy))
W, H = camera.width.item(), camera.height.item()
W, H = int(camera.width.item()), int(camera.height.item())
self.last_size = (H, W)
projmat = projection_matrix(0.001, 1000, fovx, fovy, device=self.device)
BLOCK_X, BLOCK_Y = 16, 16
tile_bounds = (
(W + BLOCK_X - 1) // BLOCK_X,
(H + BLOCK_Y - 1) // BLOCK_Y,
int((W + BLOCK_X - 1) // BLOCK_X),
int((H + BLOCK_Y - 1) // BLOCK_Y),
1,
)

Expand All @@ -656,7 +656,7 @@ def get_outputs(self, camera: Cameras) -> Dict[str, Union[torch.Tensor, List]]:

colors_crop = torch.cat((features_dc_crop[:, None, :], features_rest_crop), dim=1)

self.xys, depths, self.radii, conics, num_tiles_hit, cov3d = ProjectGaussians.apply(
self.xys, depths, self.radii, conics, num_tiles_hit, cov3d = project_gaussians( # type: ignore
means_crop,
torch.exp(scales_crop),
1,
Expand All @@ -682,67 +682,75 @@ def get_outputs(self, camera: Cameras) -> Dict[str, Union[torch.Tensor, List]]:
viewdirs = means_crop.detach() - camera.camera_to_worlds.detach()[..., :3, 3] # (N, 3)
viewdirs = viewdirs / viewdirs.norm(dim=-1, keepdim=True)
n = min(self.step // self.config.sh_degree_interval, self.config.sh_degree)
rgbs = SphericalHarmonics.apply(n, viewdirs, colors_crop)
rgbs = spherical_harmonics(n, viewdirs, colors_crop)
rgbs = torch.clamp(rgbs + 0.5, min=0.0) # type: ignore
else:
rgbs = torch.sigmoid(colors_crop[:, 0, :])

# rescale the camera back to original dimensions
camera.rescale_output_resolution(camera_downscale)

# avoid empty rasterization
num_intersects, _ = compute_cumulative_intersects(self.xys.size(0), num_tiles_hit)
assert num_intersects > 0
assert (num_tiles_hit > 0).any() # type: ignore

rgb = RasterizeGaussians.apply(
rgb = rasterize_gaussians( # type: ignore
self.xys,
depths,
self.radii,
conics,
num_tiles_hit,
num_tiles_hit, # type: ignore
rgbs,
torch.sigmoid(opacities_crop),
H,
W,
background,
background=background,
) # type: ignore
rgb = torch.clamp(rgb, max=1.0) # type: ignore
depth_im = None
if not self.training:
depth_im = RasterizeGaussians.apply( # type: ignore
depth_im = rasterize_gaussians( # type: ignore
self.xys,
depths,
self.radii,
conics,
num_tiles_hit,
num_tiles_hit, # type: ignore
depths[:, None].repeat(1, 3),
torch.sigmoid(opacities_crop),
H,
W,
torch.ones(3, device=self.device) * 10,
background=torch.ones(3, device=self.device) * 10,
)[..., 0:1] # type: ignore

return {"rgb": rgb, "depth": depth_im} # type: ignore

def get_metrics_dict(self, outputs, batch) -> Dict[str, torch.Tensor]:
"""Compute and returns metrics.
def get_gt_img(self, image: torch.Tensor):
"""Compute groundtruth image with iteration dependent downscale factor for evaluation purpose
Args:
outputs: the output to compute loss dict to
batch: ground truth batch corresponding to outputs
image: tensor.Tensor in type uint8 or float32
"""
if image.dtype == torch.uint8:
image = image.float() / 255.0
d = self._get_downscale_factor()
if d > 1:
newsize = [batch["image"].shape[0] // d, batch["image"].shape[1] // d]
newsize = [image.shape[0] // d, image.shape[1] // d]

# torchvision can be slow to import, so we do it lazily.
import torchvision.transforms.functional as TF

gt_img = TF.resize(batch["image"].permute(2, 0, 1), newsize, antialias=None).permute(1, 2, 0)
gt_img = TF.resize(image.permute(2, 0, 1), newsize, antialias=None).permute(1, 2, 0)
else:
gt_img = batch["image"]
gt_img = image
return gt_img.to(self.device)

def get_metrics_dict(self, outputs, batch) -> Dict[str, torch.Tensor]:
"""Compute and returns metrics.
Args:
outputs: the output to compute loss dict to
batch: ground truth batch corresponding to outputs
"""
gt_rgb = self.get_gt_img(batch["image"])
metrics_dict = {}
gt_rgb = gt_img.to(self.device) # RGB or RGBA image
predicted_rgb = outputs["rgb"]
metrics_dict["psnr"] = self.psnr(predicted_rgb, gt_rgb)

Expand All @@ -758,16 +766,7 @@ def get_loss_dict(self, outputs, batch, metrics_dict=None) -> Dict[str, torch.Te
batch: ground truth batch corresponding to outputs
metrics_dict: dictionary of metrics, some of which we can use for loss
"""
d = self._get_downscale_factor()
if d > 1:
newsize = [batch["image"].shape[0] // d, batch["image"].shape[1] // d]

# torchvision can be slow to import, so we do it lazily.
import torchvision.transforms.functional as TF

gt_img = TF.resize(batch["image"].permute(2, 0, 1), newsize, antialias=None).permute(1, 2, 0)
else:
gt_img = batch["image"]
gt_img = self.get_gt_img(batch["image"])
Ll1 = torch.abs(gt_img - outputs["rgb"]).mean()
simloss = 1 - self.ssim(gt_img.permute(2, 0, 1)[None, ...], outputs["rgb"].permute(2, 0, 1)[None, ...])
if self.config.use_scale_regularization and self.step % 10 == 0:
Expand Down Expand Up @@ -814,20 +813,17 @@ def get_image_metrics_and_images(
Returns:
A dictionary of metrics.
"""
gt_rgb = self.get_gt_img(batch["image"])
d = self._get_downscale_factor()
if d > 1:
# torchvision can be slow to import, so we do it lazily.
import torchvision.transforms.functional as TF

newsize = [batch["image"].shape[0] // d, batch["image"].shape[1] // d]
gt_img = TF.resize(batch["image"].permute(2, 0, 1), newsize, antialias=None).permute(1, 2, 0)
predicted_rgb = TF.resize(outputs["rgb"].permute(2, 0, 1), newsize, antialias=None).permute(1, 2, 0)
else:
gt_img = batch["image"]
predicted_rgb = outputs["rgb"]

gt_rgb = gt_img.to(self.device)

combined_rgb = torch.cat([gt_rgb, predicted_rgb], dim=1)

# Switch images from [H, W, C] to [1, C, H, W] for metrics computations
Expand Down
2 changes: 1 addition & 1 deletion nerfstudio/scripts/render.py
Original file line number Diff line number Diff line change
Expand Up @@ -234,7 +234,7 @@ def _render_trajectory_video(
if render_nearest_camera:
assert train_dataset is not None
assert train_cameras is not None
img = train_dataset.get_image(max_idx)
img = train_dataset.get_image_float32(max_idx)
height = cameras.image_height[0]
# maintain the resolution of the img to calculate the width from the height
width = int(img.shape[1] * (height / img.shape[0]))
Expand Down
Loading

0 comments on commit 3bc6c29

Please sign in to comment.