From 51b54ebbd28c4b84843434ae36867be4c3f733bf Mon Sep 17 00:00:00 2001 From: Brent Yi Date: Thu, 18 Jan 2024 00:24:56 -0800 Subject: [PATCH 1/6] Gaussian splatting support for Aria --- nerfstudio/cameras/camera_utils.py | 2 +- nerfstudio/cameras/cameras.py | 2 +- .../datamanagers/full_images_datamanager.py | 255 ++++++++++-------- .../data/dataparsers/nerfstudio_dataparser.py | 1 + nerfstudio/models/gaussian_splatting.py | 29 +- .../scripts/datasets/process_project_aria.py | 17 ++ nerfstudio/utils/tensor_dataclass.py | 2 + 7 files changed, 185 insertions(+), 123 deletions(-) diff --git a/nerfstudio/cameras/camera_utils.py b/nerfstudio/cameras/camera_utils.py index 7c98ef080c..9c1ee02200 100644 --- a/nerfstudio/cameras/camera_utils.py +++ b/nerfstudio/cameras/camera_utils.py @@ -720,7 +720,7 @@ def fisheye624_unproject_helper(uv, params, max_iters: int = 5): function so this solves an optimization problem using Newton's method to get the inverse. Inputs: - uv: BxNx3 tensor of 2D pixels to be projected + uv: BxNx2 tensor of 2D pixels to be unprojected params: Bx16 tensor of Fisheye624 parameters formatted like this: [f_u f_v c_u c_v {k_0 ... k_5} {p_0 p_1} {s_0 s_1 s_2 s_3}] or Bx15 tensor of Fisheye624 parameters formatted like this: diff --git a/nerfstudio/cameras/cameras.py b/nerfstudio/cameras/cameras.py index 4202c8c273..e390360b5e 100644 --- a/nerfstudio/cameras/cameras.py +++ b/nerfstudio/cameras/cameras.py @@ -864,7 +864,7 @@ def _compute_rays_for_vr180( assert distortion_params is not None masked_coords = pcoord_stack[coord_mask, :] - # The fisheye unprojection does not rely on planar/pinhold unprojection, thus the method needs + # The fisheye unprojection does not rely on planar/pinhole unprojection, thus the method needs # to access the focal length and principle points directly. camera_params = torch.cat( [ diff --git a/nerfstudio/data/datamanagers/full_images_datamanager.py b/nerfstudio/data/datamanagers/full_images_datamanager.py index bfc720d126..9fe19f43ae 100644 --- a/nerfstudio/data/datamanagers/full_images_datamanager.py +++ b/nerfstudio/data/datamanagers/full_images_datamanager.py @@ -34,6 +34,7 @@ from torch.nn import Parameter from tqdm import tqdm +from nerfstudio.cameras.camera_utils import fisheye624_project from nerfstudio.cameras.cameras import Cameras, CameraType from nerfstudio.configs.dataparser_configs import AnnotatedDataParserUnion from nerfstudio.data.datamanagers.base_datamanager import DataManager, DataManagerConfig, TDataset @@ -133,63 +134,11 @@ def cache_images(self, cache_images_option): continue distortion_params = camera.distortion_params.numpy() image = data["image"].numpy() - if camera.camera_type.item() == CameraType.PERSPECTIVE.value: - distortion_params = np.array( - [ - distortion_params[0], - distortion_params[1], - distortion_params[4], - distortion_params[5], - distortion_params[2], - distortion_params[3], - 0, - 0, - ] - ) - if np.any(distortion_params): - newK, roi = cv2.getOptimalNewCameraMatrix(K, distortion_params, (image.shape[1], image.shape[0]), 0) - image = cv2.undistort(image, K, distortion_params, None, newK) # type: ignore - else: - newK = K - roi = 0, 0, image.shape[1], image.shape[0] - # crop the image and update the intrinsics accordingly - x, y, w, h = roi - image = image[y : y + h, x : x + w] - if "depth_image" in data: - data["depth_image"] = data["depth_image"][y : y + h, x : x + w] - # update the width, height - self.train_dataset.cameras.width[i] = w - self.train_dataset.cameras.height[i] = h - if "mask" in data: - mask = data["mask"].numpy() - mask = mask.astype(np.uint8) * 255 - if np.any(distortion_params): - mask = cv2.undistort(mask, K, distortion_params, None, newK) # type: ignore - mask = mask[y : y + h, x : x + w] - data["mask"] = torch.from_numpy(mask).bool() - K = newK - - elif camera.camera_type.item() == CameraType.FISHEYE.value: - distortion_params = np.array( - [distortion_params[0], distortion_params[1], distortion_params[2], distortion_params[3]] - ) - newK = cv2.fisheye.estimateNewCameraMatrixForUndistortRectify( - K, distortion_params, (image.shape[1], image.shape[0]), np.eye(3), balance=0 - ) - map1, map2 = cv2.fisheye.initUndistortRectifyMap( - K, distortion_params, np.eye(3), newK, (image.shape[1], image.shape[0]), cv2.CV_32FC1 - ) - # and then remap: - image = cv2.remap(image, map1, map2, interpolation=cv2.INTER_LINEAR, borderMode=cv2.BORDER_CONSTANT) - if "mask" in data: - mask = data["mask"].numpy() - mask = mask.astype(np.uint8) * 255 - mask = cv2.fisheye.undistortImage(mask, K, distortion_params, None, newK) - data["mask"] = torch.from_numpy(mask).bool() - K = newK - else: - raise NotImplementedError("Only perspective and fisheye cameras are supported") + + K, image, mask = _undistort_image(camera, distortion_params, data, image, K) data["image"] = torch.from_numpy(image) + if mask is not None: + data["mask"] = mask cached_train.append(data) @@ -197,6 +146,8 @@ def cache_images(self, cache_images_option): self.train_dataset.cameras.fy[i] = float(K[1, 1]) self.train_dataset.cameras.cx[i] = float(K[0, 2]) self.train_dataset.cameras.cy[i] = float(K[1, 2]) + self.train_dataset.cameras.width[i] = image.shape[1] + self.train_dataset.cameras.height[i] = image.shape[0] CONSOLE.log("Caching / undistorting eval images") for i in tqdm(range(len(self.eval_dataset)), leave=False): @@ -208,61 +159,11 @@ def cache_images(self, cache_images_option): continue distortion_params = camera.distortion_params.numpy() image = data["image"].numpy() - if camera.camera_type.item() == CameraType.PERSPECTIVE.value: - distortion_params = np.array( - [ - distortion_params[0], - distortion_params[1], - distortion_params[4], - distortion_params[5], - distortion_params[2], - distortion_params[3], - 0, - 0, - ] - ) - if np.any(distortion_params): - newK, roi = cv2.getOptimalNewCameraMatrix(K, distortion_params, (image.shape[1], image.shape[0]), 0) - image = cv2.undistort(image, K, distortion_params, None, newK) # type: ignore - else: - newK = K - roi = 0, 0, image.shape[1], image.shape[0] - # crop the image and update the intrinsics accordingly - x, y, w, h = roi - image = image[y : y + h, x : x + w] - # update the width, height - self.eval_dataset.cameras.width[i] = w - self.eval_dataset.cameras.height[i] = h - if "mask" in data: - mask = data["mask"].numpy() - mask = mask.astype(np.uint8) * 255 - if np.any(distortion_params): - mask = cv2.undistort(mask, K, distortion_params, None, newK) # type: ignore - mask = mask[y : y + h, x : x + w] - data["mask"] = torch.from_numpy(mask).bool() - K = newK - - elif camera.camera_type.item() == CameraType.FISHEYE.value: - distortion_params = np.array( - [distortion_params[0], distortion_params[1], distortion_params[2], distortion_params[3]] - ) - newK = cv2.fisheye.estimateNewCameraMatrixForUndistortRectify( - K, distortion_params, (image.shape[1], image.shape[0]), np.eye(3), balance=0 - ) - map1, map2 = cv2.fisheye.initUndistortRectifyMap( - K, distortion_params, np.eye(3), newK, (image.shape[1], image.shape[0]), cv2.CV_32FC1 - ) - # and then remap: - image = cv2.remap(image, map1, map2, interpolation=cv2.INTER_LINEAR, borderMode=cv2.BORDER_CONSTANT) - if "mask" in data: - mask = data["mask"].numpy() - mask = mask.astype(np.uint8) * 255 - mask = cv2.fisheye.undistortImage(mask, K, distortion_params, None, newK) - data["mask"] = torch.from_numpy(mask).bool() - K = newK - else: - raise NotImplementedError("Only perspective and fisheye cameras are supported") + + K, image, mask = _undistort_image(camera, distortion_params, data, image, K) data["image"] = torch.from_numpy(image) + if mask is not None: + data["mask"] = mask cached_eval.append(data) @@ -414,3 +315,137 @@ def next_eval_image(self, step: int) -> Tuple[Cameras, Dict]: assert len(self.eval_dataset.cameras.shape) == 1, "Assumes single batch dimension" camera = self.eval_dataset.cameras[image_idx : image_idx + 1].to(self.device) return camera, data + + +def _undistort_image( + camera: Cameras, distortion_params: np.ndarray, data: dict, image: np.ndarray, K: np.ndarray +) -> Tuple[np.ndarray, np.ndarray, Optional[torch.Tensor]]: + mask = None + if camera.camera_type.item() == CameraType.PERSPECTIVE.value: + distortion_params = np.array( + [ + distortion_params[0], + distortion_params[1], + distortion_params[4], + distortion_params[5], + distortion_params[2], + distortion_params[3], + 0, + 0, + ] + ) + if np.any(distortion_params): + newK, roi = cv2.getOptimalNewCameraMatrix(K, distortion_params, (image.shape[1], image.shape[0]), 0) + image = cv2.undistort(image, K, distortion_params, None, newK) # type: ignore + else: + newK = K + roi = 0, 0, image.shape[1], image.shape[0] + # crop the image and update the intrinsics accordingly + x, y, w, h = roi + image = image[y : y + h, x : x + w] + if "depth_image" in data: + data["depth_image"] = data["depth_image"][y : y + h, x : x + w] + if "mask" in data: + mask = data["mask"].numpy() + mask = mask.astype(np.uint8) * 255 + if np.any(distortion_params): + mask = cv2.undistort(mask, K, distortion_params, None, newK) # type: ignore + mask = mask[y : y + h, x : x + w] + mask = torch.from_numpy(mask).bool() + K = newK + + elif camera.camera_type.item() == CameraType.FISHEYE.value: + distortion_params = np.array( + [distortion_params[0], distortion_params[1], distortion_params[2], distortion_params[3]] + ) + newK = cv2.fisheye.estimateNewCameraMatrixForUndistortRectify( + K, distortion_params, (image.shape[1], image.shape[0]), np.eye(3), balance=0 + ) + map1, map2 = cv2.fisheye.initUndistortRectifyMap( + K, distortion_params, np.eye(3), newK, (image.shape[1], image.shape[0]), cv2.CV_32FC1 + ) + # and then remap: + image = cv2.remap(image, map1, map2, interpolation=cv2.INTER_LINEAR) + if "mask" in data: + mask = data["mask"].numpy() + mask = mask.astype(np.uint8) * 255 + mask = cv2.fisheye.undistortImage(mask, K, distortion_params, None, newK) + mask = torch.from_numpy(mask).bool() + K = newK + elif camera.camera_type.item() == CameraType.FISHEYE624.value: + fisheye624_params = np.concatenate([camera.fx, camera.fy, camera.cx, camera.cy, distortion_params], axis=0) + assert fisheye624_params.shape == (16,) + + # Desired parameters of the undistorted image. + import warnings + + warnings.warn("Fisheye624 support in the full images datamanager currently assumes data from Project Aria.") + undist_h = 500 + undist_w = 500 + undist_K = torch.eye(3) + undist_K[0, 0] = 150.0 # fx + undist_K[1, 1] = 150.0 # fy + undist_K[0, 2] = (undist_w - 1) / 2.0 # cx; for a 1x1 image, center should be at (0, 0). + undist_K[1, 2] = (undist_h - 1) / 2.0 # cy + + undistort_uv_homog = torch.stack( + [ + *torch.meshgrid( + torch.arange(undist_w, dtype=torch.float32), + torch.arange(undist_h, dtype=torch.float32), + ), + torch.ones((undist_w, undist_h), dtype=torch.float32), + ], + dim=-1, + ) + assert undistort_uv_homog.shape == (undist_w, undist_h, 3) + + unproj = ( + fisheye624_project( + xyz=( + torch.einsum( + "ij,bj->bi", + torch.linalg.inv(undist_K), + undistort_uv_homog.reshape((undist_w * undist_h, 3)), + )[None] + ), + params=torch.from_numpy(fisheye624_params[None, :]), + ) + .reshape((undist_w, undist_h, 2)) + .numpy() + ) + map1 = unproj[..., 1] + map2 = unproj[..., 0] + + image = cv2.remap(image, map1, map2, interpolation=cv2.INTER_LINEAR) + + dist_h = camera.height.item() + dist_w = camera.width.item() + + # Compute distorted mask. + assert ( + "mask" not in data + and camera.metadata is not None + and "fisheye_crop_radius" in camera.metadata + and isinstance(camera.metadata["fisheye_crop_radius"], float) + ) + mask = np.mgrid[:dist_h, :dist_w] + mask[0, ...] -= dist_h // 2 + mask[1, ...] -= dist_w // 2 + mask = np.linalg.norm(mask, axis=0) < camera.metadata["fisheye_crop_radius"] + mask = torch.from_numpy( + cv2.remap( + mask.astype(np.uint8) * 255, + map1, + map2, + interpolation=cv2.INTER_LINEAR, + borderMode=cv2.BORDER_CONSTANT, + borderValue=0, + ) + / 255.0 + ).bool() + K = undist_K.numpy() + else: + raise NotImplementedError("Only perspective and fisheye cameras are supported") + + return K, image, mask diff --git a/nerfstudio/data/dataparsers/nerfstudio_dataparser.py b/nerfstudio/data/dataparsers/nerfstudio_dataparser.py index 554e88dac0..2e28c1da65 100644 --- a/nerfstudio/data/dataparsers/nerfstudio_dataparser.py +++ b/nerfstudio/data/dataparsers/nerfstudio_dataparser.py @@ -319,6 +319,7 @@ def _generate_dataparser_outputs(self, split="train"): if "ply_file_path" in meta: ply_file_path = data_dir / meta["ply_file_path"] metadata.update(self._load_3D_points(ply_file_path, transform_matrix, scale_factor)) + print("loaded points!") dataparser_outputs = DataparserOutputs( image_filenames=image_filenames, diff --git a/nerfstudio/models/gaussian_splatting.py b/nerfstudio/models/gaussian_splatting.py index ab17c6a37c..f5b8e75a1e 100644 --- a/nerfstudio/models/gaussian_splatting.py +++ b/nerfstudio/models/gaussian_splatting.py @@ -162,16 +162,18 @@ class GaussianSplattingModel(Model): config: GaussianSplattingModelConfig - def __init__(self, *args, **kwargs): - if "seed_points" in kwargs: - self.seed_pts = kwargs["seed_points"] - else: - self.seed_pts = None + def __init__( + self, + *args, + seed_points: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + **kwargs, + ): + self.seed_points = seed_points super().__init__(*args, **kwargs) def populate_modules(self): - if self.seed_pts is not None and not self.config.random_init: - self.means = torch.nn.Parameter(self.seed_pts[0]) # (Location, Color) + if self.seed_points is not None and not self.config.random_init: + self.means = torch.nn.Parameter(self.seed_points[0]) # (Location, Color) else: self.means = torch.nn.Parameter((torch.rand((500000, 3)) - 0.5) * 10) self.xys_grad_norm = None @@ -184,14 +186,19 @@ def populate_modules(self): self.quats = torch.nn.Parameter(random_quat_tensor(self.num_points)) dim_sh = num_sh_bases(self.config.sh_degree) - if self.seed_pts is not None and not self.config.random_init: - shs = torch.zeros((self.seed_pts[1].shape[0], dim_sh, 3)).float().cuda() + if ( + self.seed_points is not None + and not self.config.random_init + # We can have colors without points. + and self.seed_points[1].shape[0] > 0 + ): + shs = torch.zeros((self.seed_points[1].shape[0], dim_sh, 3)).float().cuda() if self.config.sh_degree > 0: - shs[:, 0, :3] = RGB2SH(self.seed_pts[1] / 255) + shs[:, 0, :3] = RGB2SH(self.seed_points[1] / 255) shs[:, 1:, 3:] = 0.0 else: CONSOLE.log("use color only optimization with sigmoid activation") - shs[:, 0, :3] = torch.logit(self.seed_pts[1] / 255, eps=1e-10) + shs[:, 0, :3] = torch.logit(self.seed_points[1] / 255, eps=1e-10) self.features_dc = torch.nn.Parameter(shs[:, 0, :]) self.features_rest = torch.nn.Parameter(shs[:, 1:, :]) else: diff --git a/nerfstudio/scripts/datasets/process_project_aria.py b/nerfstudio/scripts/datasets/process_project_aria.py index 10f26653fd..c0493cbf5b 100644 --- a/nerfstudio/scripts/datasets/process_project_aria.py +++ b/nerfstudio/scripts/datasets/process_project_aria.py @@ -20,6 +20,8 @@ from typing import Dict, List import numpy as np +import open3d as o3d +import pandas as pd import tyro from PIL import Image @@ -219,6 +221,21 @@ def main(self) -> None: "fisheye_crop_radius": rgb_valid_radius, } + # save global point cloud, which is useful for Gaussian Splatting. + points_path = self.mps_data_dir / "global_points.csv.gz" + if points_path.exists(): + print("Found global points, saving to PLY...") + df = pd.read_csv(points_path, compression="gzip") + points = df[["px_world", "py_world", "pz_world"]].values + pcd = o3d.geometry.PointCloud() + pcd.points = o3d.utility.Vector3dVector(points) + ply_file_path = self.output_dir / "global_points.ply" + o3d.io.write_point_cloud(str(ply_file_path), pcd) + + nerfstudio_frames["ply_file_path"] = "global_points.ply" + else: + print("No global points found!") + # write the json out to disk as transforms.json print("Writing transforms.json") transform_file = self.output_dir / "transforms.json" diff --git a/nerfstudio/utils/tensor_dataclass.py b/nerfstudio/utils/tensor_dataclass.py index a2b8d1dadb..1710881adc 100644 --- a/nerfstudio/utils/tensor_dataclass.py +++ b/nerfstudio/utils/tensor_dataclass.py @@ -141,6 +141,8 @@ def _broadcast_dict_fields(self, dict_: Dict, batch_shape) -> Dict: new_dict[k] = v.broadcast_to(batch_shape) elif isinstance(v, Dict): new_dict[k] = self._broadcast_dict_fields(v, batch_shape) + else: + new_dict[k] = v return new_dict def __getitem__(self: TensorDataclassT, indices) -> TensorDataclassT: From 736ce81160cf70a80c773beffd63855b59d2a83c Mon Sep 17 00:00:00 2001 From: Brent Yi Date: Thu, 18 Jan 2024 00:37:09 -0800 Subject: [PATCH 2/6] Fix undistorted eval image height/width --- nerfstudio/data/datamanagers/full_images_datamanager.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/nerfstudio/data/datamanagers/full_images_datamanager.py b/nerfstudio/data/datamanagers/full_images_datamanager.py index 9fe19f43ae..976cb6ebf1 100644 --- a/nerfstudio/data/datamanagers/full_images_datamanager.py +++ b/nerfstudio/data/datamanagers/full_images_datamanager.py @@ -171,6 +171,8 @@ def cache_images(self, cache_images_option): self.eval_dataset.cameras.fy[i] = float(K[1, 1]) self.eval_dataset.cameras.cx[i] = float(K[0, 2]) self.eval_dataset.cameras.cy[i] = float(K[1, 2]) + self.eval_dataset.cameras.width[i] = image.shape[1] + self.eval_dataset.cameras.height[i] = image.shape[0] if cache_images_option == "gpu": for cache in cached_train: From 0b80a90a79869bc8b757bb091c2e04a038cc7948 Mon Sep 17 00:00:00 2001 From: Brent Yi Date: Fri, 19 Jan 2024 00:10:04 -0800 Subject: [PATCH 3/6] MPS api for reading point cloud --- nerfstudio/scripts/datasets/process_project_aria.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/nerfstudio/scripts/datasets/process_project_aria.py b/nerfstudio/scripts/datasets/process_project_aria.py index c0493cbf5b..23d304cf6a 100644 --- a/nerfstudio/scripts/datasets/process_project_aria.py +++ b/nerfstudio/scripts/datasets/process_project_aria.py @@ -17,17 +17,17 @@ import threading from dataclasses import dataclass from pathlib import Path -from typing import Dict, List +from typing import Any, Dict, List, cast import numpy as np import open3d as o3d -import pandas as pd import tyro from PIL import Image try: from projectaria_tools.core import mps from projectaria_tools.core.data_provider import VrsDataProvider, create_vrs_data_provider + from projectaria_tools.core.mps.utils import filter_points_from_confidence from projectaria_tools.core.sophus import SE3 except ImportError: print("projectaria_tools import failed, please install with pip3 install projectaria-tools'[all]'") @@ -225,10 +225,10 @@ def main(self) -> None: points_path = self.mps_data_dir / "global_points.csv.gz" if points_path.exists(): print("Found global points, saving to PLY...") - df = pd.read_csv(points_path, compression="gzip") - points = df[["px_world", "py_world", "pz_world"]].values + points_data = mps.read_global_point_cloud(str(points_path)) # type: ignore + points_data = filter_points_from_confidence(points_data) pcd = o3d.geometry.PointCloud() - pcd.points = o3d.utility.Vector3dVector(points) + pcd.points = o3d.utility.Vector3dVector(np.array([cast(Any, it).position_world for it in points_data])) ply_file_path = self.output_dir / "global_points.ply" o3d.io.write_point_cloud(str(ply_file_path), pcd) From 2832b20627c56c2ec1b50878bb714c2557a0b2fa Mon Sep 17 00:00:00 2001 From: Brent Yi Date: Fri, 19 Jan 2024 00:15:10 -0800 Subject: [PATCH 4/6] Hotfix non-colmap nerfstudio datasets --- .../data/dataparsers/nerfstudio_dataparser.py | 18 ++++++++++-------- 1 file changed, 10 insertions(+), 8 deletions(-) diff --git a/nerfstudio/data/dataparsers/nerfstudio_dataparser.py b/nerfstudio/data/dataparsers/nerfstudio_dataparser.py index c13efbfe7a..b569508523 100644 --- a/nerfstudio/data/dataparsers/nerfstudio_dataparser.py +++ b/nerfstudio/data/dataparsers/nerfstudio_dataparser.py @@ -311,16 +311,20 @@ def _generate_dataparser_outputs(self, split="train"): # - transform_matrix contains the transformation to dataparser output coordinates from saved coordinates. # - dataparser_transform_matrix contains the transformation to dataparser output coordinates from original data coordinates. # - applied_transform contains the transformation to saved coordinates from original data coordinates. - if "applied_transform" not in meta: + applied_transform = None + colmap_path = self.config.data / "colmap/sparse/0" + if "applied_transform" not in meta and colmap_path.exists(): # For converting from colmap, this was the effective value of applied_transform that was being # used before we added the applied_transform field to the output dataformat. meta["applied_transform"] = [[0, 1, 0, 0], [1, 0, 0, 0], [0, 0, -1, 0]] + applied_transform = torch.tensor(meta["applied_transform"], dtype=transform_matrix.dtype) - applied_transform = torch.tensor(meta["applied_transform"], dtype=transform_matrix.dtype) - - dataparser_transform_matrix = transform_matrix @ torch.cat( - [applied_transform, torch.tensor([[0, 0, 0, 1]], dtype=transform_matrix.dtype)], 0 - ) + if applied_transform is not None: + dataparser_transform_matrix = transform_matrix @ torch.cat( + [applied_transform, torch.tensor([[0, 0, 0, 1]], dtype=transform_matrix.dtype)], 0 + ) + else: + dataparser_transform_matrix = transform_matrix if "applied_scale" in meta: applied_scale = float(meta["applied_scale"]) @@ -337,8 +341,6 @@ def _generate_dataparser_outputs(self, split="train"): # Load 3D points if self.config.load_3D_points: - colmap_path = self.config.data / "colmap/sparse/0" - if "ply_file_path" in meta: ply_file_path = data_dir / meta["ply_file_path"] From d1783dfb5baf880395fae11f25a8f73517bc09df Mon Sep 17 00:00:00 2001 From: Brent Yi Date: Fri, 19 Jan 2024 01:32:12 -0800 Subject: [PATCH 5/6] Respect masks in splatting loss function --- .../data/datamanagers/full_images_datamanager.py | 5 +++-- nerfstudio/models/gaussian_splatting.py | 14 ++++++++++++-- 2 files changed, 15 insertions(+), 4 deletions(-) diff --git a/nerfstudio/data/datamanagers/full_images_datamanager.py b/nerfstudio/data/datamanagers/full_images_datamanager.py index 0971be40ad..81a7295e80 100644 --- a/nerfstudio/data/datamanagers/full_images_datamanager.py +++ b/nerfstudio/data/datamanagers/full_images_datamanager.py @@ -31,6 +31,9 @@ import cv2 import numpy as np import torch +from torch.nn import Parameter +from tqdm import tqdm + from nerfstudio.cameras.camera_utils import fisheye624_project, fisheye624_unproject_helper from nerfstudio.cameras.cameras import Cameras, CameraType from nerfstudio.configs.dataparser_configs import AnnotatedDataParserUnion @@ -40,8 +43,6 @@ from nerfstudio.data.datasets.base_dataset import InputDataset from nerfstudio.utils.misc import get_orig_class from nerfstudio.utils.rich_utils import CONSOLE -from torch.nn import Parameter -from tqdm import tqdm @dataclass diff --git a/nerfstudio/models/gaussian_splatting.py b/nerfstudio/models/gaussian_splatting.py index 18210ce612..2365e43f7c 100644 --- a/nerfstudio/models/gaussian_splatting.py +++ b/nerfstudio/models/gaussian_splatting.py @@ -761,8 +761,18 @@ def get_loss_dict(self, outputs, batch, metrics_dict=None) -> Dict[str, torch.Te metrics_dict: dictionary of metrics, some of which we can use for loss """ gt_img = self.get_gt_img(batch["image"]) - Ll1 = torch.abs(gt_img - outputs["rgb"]).mean() - simloss = 1 - self.ssim(gt_img.permute(2, 0, 1)[None, ...], outputs["rgb"].permute(2, 0, 1)[None, ...]) + pred_img = outputs["rgb"] + + # Set masked part of both ground-truth and rendered image to black. + # This is a little bit sketchy for the SSIM loss. + if "mask" in batch: + assert batch["mask"].shape == gt_img.shape[:2] == pred_img.shape[:2] + mask = batch["mask"][..., None].to(self.device) + gt_img = gt_img * mask + pred_img = pred_img * mask + + Ll1 = torch.abs(gt_img - pred_img).mean() + simloss = 1 - self.ssim(gt_img.permute(2, 0, 1)[None, ...], pred_img.permute(2, 0, 1)[None, ...]) if self.config.use_scale_regularization and self.step % 10 == 0: scale_exp = torch.exp(self.scales) scale_reg = ( From 4803937a767ae76b80241edc84d8922fd42ef75b Mon Sep 17 00:00:00 2001 From: Brent Yi Date: Fri, 19 Jan 2024 01:46:03 -0800 Subject: [PATCH 6/6] Nits --- .../datamanagers/full_images_datamanager.py | 26 +++++++------------ 1 file changed, 10 insertions(+), 16 deletions(-) diff --git a/nerfstudio/data/datamanagers/full_images_datamanager.py b/nerfstudio/data/datamanagers/full_images_datamanager.py index 81a7295e80..ea0df36e3d 100644 --- a/nerfstudio/data/datamanagers/full_images_datamanager.py +++ b/nerfstudio/data/datamanagers/full_images_datamanager.py @@ -389,11 +389,6 @@ def _undistort_image( ) fisheye_crop_radius = camera.metadata["fisheye_crop_radius"] - # Desired parameters of the undistorted image. - import warnings - - warnings.warn("Fisheye624 support in the full images datamanager currently assumes data from Project Aria.") - # Approximate the FOV of the unmasked region of the camera. upper, lower, left, right = fisheye624_unproject_helper( torch.tensor( @@ -412,7 +407,7 @@ def _undistort_image( torch.acos(torch.sum(left * right / torch.linalg.norm(left) / torch.linalg.norm(right))), ) - # Heuristic to determine the parameters of the undistorted image. + # Heuristics to determine parameters of an undistorted image. undist_h = int(fisheye_crop_radius * 2) undist_w = int(fisheye_crop_radius * 2) undistort_focal = undist_h / (2 * torch.tan(fov_radians / 2.0)) @@ -422,7 +417,8 @@ def _undistort_image( undist_K[0, 2] = (undist_w - 1) / 2.0 # cx; for a 1x1 image, center should be at (0, 0). undist_K[1, 2] = (undist_h - 1) / 2.0 # cy - undistort_uv_homog = torch.stack( + # Undistorted 2D coordinates -> rays -> reproject to distorted UV coordinates. + undist_uv_homog = torch.stack( [ *torch.meshgrid( torch.arange(undist_w, dtype=torch.float32), @@ -432,15 +428,14 @@ def _undistort_image( ], dim=-1, ) - assert undistort_uv_homog.shape == (undist_w, undist_h, 3) - - unproj = ( + assert undist_uv_homog.shape == (undist_w, undist_h, 3) + dist_uv = ( fisheye624_project( xyz=( torch.einsum( "ij,bj->bi", torch.linalg.inv(undist_K), - undistort_uv_homog.reshape((undist_w * undist_h, 3)), + undist_uv_homog.reshape((undist_w * undist_h, 3)), )[None] ), params=fisheye624_params[None, :], @@ -448,15 +443,15 @@ def _undistort_image( .reshape((undist_w, undist_h, 2)) .numpy() ) - map1 = unproj[..., 1] - map2 = unproj[..., 0] + map1 = dist_uv[..., 1] + map2 = dist_uv[..., 0] + # Use correspondence to undistort image. image = cv2.remap(image, map1, map2, interpolation=cv2.INTER_LINEAR) + # Compute undistorted mask as well. dist_h = camera.height.item() dist_w = camera.width.item() - - # Compute distorted mask. mask = np.mgrid[:dist_h, :dist_w] mask[0, ...] -= dist_h // 2 mask[1, ...] -= dist_w // 2 @@ -473,7 +468,6 @@ def _undistort_image( / 255.0 ).bool() K = undist_K.numpy() - else: raise NotImplementedError("Only perspective and fisheye cameras are supported")