diff --git a/docs/nerfology/methods/splat.md b/docs/nerfology/methods/splat.md index 8e613e251b..30456d054f 100644 --- a/docs/nerfology/methods/splat.md +++ b/docs/nerfology/methods/splat.md @@ -27,11 +27,11 @@ Because gaussian splatting trains on *full images* instead of bundles of rays, t ### Running the Method -To run gaussian splatting, run `ns-train gaussian-splatting --data `. Just like NeRF methods, the splat can be interactively viewed in the web-viewer, loaded from a checkpoint, rendered, and exported. +To run gaussian splatting, run `ns-train splatfacto --data `. Just like NeRF methods, the splat can be interactively viewed in the web-viewer, loaded from a checkpoint, rendered, and exported. #### Quality and Regularization The default settings provided maintain a balance between speed, quality, and splat file size, but if you care more about quality than training speed or size, you can decrease the alpha cull threshold -(threshold to delete translucent gaussians) and disable culling after 15k steps like so: `ns-train gaussian-splatting --pipeline.model.cull_scale_thresh=0.005 --pipeline.model.continue_cull_post_densification=False --data ` +(threshold to delete translucent gaussians) and disable culling after 15k steps like so: `ns-train splatfacto --pipeline.model.cull_scale_thresh=0.005 --pipeline.model.continue_cull_post_densification=False --data ` A common artifact in splatting is long, spikey gaussians. [PhysGaussian](https://xpandora.github.io/PhysGaussian/) proposes an example of a scale-regularizer that encourages gaussians to be more evenly shaped. To enable this, set the `use_scale_regularization` flag to `True`. diff --git a/nerfstudio/configs/method_configs.py b/nerfstudio/configs/method_configs.py index f60450163b..dd842e8ab6 100644 --- a/nerfstudio/configs/method_configs.py +++ b/nerfstudio/configs/method_configs.py @@ -51,7 +51,6 @@ from nerfstudio.field_components.temporal_distortions import TemporalDistortionKind from nerfstudio.fields.sdf_field import SDFFieldConfig from nerfstudio.models.depth_nerfacto import DepthNerfactoModelConfig -from nerfstudio.models.gaussian_splatting import GaussianSplattingModelConfig from nerfstudio.models.generfacto import GenerfactoModelConfig from nerfstudio.models.instant_ngp import InstantNGPModelConfig from nerfstudio.models.mipnerf import MipNerfModel @@ -59,6 +58,7 @@ from nerfstudio.models.neus import NeuSModelConfig from nerfstudio.models.neus_facto import NeuSFactoModelConfig from nerfstudio.models.semantic_nerfw import SemanticNerfWModelConfig +from nerfstudio.models.splatfacto import SplatfactoModelConfig from nerfstudio.models.tensorf import TensoRFModelConfig from nerfstudio.models.vanilla_nerf import NeRFModel, VanillaModelConfig from nerfstudio.pipelines.base_pipeline import VanillaPipelineConfig @@ -80,7 +80,7 @@ "generfacto": "Generative Text to NeRF model", "neus": "Implementation of NeuS. (slow)", "neus-facto": "Implementation of NeuS-Facto. (slow)", - "gaussian-splatting": "Gaussian Splatting model", + "splatfacto": "Gaussian Splatting model", } method_configs["nerfacto"] = TrainerConfig( @@ -588,8 +588,8 @@ vis="viewer", ) -method_configs["gaussian-splatting"] = TrainerConfig( - method_name="gaussian-splatting", +method_configs["splatfacto"] = TrainerConfig( + method_name="splatfacto", steps_per_eval_image=100, steps_per_eval_batch=0, steps_per_save=2000, @@ -601,7 +601,7 @@ datamanager=FullImageDatamanagerConfig( dataparser=NerfstudioDataParserConfig(load_3D_points=True), ), - model=GaussianSplattingModelConfig(), + model=SplatfactoModelConfig(), ), optimizers={ "xyz": { diff --git a/nerfstudio/data/dataparsers/nerfstudio_dataparser.py b/nerfstudio/data/dataparsers/nerfstudio_dataparser.py index dcc4ace781..e8636a77c4 100644 --- a/nerfstudio/data/dataparsers/nerfstudio_dataparser.py +++ b/nerfstudio/data/dataparsers/nerfstudio_dataparser.py @@ -386,7 +386,7 @@ def _generate_dataparser_outputs(self, split="train"): else: if not self.prompted_user: CONSOLE.print( - "[bold yellow]Warning: load_3D_points set to true but no point cloud found. gaussian-splatting models will use random point cloud initialization." + "[bold yellow]Warning: load_3D_points set to true but no point cloud found. splatfacto models will use random point cloud initialization." ) ply_file_path = None diff --git a/nerfstudio/models/gaussian_splatting.py b/nerfstudio/models/splatfacto.py similarity index 99% rename from nerfstudio/models/gaussian_splatting.py rename to nerfstudio/models/splatfacto.py index 2365e43f7c..6944d377c3 100644 --- a/nerfstudio/models/gaussian_splatting.py +++ b/nerfstudio/models/splatfacto.py @@ -99,10 +99,10 @@ def projection_matrix(znear, zfar, fovx, fovy, device: Union[str, torch.device] @dataclass -class GaussianSplattingModelConfig(ModelConfig): +class SplatfactoModelConfig(ModelConfig): """Gaussian Splatting Model Config""" - _target: Type = field(default_factory=lambda: GaussianSplattingModel) + _target: Type = field(default_factory=lambda: SplatfactoModel) warmup_length: int = 500 """period of steps where refinement is turned off""" refine_every: int = 100 @@ -149,14 +149,14 @@ class GaussianSplattingModelConfig(ModelConfig): """ -class GaussianSplattingModel(Model): +class SplatfactoModel(Model): """Gaussian Splatting model Args: config: Gaussian Splatting configuration to instantiate model """ - config: GaussianSplattingModelConfig + config: SplatfactoModelConfig def __init__( self, diff --git a/nerfstudio/scripts/exporter.py b/nerfstudio/scripts/exporter.py index 45be2e84cc..e526268890 100644 --- a/nerfstudio/scripts/exporter.py +++ b/nerfstudio/scripts/exporter.py @@ -40,7 +40,7 @@ from nerfstudio.exporter.exporter_utils import collect_camera_poses, generate_point_cloud, get_mesh_from_filename from nerfstudio.exporter.marching_cubes import generate_mesh_with_multires_marching_cubes from nerfstudio.fields.sdf_field import SDFField # noqa -from nerfstudio.models.gaussian_splatting import GaussianSplattingModel +from nerfstudio.models.splatfacto import SplatfactoModel from nerfstudio.pipelines.base_pipeline import Pipeline, VanillaPipeline from nerfstudio.utils.eval_utils import eval_setup from nerfstudio.utils.rich_utils import CONSOLE @@ -488,9 +488,9 @@ def main(self) -> None: _, pipeline, _, _ = eval_setup(self.load_config) - assert isinstance(pipeline.model, GaussianSplattingModel) + assert isinstance(pipeline.model, SplatfactoModel) - model: GaussianSplattingModel = pipeline.model + model: SplatfactoModel = pipeline.model filename = self.output_dir / "splat.ply" diff --git a/nerfstudio/viewer/export_panel.py b/nerfstudio/viewer/export_panel.py index 5d93a20dd8..76808a2725 100644 --- a/nerfstudio/viewer/export_panel.py +++ b/nerfstudio/viewer/export_panel.py @@ -22,7 +22,7 @@ from nerfstudio.data.scene_box import OrientedBox from nerfstudio.models.base_model import Model -from nerfstudio.models.gaussian_splatting import GaussianSplattingModel +from nerfstudio.models.splatfacto import SplatfactoModel from nerfstudio.viewer.control_panel import ControlPanel @@ -32,7 +32,7 @@ def populate_export_tab( config_path: Path, viewer_model: Model, ) -> None: - viewing_gsplat = isinstance(viewer_model, GaussianSplattingModel) + viewing_gsplat = isinstance(viewer_model, SplatfactoModel) if not viewing_gsplat: crop_output = server.add_gui_checkbox("Use Crop", False) diff --git a/nerfstudio/viewer/render_state_machine.py b/nerfstudio/viewer/render_state_machine.py index 74fde38336..d116a190e4 100644 --- a/nerfstudio/viewer/render_state_machine.py +++ b/nerfstudio/viewer/render_state_machine.py @@ -28,7 +28,7 @@ from nerfstudio.cameras.cameras import Cameras from nerfstudio.model_components.renderers import background_color_override_context -from nerfstudio.models.gaussian_splatting import GaussianSplattingModel +from nerfstudio.models.splatfacto import SplatfactoModel from nerfstudio.utils import colormaps, writer from nerfstudio.utils.writer import GLOBAL_BUFFER, EventName, TimeWriter from nerfstudio.viewer.utils import CameraState, get_camera @@ -136,7 +136,7 @@ def _render_img(self, camera_state: CameraState): with TimeWriter(None, None, write=False) as vis_t: with self.viewer.train_lock if self.viewer.train_lock is not None else contextlib.nullcontext(): - if isinstance(self.viewer.get_model(), GaussianSplattingModel): + if isinstance(self.viewer.get_model(), SplatfactoModel): color = self.viewer.control_panel.background_color background_color = torch.tensor( [color[0] / 255.0, color[1] / 255.0, color[2] / 255.0], @@ -168,7 +168,7 @@ def _render_img(self, camera_state: CameraState): self.viewer.get_model().train() num_rays = (camera.height * camera.width).item() if self.viewer.control_panel.layer_depth: - if isinstance(self.viewer.get_model(), GaussianSplattingModel): + if isinstance(self.viewer.get_model(), SplatfactoModel): # Gaussians render much faster than we can send depth images, so we do some downsampling. assert len(outputs["depth"].shape) == 3 assert outputs["depth"].shape[-1] == 1 diff --git a/nerfstudio/viewer/viewer.py b/nerfstudio/viewer/viewer.py index 579dbb5cc4..4480bf214e 100644 --- a/nerfstudio/viewer/viewer.py +++ b/nerfstudio/viewer/viewer.py @@ -32,7 +32,7 @@ from nerfstudio.configs import base_config as cfg from nerfstudio.data.datasets.base_dataset import InputDataset from nerfstudio.models.base_model import Model -from nerfstudio.models.gaussian_splatting import GaussianSplattingModel +from nerfstudio.models.splatfacto import SplatfactoModel from nerfstudio.pipelines.base_pipeline import Pipeline from nerfstudio.utils.decorators import check_main_thread, decorate_all from nerfstudio.utils.writer import GLOBAL_BUFFER, EventName @@ -250,7 +250,7 @@ def nested_folder_install(folder_labels: List[str], prev_labels: List[str], elem # Diagnostics for Gaussian Splatting: where the points are at the start of training. # This is hidden by default, it can be shown from the Viser UI's scene tree table. - if isinstance(pipeline.model, GaussianSplattingModel): + if isinstance(pipeline.model, SplatfactoModel): self.viser_server.add_point_cloud( "/gaussian_splatting_initial_points", points=pipeline.model.means.numpy(force=True) * VISER_NERFSTUDIO_SCALE_RATIO, diff --git a/nerfstudio/viewer_legacy/server/render_state_machine.py b/nerfstudio/viewer_legacy/server/render_state_machine.py index d842598d1a..a3c0524906 100644 --- a/nerfstudio/viewer_legacy/server/render_state_machine.py +++ b/nerfstudio/viewer_legacy/server/render_state_machine.py @@ -24,7 +24,7 @@ from nerfstudio.cameras.cameras import Cameras from nerfstudio.model_components.renderers import background_color_override_context -from nerfstudio.models.gaussian_splatting import GaussianSplattingModel +from nerfstudio.models.splatfacto import SplatfactoModel from nerfstudio.utils import colormaps, writer from nerfstudio.utils.writer import GLOBAL_BUFFER, EventName, TimeWriter from nerfstudio.viewer_legacy.server import viewer_utils @@ -130,7 +130,7 @@ def _render_img(self, cam_msg: CameraMessage): with self.viewer.train_lock if self.viewer.train_lock is not None else contextlib.nullcontext(): # TODO jake-austin: Make this check whether the model inherits from a camera based model or a ray based model # TODO Zhuoyang: First made some dummy judgements, need to be fixed later - isGaussianSplattingModel = isinstance(self.viewer.get_model(), GaussianSplattingModel) + isGaussianSplattingModel = isinstance(self.viewer.get_model(), SplatfactoModel) if isGaussianSplattingModel: # TODO fix me before ship camera_ray_bundle = camera.generate_rays(camera_indices=0, aabb_box=self.viewer.get_model().render_aabb) diff --git a/tests/test_train.py b/tests/test_train.py index 45d8e348f0..82e198373b 100644 --- a/tests/test_train.py +++ b/tests/test_train.py @@ -26,7 +26,7 @@ "neus", "generfacto", "neus-facto", - "gaussian-splatting", + "splatfacto", ]