diff --git a/docs/extensions/blender_addon.md b/docs/extensions/blender_addon.md index 677f8b005c..5b24d72335 100644 --- a/docs/extensions/blender_addon.md +++ b/docs/extensions/blender_addon.md @@ -6,7 +6,7 @@ ## Overview -This Blender add-on allows for compositing with a Nerfstudio render as a background layer by generating a camera path JSON file from the Blender camera path, as well as a way to import Nerfstudio JSON files as a Blender camera baked with the Nerfstudio camera path. This add-on also allows compositing multiple NeRF objects into a NeRF scene. This is achieved by importing a mesh or point-cloud representation of the NeRF scene from Nerfstudio into Blender and getting the camera coordinates relative to the transformations of the NeRF representation. Dynamic FOV from the Blender camera is supported and will match the Nerfstudio render. Perspective, equirectangular, VR180, and omnidirectional stereo (VR 360) cameras are supported. This add-on also supports Gaussian Splatting scenes as well, however equirectangular and VR video rendering is not currently supported. +This Blender add-on allows for compositing with a Nerfstudio render as a background layer by generating a camera path JSON file from the Blender camera path, as well as a way to import Nerfstudio JSON files as a Blender camera baked with the Nerfstudio camera path. This add-on also allows compositing multiple NeRF objects into a NeRF scene. This is achieved by importing a mesh or point-cloud representation of the NeRF scene from Nerfstudio into Blender and getting the camera coordinates relative to the transformations of the NeRF representation. Dynamic FOV from the Blender camera is supported and will match the Nerfstudio render. Perspective, equirectangular, VR180, and omnidirectional stereo (VR 360) cameras are supported. This add-on also supports Gaussian Splatting scenes as well, however equirectangular and VR video rendering is not currently supported for splats.
image @@ -109,6 +109,7 @@ This Blender add-on allows for compositing with a Nerfstudio render as a backgro image
- Fisheye and orthographic cameras are not supported. +- Renders with Gaussian Splats are supported, but the point cloud or mesh representation would need to be generated from training a NeRF on the same dataset. - A walkthrough of this section is included in the tutorial video. ## Create Blender Camera from Nerfstudio JSON Camera Path diff --git a/docs/nerfology/methods/splat.md b/docs/nerfology/methods/splat.md index 802dbe8619..8e613e251b 100644 --- a/docs/nerfology/methods/splat.md +++ b/docs/nerfology/methods/splat.md @@ -1,29 +1,55 @@ # Gaussian Splatting -[3D Gaussian Splatting](https://repo-sam.inria.fr/fungraph/3d-gaussian-splatting/) was proposed in SIGGRAPH 2023 from INRIA, and is a completely -different method of representing radiance fields by explicitly storing a collection of 3D volumetric gaussians. These can be "splatted", or projected, onto a 2D image -provided a camera pose, and rasterized to obtain per-pixel colors. Because rasterization is very fast on GPUs, this method can render much faster than neural representations -of radiance fields. +

Real-Time Radiance Field Rendering

