diff --git a/nerfstudio/configs/method_configs.py b/nerfstudio/configs/method_configs.py index 92da20ccb2..ef86a6b320 100644 --- a/nerfstudio/configs/method_configs.py +++ b/nerfstudio/configs/method_configs.py @@ -31,7 +31,6 @@ from nerfstudio.data.datamanagers.parallel_datamanager import ParallelDataManagerConfig from nerfstudio.data.datamanagers.random_cameras_datamanager import RandomCamerasDataManagerConfig from nerfstudio.data.dataparsers.blender_dataparser import BlenderDataParserConfig -from nerfstudio.data.dataparsers.colmap_dataparser import ColmapDataParserConfig from nerfstudio.data.dataparsers.dnerf_dataparser import DNeRFDataParserConfig from nerfstudio.data.dataparsers.instant_ngp_dataparser import InstantNGPDataParserConfig from nerfstudio.data.dataparsers.nerfstudio_dataparser import NerfstudioDataParserConfig @@ -648,7 +647,7 @@ gradient_accumulation_steps={"camera_opt": 100}, pipeline=VanillaPipelineConfig( datamanager=FullImageDatamanagerConfig( - dataparser=ColmapDataParserConfig(load_3D_points=True), + dataparser=NerfstudioDataParserConfig(load_3D_points=True), ), model=GaussianSplattingModelConfig(), ), diff --git a/nerfstudio/data/dataparsers/colmap_dataparser.py b/nerfstudio/data/dataparsers/colmap_dataparser.py index a5a9685cf6..5f5917e703 100644 --- a/nerfstudio/data/dataparsers/colmap_dataparser.py +++ b/nerfstudio/data/dataparsers/colmap_dataparser.py @@ -67,8 +67,8 @@ class ColmapDataParserConfig(DataParserConfig): assume_colmap_world_coordinate_convention: bool = True """Colmap optimized world often have y direction of the first camera pointing towards down direction, while nerfstudio world set z direction to be up direction for viewer. Therefore, we usually need to apply an extra - transform when orientation_method=none. This parameter has no effects if orientation_method is set other than none. - When this parameter is set to False, no extra transform is applied when reading data from colmap. + transform when orientation_method=none. This parameter has no effects if orientation_method is set other than none. + When this parameter is set to False, no extra transform is applied when reading data from colmap. """ eval_mode: Literal["fraction", "filename", "interval", "all"] = "interval" """ @@ -93,8 +93,9 @@ class ColmapDataParserConfig(DataParserConfig): """Path to depth maps directory. If not set, depths are not loaded.""" colmap_path: Path = Path("colmap/sparse/0") """Path to the colmap reconstruction directory relative to the data path.""" - load_3D_points: bool = False - """Whether to load the 3D points from the colmap reconstruction.""" + load_3D_points: bool = True + """Whether to load the 3D points from the colmap reconstruction. This is helpful for Gaussian splatting and + generally unused otherwise, but it's typically harmless so we default to True.""" max_2D_matches_per_3D_point: int = 0 """Maximum number of 2D matches per 3D point. If set to -1, all 2D matches are loaded. If set to 0, no 2D matches are loaded.""" diff --git a/nerfstudio/data/dataparsers/nerfstudio_dataparser.py b/nerfstudio/data/dataparsers/nerfstudio_dataparser.py index 2d007d8e9a..c13efbfe7a 100644 --- a/nerfstudio/data/dataparsers/nerfstudio_dataparser.py +++ b/nerfstudio/data/dataparsers/nerfstudio_dataparser.py @@ -73,6 +73,8 @@ class NerfstudioDataParserConfig(DataParserConfig): """The interval between frames to use for eval. Only used when eval_mode is eval-interval.""" depth_unit_scale_factor: float = 1e-3 """Scales the depth values to meters. Default value is 0.001 for a millimeter to meter conversion.""" + load_3D_points: bool = False + """Whether to load the 3D points from the colmap reconstruction.""" @dataclass @@ -305,20 +307,90 @@ def _generate_dataparser_outputs(self, split="train"): assert self.downscale_factor is not None cameras.rescale_output_resolution(scaling_factor=1.0 / self.downscale_factor) - if "applied_transform" in meta: - applied_transform = torch.tensor(meta["applied_transform"], dtype=transform_matrix.dtype) - transform_matrix = transform_matrix @ torch.cat( - [applied_transform, torch.tensor([[0, 0, 0, 1]], dtype=transform_matrix.dtype)], 0 - ) + # The naming is somewhat confusing, but: + # - transform_matrix contains the transformation to dataparser output coordinates from saved coordinates. + # - dataparser_transform_matrix contains the transformation to dataparser output coordinates from original data coordinates. + # - applied_transform contains the transformation to saved coordinates from original data coordinates. + if "applied_transform" not in meta: + # For converting from colmap, this was the effective value of applied_transform that was being + # used before we added the applied_transform field to the output dataformat. + meta["applied_transform"] = [[0, 1, 0, 0], [1, 0, 0, 0], [0, 0, -1, 0]] + + applied_transform = torch.tensor(meta["applied_transform"], dtype=transform_matrix.dtype) + + dataparser_transform_matrix = transform_matrix @ torch.cat( + [applied_transform, torch.tensor([[0, 0, 0, 1]], dtype=transform_matrix.dtype)], 0 + ) + if "applied_scale" in meta: applied_scale = float(meta["applied_scale"]) scale_factor *= applied_scale - # Load 3D points + # reinitialize metadata for dataparser_outputs metadata = {} - if "ply_file_path" in meta: - ply_file_path = data_dir / meta["ply_file_path"] - metadata.update(self._load_3D_points(ply_file_path, transform_matrix, scale_factor)) + + # _generate_dataparser_outputs might be called more than once so we check if we already loaded the point cloud + try: + self.prompted_user + except AttributeError: + self.prompted_user = False + + # Load 3D points + if self.config.load_3D_points: + colmap_path = self.config.data / "colmap/sparse/0" + + if "ply_file_path" in meta: + ply_file_path = data_dir / meta["ply_file_path"] + + elif colmap_path.exists(): + from rich.prompt import Confirm + + # check if user wants to make a point cloud from colmap points + if not self.prompted_user: + self.create_pc = Confirm.ask( + "load_3D_points is true, but the dataset was processed with an outdated ns-process-data that didn't convert colmap points to .ply! Update the colmap dataset automatically?" + ) + + if self.create_pc: + import json + + from nerfstudio.process_data.colmap_utils import create_ply_from_colmap + + with open(self.config.data / "transforms.json") as f: + transforms = json.load(f) + + # Update dataset if missing the applied_transform field. + if "applied_transform" not in transforms: + transforms["applied_transform"] = meta["applied_transform"] + + ply_filename = "sparse_pc.ply" + create_ply_from_colmap( + filename=ply_filename, + recon_dir=colmap_path, + output_dir=self.config.data, + applied_transform=applied_transform, + ) + ply_file_path = data_dir / ply_filename + transforms["ply_file_path"] = ply_filename + + # This was the applied_transform value + + with open(self.config.data / "transforms.json", "w", encoding="utf-8") as f: + json.dump(transforms, f, indent=4) + else: + ply_file_path = None + else: + if not self.prompted_user: + CONSOLE.print( + "[bold yellow]Warning: load_3D_points set to true but no point cloud found. gaussian-splatting models will use random point cloud initialization." + ) + ply_file_path = None + + if ply_file_path: + sparse_points = self._load_3D_points(ply_file_path, transform_matrix, scale_factor) + if sparse_points is not None: + metadata.update(sparse_points) + self.prompted_user = True dataparser_outputs = DataparserOutputs( image_filenames=image_filenames, @@ -326,7 +398,7 @@ def _generate_dataparser_outputs(self, split="train"): scene_box=scene_box, mask_filenames=mask_filenames if len(mask_filenames) > 0 else None, dataparser_scale=scale_factor, - dataparser_transform=transform_matrix, + dataparser_transform=dataparser_transform_matrix, metadata={ "depth_filenames": depth_filenames if len(depth_filenames) > 0 else None, "depth_unit_scale_factor": self.config.depth_unit_scale_factor, @@ -336,10 +408,24 @@ def _generate_dataparser_outputs(self, split="train"): return dataparser_outputs def _load_3D_points(self, ply_file_path: Path, transform_matrix: torch.Tensor, scale_factor: float): + """Loads point clouds positions and colors from .ply + + Args: + ply_file_path: Path to .ply file + transform_matrix: Matrix to transform world coordinates + scale_factor: How much to scale the camera origins by. + + Returns: + A dictionary of points: points3D_xyz and colors: points3D_rgb + """ import open3d as o3d # Importing open3d is slow, so we only do it if we need it. pcd = o3d.io.read_point_cloud(str(ply_file_path)) + # if no points found don't read in an initial point cloud + if len(pcd.points) == 0: + return None + points3D = torch.from_numpy(np.asarray(pcd.points, dtype=np.float32)) points3D = ( torch.cat( diff --git a/nerfstudio/process_data/colmap_utils.py b/nerfstudio/process_data/colmap_utils.py index 348f2b31a6..2f2ac3021a 100644 --- a/nerfstudio/process_data/colmap_utils.py +++ b/nerfstudio/process_data/colmap_utils.py @@ -18,7 +18,7 @@ import json from pathlib import Path -from typing import Any, Dict, Literal, Optional +from typing import Any, Dict, Literal, Optional, Union import appdirs import cv2 @@ -34,6 +34,7 @@ read_cameras_binary, read_images_binary, read_points3D_binary, + read_points3D_text, ) from nerfstudio.process_data.process_data_utils import CameraModel from nerfstudio.utils import colormaps @@ -391,6 +392,7 @@ def colmap_to_json( camera_mask_path: Optional[Path] = None, image_id_to_depth_path: Optional[Dict[int, Path]] = None, image_rename_map: Optional[Dict[str, str]] = None, + ply_filename="sparse_pc.ply", keep_original_world_coordinate: bool = False, ) -> int: """Converts COLMAP's cameras.bin and images.bin to a JSON file. @@ -459,12 +461,23 @@ def colmap_to_json( out = parse_colmap_camera_params(cam_id_to_camera[1]) out["frames"] = frames + applied_transform = None if not keep_original_world_coordinate: applied_transform = np.eye(4)[:3, :] applied_transform = applied_transform[np.array([0, 2, 1]), :] applied_transform[2, :] *= -1 out["applied_transform"] = applied_transform.tolist() + # create ply from colmap + assert ply_filename.endswith(".ply"), f"ply_filename: {ply_filename} does not end with '.ply'" + create_ply_from_colmap( + ply_filename, + recon_dir, + output_dir, + torch.from_numpy(applied_transform).float() if applied_transform is not None else None, + ) + out["ply_file_path"] = ply_filename + with open(output_dir / "transforms.json", "w", encoding="utf-8") as f: json.dump(out, f, indent=4) @@ -643,3 +656,49 @@ def get_matching_summary(num_initial_frames: int, num_matched_frames: int) -> st result += " or large exposure changes." return result return f"[bold green]COLMAP found poses for {num_matched_frames / num_initial_frames * 100:.2f}% of the images." + + +def create_ply_from_colmap( + filename: str, recon_dir: Path, output_dir: Path, applied_transform: Union[torch.Tensor, None] +) -> None: + """Writes a ply file from colmap. + + Args: + filename: file name for .ply + recon_dir: Directory to grab colmap points + output_dir: Directory to output .ply + """ + if (recon_dir / "points3D.bin").exists(): + colmap_points = read_points3D_binary(recon_dir / "points3D.bin") + elif (recon_dir / "points3D.txt").exists(): + colmap_points = read_points3D_text(recon_dir / "points3D.txt") + else: + raise ValueError(f"Could not find points3D.txt or points3D.bin in {recon_dir}") + + # Load point Positions + points3D = torch.from_numpy(np.array([p.xyz for p in colmap_points.values()], dtype=np.float32)) + if applied_transform is not None: + assert applied_transform.shape == (3, 4) + points3D = torch.einsum("ij,bj->bi", applied_transform[:3, :3], points3D) + applied_transform[:3, 3] + + # Load point colours + points3D_rgb = torch.from_numpy(np.array([p.rgb for p in colmap_points.values()], dtype=np.uint8)) + + # write ply + with open(output_dir / filename, "w") as f: + # Header + f.write("ply\n") + f.write("format ascii 1.0\n") + f.write(f"element vertex {len(points3D)}\n") + f.write("property float x\n") + f.write("property float y\n") + f.write("property float z\n") + f.write("property uint8 red\n") + f.write("property uint8 green\n") + f.write("property uint8 blue\n") + f.write("end_header\n") + + for coord, color in zip(points3D, points3D_rgb): + x, y, z = coord + r, g, b = color + f.write(f"{x:8f} {y:8f} {z:8f} {r} {g} {b}\n") diff --git a/nerfstudio/viewer/viewer.py b/nerfstudio/viewer/viewer.py index c7a881b80d..579dbb5cc4 100644 --- a/nerfstudio/viewer/viewer.py +++ b/nerfstudio/viewer/viewer.py @@ -32,6 +32,7 @@ from nerfstudio.configs import base_config as cfg from nerfstudio.data.datasets.base_dataset import InputDataset from nerfstudio.models.base_model import Model +from nerfstudio.models.gaussian_splatting import GaussianSplattingModel from nerfstudio.pipelines.base_pipeline import Pipeline from nerfstudio.utils.decorators import check_main_thread, decorate_all from nerfstudio.utils.writer import GLOBAL_BUFFER, EventName @@ -247,6 +248,17 @@ def nested_folder_install(folder_labels: List[str], prev_labels: List[str], elem for c in self.viewer_controls: c._setup(self) + # Diagnostics for Gaussian Splatting: where the points are at the start of training. + # This is hidden by default, it can be shown from the Viser UI's scene tree table. + if isinstance(pipeline.model, GaussianSplattingModel): + self.viser_server.add_point_cloud( + "/gaussian_splatting_initial_points", + points=pipeline.model.means.numpy(force=True) * VISER_NERFSTUDIO_SCALE_RATIO, + colors=(255, 0, 0), + point_size=0.01, + point_shape="circle", + visible=False, # Hidden by default. + ) self.ready = True def toggle_pause_button(self) -> None: diff --git a/tests/process_data/test_process_images.py b/tests/process_data/test_process_images.py index eed454b11f..8482676c53 100644 --- a/tests/process_data/test_process_images.py +++ b/tests/process_data/test_process_images.py @@ -12,9 +12,11 @@ from nerfstudio.data.utils.colmap_parsing_utils import ( Camera, Image as ColmapImage, + Point3D, qvec2rotmat, write_cameras_binary, write_images_binary, + write_points3D_binary, ) from nerfstudio.process_data.images_to_nerfstudio_dataset import ImagesToNerfstudioDataset @@ -50,6 +52,19 @@ def test_process_images_skip_colmap(tmp_path: Path): {1: Camera(1, "OPENCV", width, height, [110, 110, 50, 75, 0, 0, 0, 0, 0, 0])}, sparse_path / "cameras.bin", ) + write_points3D_binary( + { + 1: Point3D( + id=1, + xyz=np.array([0, 0, 0]), + rgb=np.array([0, 0, 0]), + error=np.array([0]), + image_ids=np.array([1]), + point2D_idxs=np.array([0]), + ), + }, + sparse_path / "points3D.bin", + ) frames = {} num_frames = 10 qvecs = random_quaternion(num_frames)