From 5bc7f79f769df4714b9018cb1f7d442bc837987c Mon Sep 17 00:00:00 2001 From: Brent Yi Date: Fri, 12 Jan 2024 23:20:49 -0800 Subject: [PATCH] Clean up `ns-train {method} --help` for not-yet-installed external methods --- nerfstudio/configs/external_methods.py | 54 ++++++++++++++------------ 1 file changed, 30 insertions(+), 24 deletions(-) diff --git a/nerfstudio/configs/external_methods.py b/nerfstudio/configs/external_methods.py index d530cf5bc6..3ab210a95d 100644 --- a/nerfstudio/configs/external_methods.py +++ b/nerfstudio/configs/external_methods.py @@ -14,15 +14,16 @@ """This file contains the configuration for external methods which are not included in this repository.""" +import inspect import subprocess import sys -from dataclasses import dataclass, field -from typing import Any, Dict, List, Optional, Tuple, cast - -from rich.prompt import Confirm +from dataclasses import dataclass +from typing import Dict, List, Optional, Tuple, cast +import tyro from nerfstudio.engine.trainer import TrainerConfig from nerfstudio.utils.rich_utils import CONSOLE +from rich.prompt import Confirm @dataclass @@ -177,21 +178,30 @@ class ExternalMethod: @dataclass -class ExternalMethodTrainerConfig(TrainerConfig): - """ - Trainer config for external methods which does not have an implementation in this repository. +class ExternalMethodDummyTrainerConfig: + """Dummy trainer config for external methods (a) which do not have an + implementation in this repository, and (b) are not yet installed. When this + config is instantiated, we give the user the option to install the method. """ - _method: ExternalMethod = field(default=cast(ExternalMethod, None)) + # tyro.conf.Suppress will prevent these fields from appearing as CLI arguments. + method_name: tyro.conf.Suppress[str] + method: tyro.conf.Suppress[ExternalMethod] - def handle_print_information(self, *_args, **_kwargs): - """Prints the method information and exits.""" - CONSOLE.print(self._method.instructions) - if self._method.pip_package and Confirm.ask( + def __post_init__(self): + """Offer to install an external method.""" + + # Don't trigger install message from get_external_methods() below; only + # if this dummy object is instantiated from the CLI. + if inspect.stack()[2].function == "get_external_methods": + return + + CONSOLE.print(self.method.instructions) + if self.method.pip_package and Confirm.ask( "\nWould you like to run the install it now?", default=False, console=CONSOLE ): # Install the method - install_command = f"{sys.executable} -m pip install {self._method.pip_package}" + install_command = f"{sys.executable} -m pip install {self.method.pip_package}" CONSOLE.print(f"Running: [cyan]{install_command}[/cyan]") result = subprocess.run(install_command, shell=True, check=False) if result.returncode != 0: @@ -200,20 +210,16 @@ def handle_print_information(self, *_args, **_kwargs): sys.exit(0) - def __getattribute__(self, __name: str) -> Any: - out = object.__getattribute__(self, __name) - if callable(out) and __name not in {"handle_print_information"} and not __name.startswith("__"): - # We exit early, displaying the message - return self.handle_print_information - return out - def get_external_methods() -> Tuple[Dict[str, TrainerConfig], Dict[str, str]]: """Returns the external methods trainer configs and the descriptions.""" - method_configs = {} - descriptions = {} + method_configs: Dict[str, TrainerConfig] = {} + descriptions: Dict[str, str] = {} for external_method in external_methods: for config_slug, config_description in external_method.configurations: - method_configs[config_slug] = ExternalMethodTrainerConfig(method_name=config_slug, _method=external_method) - descriptions[config_slug] = f"""[External] {config_description}""" + method_configs[config_slug] = cast( # Need a cast because this is not a real TrainerConfig. + TrainerConfig, + ExternalMethodDummyTrainerConfig(method_name=config_slug, method=external_method), + ) + descriptions[config_slug] = f"""[External, run to install] {config_description}""" return method_configs, descriptions