+ + +```{button-link} https://repo-sam.inria.fr/fungraph/3d-gaussian-splatting/ +:color: primary +:outline: +Paper Website +``` + +[3D Gaussian Splatting](https://repo-sam.inria.fr/fungraph/3d-gaussian-splatting/) was proposed in SIGGRAPH 2023 from INRIA, and is a completely different method of representing radiance fields by explicitly storing a collection of 3D volumetric gaussians. These can be "splatted", or projected, onto a 2D image provided a camera pose, and rasterized to obtain per-pixel colors. Because rasterization is very fast on GPUs, this method can render much faster than neural representations of radiance fields. ### Installation -Nerfstudio uses [gsplat](https://github.com/nerfstudio-project/gsplat) as its gaussian rasterization backend, an in-house re-implementation which is meant to be more developer friendly. This can be installed with `pip install gsplat`. The associated CUDA code will be compiled the first time gaussian splatting is executed. Some users with PyTorch 2.0 have experienced issues with this, which can be resolved by either installing gsplat from source, or upgrading torch to 2.1. + +```{button-link} https://docs.gsplat.studio/ +:color: primary +:outline: +GSplat +``` + +Nerfstudio uses [gsplat](https://github.com/nerfstudio-project/gsplat) as its gaussian rasterization backend, an in-house re-implementation which is designed to be more developer friendly. This can be installed with `pip install gsplat`. The associated CUDA code will be compiled the first time gaussian splatting is executed. Some users with PyTorch 2.0 have experienced issues with this, which can be resolved by either installing gsplat from source, or upgrading torch to 2.1. ### Data -Gaussian Splatting works much better if you initialize it from pre-existing geometry, such as SfM points rom COLMAP. COLMAP datasets or datasets from `ns-process-data` will automatically save these points and initialize gaussians on them. Other datasets currently do not support initialization, and will initialize gaussians randomly. Initializing from other data inputs (i.e. depth from phone app scanners) may be supported in the future. +Gaussian Splatting works much better if you initialize it from pre-existing geometry, such as SfM points from COLMAP. COLMAP datasets or datasets from `ns-process-data` will automatically save these points and initialize gaussians on them. Other datasets currently do not support initialization, and will initialize gaussians randomly. Initializing from other data inputs (i.e. depth from phone app scanners) may be supported in the future. Because gaussian splatting trains on *full images* instead of bundles of rays, there is a new datamanager in `full_images_datamanager.py` which undistorts input images, caches them, and provides single images at each train step. ### Running the Method -To run gaussian splatting, run `ns-train gaussian-splatting --data `. Just like NeRF methods, the splat can be interactively viewed in the web-viewer, rendered, and exported. +To run gaussian splatting, run `ns-train gaussian-splatting --data `. Just like NeRF methods, the splat can be interactively viewed in the web-viewer, loaded from a checkpoint, rendered, and exported. + +#### Quality and Regularization +The default settings provided maintain a balance between speed, quality, and splat file size, but if you care more about quality than training speed or size, you can decrease the alpha cull threshold +(threshold to delete translucent gaussians) and disable culling after 15k steps like so: `ns-train gaussian-splatting --pipeline.model.cull_scale_thresh=0.005 --pipeline.model.continue_cull_post_densification=False --data ` + +A common artifact in splatting is long, spikey gaussians. [PhysGaussian](https://xpandora.github.io/PhysGaussian/) proposes an example of a scale-regularizer that encourages gaussians to be more evenly shaped. To enable this, set the `use_scale_regularization` flag to `True`. ### Details For more details on the method, see the [original paper](https://arxiv.org/abs/2308.04079). Additionally, for a detailed derivation of the gradients used in the gsplat library, see [here](https://arxiv.org/abs/2312.02121). ### Exporting splats -Gaussian splats can be exported as a `.ply` file which are ingestable by a variety of online web viewers. You can do this via the viewer, or `ns-export gaussian-splat`. Currently splats can only be exported from trained splats, not from nerfacto. +Gaussian splats can be exported as a `.ply` file which are ingestable by a variety of online web viewers. You can do this via the viewer, or `ns-export gaussian-splat --load-config --output-dir exports/splat`. Currently splats can only be exported from trained splats, not from nerfacto. + +Nerfstudio gaussian splat exports currently supports multiple third-party splat viewers: +- [Polycam Viewer](https://poly.cam/tools/gaussian-splatting) +- [Playcanvas SuperSplat](https://playcanvas.com/super-splat) +- [WebGL Viewer by antimatter15](https://antimatter15.com/splat/) +- [Spline](https://spline.design/) +- [Three.js Viewer by mkkellogg](https://github.com/mkkellogg/GaussianSplats3D) ### FAQ - Can I export a mesh or pointcloud? -Currently these export options are not supported, but may in the future and contributions are always welcome! +Currently these export options are not supported, but may become in the future. Contributions are always welcome! - Can I render fisheye, equirectangular, orthographic images? Currently, no. Gaussian splatting assumes a perspective camera for its rasterization pipeline. Implementing other camera models is of interest but not currently planned. diff --git a/nerfstudio/data/datamanagers/full_images_datamanager.py b/nerfstudio/data/datamanagers/full_images_datamanager.py index 61296e010a..f67bd120f4 100644 --- a/nerfstudio/data/datamanagers/full_images_datamanager.py +++ b/nerfstudio/data/datamanagers/full_images_datamanager.py @@ -61,6 +61,8 @@ class FullImageDatamanagerConfig(DataManagerConfig): """Specifies the image indices to use during eval; if None, uses all.""" cache_images: Literal["cpu", "gpu"] = "cpu" """Whether to cache images in memory. If "cpu", caches on cpu. If "gpu", caches on device.""" + cache_images_type: Literal["uint8", "float32"] = "float32" + """The image type returned from manager, caching images in uint8 saves memory""" class FullImageDatamanager(DataManager, Generic[TDataset]): @@ -126,7 +128,7 @@ def cache_images(self, cache_images_option): CONSOLE.log("Caching / undistorting train images") for i in tqdm(range(len(self.train_dataset)), leave=False): # cv2.undistort the images / cameras - data = self.train_dataset.get_data(i) + data = self.train_dataset.get_data(i, image_type=self.config.cache_images_type) camera = self.train_dataset.cameras[i].reshape(()) K = camera.get_intrinsics_matrices().numpy() if camera.distortion_params is None: @@ -202,7 +204,7 @@ def cache_images(self, cache_images_option): CONSOLE.log("Caching / undistorting eval images") for i in tqdm(range(len(self.eval_dataset)), leave=False): # cv2.undistort the images / cameras - data = self.eval_dataset.get_data(i) + data = self.eval_dataset.get_data(i, image_type=self.config.cache_images_type) camera = self.eval_dataset.cameras[i].reshape(()) K = camera.get_intrinsics_matrices().numpy() if camera.distortion_params is None: diff --git a/nerfstudio/data/dataparsers/colmap_dataparser.py b/nerfstudio/data/dataparsers/colmap_dataparser.py index a0ee580e35..ce3cf8919f 100644 --- a/nerfstudio/data/dataparsers/colmap_dataparser.py +++ b/nerfstudio/data/dataparsers/colmap_dataparser.py @@ -486,7 +486,7 @@ def get_fname(parent: Path, filepath: Path) -> Path: max_res = max(h, w) df = 0 while True: - if (max_res / 2 ** (df)) < MAX_AUTO_RESOLUTION: + if (max_res / 2 ** (df)) <= MAX_AUTO_RESOLUTION: break df += 1 diff --git a/nerfstudio/data/dataparsers/nerfstudio_dataparser.py b/nerfstudio/data/dataparsers/nerfstudio_dataparser.py index 554e88dac0..2d007d8e9a 100644 --- a/nerfstudio/data/dataparsers/nerfstudio_dataparser.py +++ b/nerfstudio/data/dataparsers/nerfstudio_dataparser.py @@ -376,7 +376,7 @@ def _get_fname(self, filepath: Path, data_dir: Path, downsample_folder_prefix="i max_res = max(h, w) df = 0 while True: - if (max_res / 2 ** (df)) < MAX_AUTO_RESOLUTION: + if (max_res / 2 ** (df)) <= MAX_AUTO_RESOLUTION: break if not (data_dir / f"{downsample_folder_prefix}{2**(df+1)}" / filepath.name).exists(): break diff --git a/nerfstudio/data/datasets/base_dataset.py b/nerfstudio/data/datasets/base_dataset.py index 6532d5d509..0e40f312d2 100644 --- a/nerfstudio/data/datasets/base_dataset.py +++ b/nerfstudio/data/datasets/base_dataset.py @@ -19,12 +19,12 @@ from copy import deepcopy from pathlib import Path -from typing import Dict, List +from typing import Dict, List, Literal import numpy as np import numpy.typing as npt import torch -from jaxtyping import Float +from jaxtyping import Float, UInt8 from PIL import Image from torch import Tensor from torch.utils.data import Dataset @@ -77,24 +77,51 @@ def get_numpy_image(self, image_idx: int) -> npt.NDArray[np.uint8]: assert image.shape[2] in [3, 4], f"Image shape of {image.shape} is in correct." return image - def get_image(self, image_idx: int) -> Float[Tensor, "image_height image_width num_channels"]: - """Returns a 3 channel image. + def get_image_float32(self, image_idx: int) -> Float[Tensor, "image_height image_width num_channels"]: + """Returns a 3 channel image in float32 torch.Tensor. Args: image_idx: The image index in the dataset. """ image = torch.from_numpy(self.get_numpy_image(image_idx).astype("float32") / 255.0) if self._dataparser_outputs.alpha_color is not None and image.shape[-1] == 4: + assert (self._dataparser_outputs.alpha_color >= 0).all() and ( + self._dataparser_outputs.alpha_color <= 1 + ).all(), "alpha color given is out of range between [0, 1]." image = image[:, :, :3] * image[:, :, -1:] + self._dataparser_outputs.alpha_color * (1.0 - image[:, :, -1:]) return image - def get_data(self, image_idx: int) -> Dict: + def get_image_uint8(self, image_idx: int) -> UInt8[Tensor, "image_height image_width num_channels"]: + """Returns a 3 channel image in uint8 torch.Tensor. + + Args: + image_idx: The image index in the dataset. + """ + image = torch.from_numpy(self.get_numpy_image(image_idx)) + if self._dataparser_outputs.alpha_color is not None and image.shape[-1] == 4: + assert (self._dataparser_outputs.alpha_color >= 0).all() and ( + self._dataparser_outputs.alpha_color <= 1 + ).all(), "alpha color given is out of range between [0, 1]." + image = image[:, :, :3] * image[:, :, -1:] / 255.0 + 255.0 * self._dataparser_outputs.alpha_color * ( + 1.0 - image[:, :, -1:] / 255.0 + ) + image = torch.clamp(image, min=0, max=255).to(torch.uint8) + return image + + def get_data(self, image_idx: int, image_type: Literal["uint8", "float32"] = "float32") -> Dict: """Returns the ImageDataset data as a dictionary. Args: image_idx: The image index in the dataset. + image_type: the type of images returned """ - image = self.get_image(image_idx) + if image_type == "float32": + image = self.get_image_float32(image_idx) + elif image_type == "uint8": + image = self.get_image_uint8(image_idx) + else: + raise NotImplementedError(f"image_type (={image_type}) getter was not implemented, use uint8 or float32") + data = {"image_idx": image_idx, "image": image} if self._dataparser_outputs.mask_filenames is not None: mask_filepath = self._dataparser_outputs.mask_filenames[image_idx] diff --git a/nerfstudio/data/pixel_samplers.py b/nerfstudio/data/pixel_samplers.py index 0fc58bde01..144e405a57 100644 --- a/nerfstudio/data/pixel_samplers.py +++ b/nerfstudio/data/pixel_samplers.py @@ -17,6 +17,7 @@ """ import random +import warnings from dataclasses import dataclass, field from typing import Dict, Optional, Type, Union @@ -42,6 +43,10 @@ class PixelSamplerConfig(InstantiateConfig): """List of whether or not camera i is equirectangular.""" fisheye_crop_radius: Optional[float] = None """Set to the radius (in pixels) for fisheye cameras.""" + rejection_sample_mask: bool = True + """Whether or not to use rejection sampling when sampling images with masks""" + max_num_iterations: int = 100 + """If rejection sampling masks, the maximum number of times to sample""" class PixelSampler: @@ -88,15 +93,44 @@ def sample_method( num_images: number of images to sample over mask: mask of possible pixels in an image to sample from. """ + indices = ( + torch.rand((batch_size, 3), device=device) + * torch.tensor([num_images, image_height, image_width], device=device) + ).long() + if isinstance(mask, torch.Tensor): - nonzero_indices = torch.nonzero(mask[..., 0], as_tuple=False) - chosen_indices = random.sample(range(len(nonzero_indices)), k=batch_size) - indices = nonzero_indices[chosen_indices] - else: - indices = ( - torch.rand((batch_size, 3), device=device) - * torch.tensor([num_images, image_height, image_width], device=device) - ).long() + if self.config.rejection_sample_mask: + num_valid = 0 + for _ in range(self.config.max_num_iterations): + c, y, x = (i.flatten() for i in torch.split(indices, 1, dim=-1)) + chosen_indices_validity = mask[..., 0][c, y, x].bool() + num_valid = int(torch.sum(chosen_indices_validity).item()) + if num_valid == batch_size: + break + else: + replacement_indices = ( + torch.rand((batch_size - num_valid, 3), device=device) + * torch.tensor([num_images, image_height, image_width], device=device) + ).long() + indices[~chosen_indices_validity] = replacement_indices + + if num_valid != batch_size: + warnings.warn( + """ + Masked sampling failed, mask is either empty or mostly empty. + Reverting behavior to non-rejection sampling. Consider setting + pipeline.datamanager.pixel-sampler.rejection-sample-mask to False + or increasing pipeline.datamanager.pixel-sampler.max-num-iterations + """ + ) + self.config.rejection_sample_mask = False + nonzero_indices = torch.nonzero(mask[..., 0], as_tuple=False) + chosen_indices = random.sample(range(len(nonzero_indices)), k=batch_size) + indices = nonzero_indices[chosen_indices] + else: + nonzero_indices = torch.nonzero(mask[..., 0], as_tuple=False) + chosen_indices = random.sample(range(len(nonzero_indices)), k=batch_size) + indices = nonzero_indices[chosen_indices] return indices diff --git a/nerfstudio/models/gaussian_splatting.py b/nerfstudio/models/gaussian_splatting.py index 0029d86274..f7280d1831 100644 --- a/nerfstudio/models/gaussian_splatting.py +++ b/nerfstudio/models/gaussian_splatting.py @@ -27,14 +27,12 @@ import numpy as np import torch from gsplat._torch_impl import quat_to_rotmat -from gsplat.compute_cumulative_intersects import compute_cumulative_intersects -from gsplat.project_gaussians import ProjectGaussians -from gsplat.rasterize import RasterizeGaussians -from gsplat.sh import SphericalHarmonics, num_sh_bases +from gsplat.project_gaussians import project_gaussians +from gsplat.rasterize import rasterize_gaussians +from gsplat.sh import num_sh_bases, spherical_harmonics from pytorch_msssim import SSIM from torch.nn import Parameter -from nerfstudio.cameras.camera_optimizers import CameraOptimizer, CameraOptimizerConfig from nerfstudio.cameras.cameras import Cameras from nerfstudio.data.scene_box import OrientedBox from nerfstudio.engine.callbacks import TrainingCallback, TrainingCallbackAttributes, TrainingCallbackLocation @@ -147,8 +145,6 @@ class GaussianSplattingModelConfig(ModelConfig): """stop splitting at this step""" sh_degree: int = 3 """maximum degree of spherical harmonics to use""" - camera_optimizer: CameraOptimizerConfig = field(default_factory=CameraOptimizerConfig) - """camera optimizer config""" use_scale_regularization: bool = False """If enabled, a scale regularization introduced in PhysGauss (https://xpandora.github.io/PhysGaussian/) is used for reducing huge spikey gaussians.""" max_gauss_ratio: float = 10.0 @@ -219,10 +215,6 @@ def populate_modules(self): else: self.back_color = get_color(self.config.background_color) - self.camera_optimizer: CameraOptimizer = self.config.camera_optimizer.setup( - num_cameras=self.num_train_data, device="cpu" - ) - @property def colors(self): if self.config.sh_degree > 0: @@ -331,7 +323,8 @@ def after_train(self, step: int): with torch.no_grad(): # keep track of a moving average of grad norms visible_mask = (self.radii > 0).flatten() - grads = self.xys.grad.detach().norm(dim=-1) # TODO fill in + assert self.xys.grad is not None + grads = self.xys.grad.detach().norm(dim=-1) # print(f"grad norm min {grads.min().item()} max {grads.max().item()} mean {grads.mean().item()} size {grads.shape}") if self.xys_grad_norm is None: self.xys_grad_norm = grads @@ -576,8 +569,6 @@ def get_param_groups(self) -> Dict[str, List[Parameter]]: Mapping of different parameter groups """ gps = self.get_gaussian_param_groups() - # add camera optimizer param groups - self.camera_optimizer.get_param_groups(gps) return gps def _get_downscale_factor(self): @@ -603,8 +594,8 @@ def get_outputs(self, camera: Cameras) -> Dict[str, Union[torch.Tensor, List]]: if self.training: # currently relies on the branch vickie/camera-grads self.camera_optimizer.apply_to_camera(camera) + # get the background color - if self.training: if self.config.background_color == "random": background = torch.rand(3, device=self.device) @@ -641,13 +632,13 @@ def get_outputs(self, camera: Cameras) -> Dict[str, Union[torch.Tensor, List]]: cy = camera.cy.item() fovx = 2 * math.atan(camera.width / (2 * camera.fx)) fovy = 2 * math.atan(camera.height / (2 * camera.fy)) - W, H = camera.width.item(), camera.height.item() + W, H = int(camera.width.item()), int(camera.height.item()) self.last_size = (H, W) projmat = projection_matrix(0.001, 1000, fovx, fovy, device=self.device) BLOCK_X, BLOCK_Y = 16, 16 tile_bounds = ( - (W + BLOCK_X - 1) // BLOCK_X, - (H + BLOCK_Y - 1) // BLOCK_Y, + int((W + BLOCK_X - 1) // BLOCK_X), + int((H + BLOCK_Y - 1) // BLOCK_Y), 1, ) @@ -668,7 +659,7 @@ def get_outputs(self, camera: Cameras) -> Dict[str, Union[torch.Tensor, List]]: colors_crop = torch.cat((features_dc_crop[:, None, :], features_rest_crop), dim=1) - self.xys, depths, self.radii, conics, num_tiles_hit, cov3d = ProjectGaussians.apply( + self.xys, depths, self.radii, conics, num_tiles_hit, cov3d = project_gaussians( # type: ignore means_crop, torch.exp(scales_crop), 1, @@ -694,7 +685,7 @@ def get_outputs(self, camera: Cameras) -> Dict[str, Union[torch.Tensor, List]]: viewdirs = means_crop.detach() - camera.camera_to_worlds.detach()[..., :3, 3] # (N, 3) viewdirs = viewdirs / viewdirs.norm(dim=-1, keepdim=True) n = min(self.step // self.config.sh_degree_interval, self.config.sh_degree) - rgbs = SphericalHarmonics.apply(n, viewdirs, colors_crop) + rgbs = spherical_harmonics(n, viewdirs, colors_crop) rgbs = torch.clamp(rgbs + 0.5, min=0.0) # type: ignore else: rgbs = torch.sigmoid(colors_crop[:, 0, :]) @@ -702,63 +693,70 @@ def get_outputs(self, camera: Cameras) -> Dict[str, Union[torch.Tensor, List]]: # rescale the camera back to original dimensions camera.rescale_output_resolution(camera_downscale) - # avoid empty rasterization - num_intersects, _ = compute_cumulative_intersects(self.xys.size(0), num_tiles_hit) - assert num_intersects > 0 + assert (num_tiles_hit > 0).any() # type: ignore - rgb = RasterizeGaussians.apply( + rgb = rasterize_gaussians( # type: ignore self.xys, depths, self.radii, conics, - num_tiles_hit, + num_tiles_hit, # type: ignore rgbs, torch.sigmoid(opacities_crop), H, W, - background, + background=background, ) # type: ignore rgb = torch.clamp(rgb, max=1.0) # type: ignore depth_im = None if not self.training: - depth_im = RasterizeGaussians.apply( # type: ignore + depth_im = rasterize_gaussians( # type: ignore self.xys, depths, self.radii, conics, - num_tiles_hit, + num_tiles_hit, # type: ignore depths[:, None].repeat(1, 3), torch.sigmoid(opacities_crop), H, W, - torch.ones(3, device=self.device) * 10, + background=torch.ones(3, device=self.device) * 10, )[..., 0:1] # type: ignore return {"rgb": rgb, "depth": depth_im} # type: ignore - def get_metrics_dict(self, outputs, batch) -> Dict[str, torch.Tensor]: - """Compute and returns metrics. + def get_gt_img(self, image: torch.Tensor): + """Compute groundtruth image with iteration dependent downscale factor for evaluation purpose Args: - outputs: the output to compute loss dict to - batch: ground truth batch corresponding to outputs + image: tensor.Tensor in type uint8 or float32 """ + if image.dtype == torch.uint8: + image = image.float() / 255.0 d = self._get_downscale_factor() if d > 1: - newsize = [batch["image"].shape[0] // d, batch["image"].shape[1] // d] + newsize = [image.shape[0] // d, image.shape[1] // d] # torchvision can be slow to import, so we do it lazily. import torchvision.transforms.functional as TF - gt_img = TF.resize(batch["image"].permute(2, 0, 1), newsize, antialias=None).permute(1, 2, 0) + gt_img = TF.resize(image.permute(2, 0, 1), newsize, antialias=None).permute(1, 2, 0) else: - gt_img = batch["image"] + gt_img = image + return gt_img.to(self.device) + + def get_metrics_dict(self, outputs, batch) -> Dict[str, torch.Tensor]: + """Compute and returns metrics. + + Args: + outputs: the output to compute loss dict to + batch: ground truth batch corresponding to outputs + """ + gt_rgb = self.get_gt_img(batch["image"]) metrics_dict = {} - gt_rgb = gt_img.to(self.device) # RGB or RGBA image predicted_rgb = outputs["rgb"] metrics_dict["psnr"] = self.psnr(predicted_rgb, gt_rgb) - self.camera_optimizer.get_metrics_dict(metrics_dict) metrics_dict["gaussian_count"] = self.num_points return metrics_dict @@ -770,16 +768,7 @@ def get_loss_dict(self, outputs, batch, metrics_dict=None) -> Dict[str, torch.Te batch: ground truth batch corresponding to outputs metrics_dict: dictionary of metrics, some of which we can use for loss """ - d = self._get_downscale_factor() - if d > 1: - newsize = [batch["image"].shape[0] // d, batch["image"].shape[1] // d] - - # torchvision can be slow to import, so we do it lazily. - import torchvision.transforms.functional as TF - - gt_img = TF.resize(batch["image"].permute(2, 0, 1), newsize, antialias=None).permute(1, 2, 0) - else: - gt_img = batch["image"] + gt_img = self.get_gt_img(batch["image"]) Ll1 = torch.abs(gt_img - outputs["rgb"]).mean() simloss = 1 - self.ssim(gt_img.permute(2, 0, 1)[None, ...], outputs["rgb"].permute(2, 0, 1)[None, ...]) if self.config.use_scale_regularization and self.step % 10 == 0: @@ -826,20 +815,17 @@ def get_image_metrics_and_images( Returns: A dictionary of metrics. """ + gt_rgb = self.get_gt_img(batch["image"]) d = self._get_downscale_factor() if d > 1: # torchvision can be slow to import, so we do it lazily. import torchvision.transforms.functional as TF newsize = [batch["image"].shape[0] // d, batch["image"].shape[1] // d] - gt_img = TF.resize(batch["image"].permute(2, 0, 1), newsize, antialias=None).permute(1, 2, 0) predicted_rgb = TF.resize(outputs["rgb"].permute(2, 0, 1), newsize, antialias=None).permute(1, 2, 0) else: - gt_img = batch["image"] predicted_rgb = outputs["rgb"] - gt_rgb = gt_img.to(self.device) - combined_rgb = torch.cat([gt_rgb, predicted_rgb], dim=1) # Switch images from [H, W, C] to [1, C, H, W] for metrics computations diff --git a/nerfstudio/scripts/render.py b/nerfstudio/scripts/render.py index bcf050df21..d47af785c3 100644 --- a/nerfstudio/scripts/render.py +++ b/nerfstudio/scripts/render.py @@ -234,7 +234,7 @@ def _render_trajectory_video( if render_nearest_camera: assert train_dataset is not None assert train_cameras is not None - img = train_dataset.get_image(max_idx) + img = train_dataset.get_image_float32(max_idx) height = cameras.image_height[0] # maintain the resolution of the img to calculate the width from the height width = int(img.shape[1] * (height / img.shape[0])) diff --git a/pyproject.toml b/pyproject.toml index 231b8aed4f..03acd7f2ee 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -62,7 +62,7 @@ dependencies = [ "xatlas", "trimesh>=3.20.2", "timm==0.6.7", - "gsplat==0.1.0", + "gsplat==0.1.2.1", "pytorch-msssim", "pathos" ]