From 6f2b87e8c279f96312975727978fad8a0bd45bca Mon Sep 17 00:00:00 2001 From: "J.Y" <132313008+jb-ye@users.noreply.github.com> Date: Thu, 18 Jan 2024 14:03:11 -0500 Subject: [PATCH] Add option of caching image in bytes rather than float32 (#2741) Co-authored-by: Justin Kerr --- .../datamanagers/full_images_datamanager.py | 6 ++- nerfstudio/data/datasets/base_dataset.py | 39 ++++++++++++++--- nerfstudio/models/gaussian_splatting.py | 42 +++++++++---------- nerfstudio/scripts/render.py | 2 +- 4 files changed, 58 insertions(+), 31 deletions(-) diff --git a/nerfstudio/data/datamanagers/full_images_datamanager.py b/nerfstudio/data/datamanagers/full_images_datamanager.py index bfc720d126..ec90b4b34c 100644 --- a/nerfstudio/data/datamanagers/full_images_datamanager.py +++ b/nerfstudio/data/datamanagers/full_images_datamanager.py @@ -61,6 +61,8 @@ class FullImageDatamanagerConfig(DataManagerConfig): """Specifies the image indices to use during eval; if None, uses all.""" cache_images: Literal["cpu", "gpu"] = "cpu" """Whether to cache images in memory. If "cpu", caches on cpu. If "gpu", caches on device.""" + cache_images_type: Literal["uint8", "float32"] = "float32" + """The image type returned from manager, caching images in uint8 saves memory""" class FullImageDatamanager(DataManager, Generic[TDataset]): @@ -126,7 +128,7 @@ def cache_images(self, cache_images_option): CONSOLE.log("Caching / undistorting train images") for i in tqdm(range(len(self.train_dataset)), leave=False): # cv2.undistort the images / cameras - data = self.train_dataset.get_data(i) + data = self.train_dataset.get_data(i, image_type=self.config.cache_images_type) camera = self.train_dataset.cameras[i].reshape(()) K = camera.get_intrinsics_matrices().numpy() if camera.distortion_params is None: @@ -201,7 +203,7 @@ def cache_images(self, cache_images_option): CONSOLE.log("Caching / undistorting eval images") for i in tqdm(range(len(self.eval_dataset)), leave=False): # cv2.undistort the images / cameras - data = self.eval_dataset.get_data(i) + data = self.eval_dataset.get_data(i, image_type=self.config.cache_images_type) camera = self.eval_dataset.cameras[i].reshape(()) K = camera.get_intrinsics_matrices().numpy() if camera.distortion_params is None: diff --git a/nerfstudio/data/datasets/base_dataset.py b/nerfstudio/data/datasets/base_dataset.py index 6532d5d509..0e40f312d2 100644 --- a/nerfstudio/data/datasets/base_dataset.py +++ b/nerfstudio/data/datasets/base_dataset.py @@ -19,12 +19,12 @@ from copy import deepcopy from pathlib import Path -from typing import Dict, List +from typing import Dict, List, Literal import numpy as np import numpy.typing as npt import torch -from jaxtyping import Float +from jaxtyping import Float, UInt8 from PIL import Image from torch import Tensor from torch.utils.data import Dataset @@ -77,24 +77,51 @@ def get_numpy_image(self, image_idx: int) -> npt.NDArray[np.uint8]: assert image.shape[2] in [3, 4], f"Image shape of {image.shape} is in correct." return image - def get_image(self, image_idx: int) -> Float[Tensor, "image_height image_width num_channels"]: - """Returns a 3 channel image. + def get_image_float32(self, image_idx: int) -> Float[Tensor, "image_height image_width num_channels"]: + """Returns a 3 channel image in float32 torch.Tensor. Args: image_idx: The image index in the dataset. """ image = torch.from_numpy(self.get_numpy_image(image_idx).astype("float32") / 255.0) if self._dataparser_outputs.alpha_color is not None and image.shape[-1] == 4: + assert (self._dataparser_outputs.alpha_color >= 0).all() and ( + self._dataparser_outputs.alpha_color <= 1 + ).all(), "alpha color given is out of range between [0, 1]." image = image[:, :, :3] * image[:, :, -1:] + self._dataparser_outputs.alpha_color * (1.0 - image[:, :, -1:]) return image - def get_data(self, image_idx: int) -> Dict: + def get_image_uint8(self, image_idx: int) -> UInt8[Tensor, "image_height image_width num_channels"]: + """Returns a 3 channel image in uint8 torch.Tensor. + + Args: + image_idx: The image index in the dataset. + """ + image = torch.from_numpy(self.get_numpy_image(image_idx)) + if self._dataparser_outputs.alpha_color is not None and image.shape[-1] == 4: + assert (self._dataparser_outputs.alpha_color >= 0).all() and ( + self._dataparser_outputs.alpha_color <= 1 + ).all(), "alpha color given is out of range between [0, 1]." + image = image[:, :, :3] * image[:, :, -1:] / 255.0 + 255.0 * self._dataparser_outputs.alpha_color * ( + 1.0 - image[:, :, -1:] / 255.0 + ) + image = torch.clamp(image, min=0, max=255).to(torch.uint8) + return image + + def get_data(self, image_idx: int, image_type: Literal["uint8", "float32"] = "float32") -> Dict: """Returns the ImageDataset data as a dictionary. Args: image_idx: The image index in the dataset. + image_type: the type of images returned """ - image = self.get_image(image_idx) + if image_type == "float32": + image = self.get_image_float32(image_idx) + elif image_type == "uint8": + image = self.get_image_uint8(image_idx) + else: + raise NotImplementedError(f"image_type (={image_type}) getter was not implemented, use uint8 or float32") + data = {"image_idx": image_idx, "image": image} if self._dataparser_outputs.mask_filenames is not None: mask_filepath = self._dataparser_outputs.mask_filenames[image_idx] diff --git a/nerfstudio/models/gaussian_splatting.py b/nerfstudio/models/gaussian_splatting.py index 00f8c82705..b8102e7368 100644 --- a/nerfstudio/models/gaussian_splatting.py +++ b/nerfstudio/models/gaussian_splatting.py @@ -722,25 +722,35 @@ def get_outputs(self, camera: Cameras) -> Dict[str, Union[torch.Tensor, List]]: return {"rgb": rgb, "depth": depth_im} # type: ignore - def get_metrics_dict(self, outputs, batch) -> Dict[str, torch.Tensor]: - """Compute and returns metrics. + def get_gt_img(self, image: torch.Tensor): + """Compute groundtruth image with iteration dependent downscale factor for evaluation purpose Args: - outputs: the output to compute loss dict to - batch: ground truth batch corresponding to outputs + image: tensor.Tensor in type uint8 or float32 """ + if image.dtype == torch.uint8: + image = image.float() / 255.0 d = self._get_downscale_factor() if d > 1: - newsize = [batch["image"].shape[0] // d, batch["image"].shape[1] // d] + newsize = [image.shape[0] // d, image.shape[1] // d] # torchvision can be slow to import, so we do it lazily. import torchvision.transforms.functional as TF - gt_img = TF.resize(batch["image"].permute(2, 0, 1), newsize, antialias=None).permute(1, 2, 0) + gt_img = TF.resize(image.permute(2, 0, 1), newsize, antialias=None).permute(1, 2, 0) else: - gt_img = batch["image"] + gt_img = image + return gt_img.to(self.device) + + def get_metrics_dict(self, outputs, batch) -> Dict[str, torch.Tensor]: + """Compute and returns metrics. + + Args: + outputs: the output to compute loss dict to + batch: ground truth batch corresponding to outputs + """ + gt_rgb = self.get_gt_img(batch["image"]) metrics_dict = {} - gt_rgb = gt_img.to(self.device) # RGB or RGBA image predicted_rgb = outputs["rgb"] metrics_dict["psnr"] = self.psnr(predicted_rgb, gt_rgb) @@ -756,16 +766,7 @@ def get_loss_dict(self, outputs, batch, metrics_dict=None) -> Dict[str, torch.Te batch: ground truth batch corresponding to outputs metrics_dict: dictionary of metrics, some of which we can use for loss """ - d = self._get_downscale_factor() - if d > 1: - newsize = [batch["image"].shape[0] // d, batch["image"].shape[1] // d] - - # torchvision can be slow to import, so we do it lazily. - import torchvision.transforms.functional as TF - - gt_img = TF.resize(batch["image"].permute(2, 0, 1), newsize, antialias=None).permute(1, 2, 0) - else: - gt_img = batch["image"] + gt_img = self.get_gt_img(batch["image"]) Ll1 = torch.abs(gt_img - outputs["rgb"]).mean() simloss = 1 - self.ssim(gt_img.permute(2, 0, 1)[None, ...], outputs["rgb"].permute(2, 0, 1)[None, ...]) if self.config.use_scale_regularization and self.step % 10 == 0: @@ -812,20 +813,17 @@ def get_image_metrics_and_images( Returns: A dictionary of metrics. """ + gt_rgb = self.get_gt_img(batch["image"]) d = self._get_downscale_factor() if d > 1: # torchvision can be slow to import, so we do it lazily. import torchvision.transforms.functional as TF newsize = [batch["image"].shape[0] // d, batch["image"].shape[1] // d] - gt_img = TF.resize(batch["image"].permute(2, 0, 1), newsize, antialias=None).permute(1, 2, 0) predicted_rgb = TF.resize(outputs["rgb"].permute(2, 0, 1), newsize, antialias=None).permute(1, 2, 0) else: - gt_img = batch["image"] predicted_rgb = outputs["rgb"] - gt_rgb = gt_img.to(self.device) - combined_rgb = torch.cat([gt_rgb, predicted_rgb], dim=1) # Switch images from [H, W, C] to [1, C, H, W] for metrics computations diff --git a/nerfstudio/scripts/render.py b/nerfstudio/scripts/render.py index bcf050df21..d47af785c3 100644 --- a/nerfstudio/scripts/render.py +++ b/nerfstudio/scripts/render.py @@ -234,7 +234,7 @@ def _render_trajectory_video( if render_nearest_camera: assert train_dataset is not None assert train_cameras is not None - img = train_dataset.get_image(max_idx) + img = train_dataset.get_image_float32(max_idx) height = cameras.image_height[0] # maintain the resolution of the img to calculate the width from the height width = int(img.shape[1] * (height / img.shape[0]))