Skip to content

Commit

Permalink
Clean up ns-train {method} --help for not-yet-installed external me…
Browse files Browse the repository at this point in the history
…thods
  • Loading branch information
brentyi committed Jan 13, 2024
1 parent 114c5f7 commit 5bc7f79
Showing 1 changed file with 30 additions and 24 deletions.
54 changes: 30 additions & 24 deletions nerfstudio/configs/external_methods.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,15 +14,16 @@


"""This file contains the configuration for external methods which are not included in this repository."""
import inspect
import subprocess
import sys
from dataclasses import dataclass, field
from typing import Any, Dict, List, Optional, Tuple, cast

from rich.prompt import Confirm
from dataclasses import dataclass
from typing import Dict, List, Optional, Tuple, cast

import tyro
from nerfstudio.engine.trainer import TrainerConfig
from nerfstudio.utils.rich_utils import CONSOLE
from rich.prompt import Confirm


@dataclass
Expand Down Expand Up @@ -177,21 +178,30 @@ class ExternalMethod:


@dataclass
class ExternalMethodTrainerConfig(TrainerConfig):
"""
Trainer config for external methods which does not have an implementation in this repository.
class ExternalMethodDummyTrainerConfig:
"""Dummy trainer config for external methods (a) which do not have an
implementation in this repository, and (b) are not yet installed. When this
config is instantiated, we give the user the option to install the method.
"""

_method: ExternalMethod = field(default=cast(ExternalMethod, None))
# tyro.conf.Suppress will prevent these fields from appearing as CLI arguments.
method_name: tyro.conf.Suppress[str]
method: tyro.conf.Suppress[ExternalMethod]

def handle_print_information(self, *_args, **_kwargs):
"""Prints the method information and exits."""
CONSOLE.print(self._method.instructions)
if self._method.pip_package and Confirm.ask(
def __post_init__(self):
"""Offer to install an external method."""

# Don't trigger install message from get_external_methods() below; only
# if this dummy object is instantiated from the CLI.
if inspect.stack()[2].function == "get_external_methods":
return

CONSOLE.print(self.method.instructions)
if self.method.pip_package and Confirm.ask(
"\nWould you like to run the install it now?", default=False, console=CONSOLE
):
# Install the method
install_command = f"{sys.executable} -m pip install {self._method.pip_package}"
install_command = f"{sys.executable} -m pip install {self.method.pip_package}"
CONSOLE.print(f"Running: [cyan]{install_command}[/cyan]")
result = subprocess.run(install_command, shell=True, check=False)
if result.returncode != 0:
Expand All @@ -200,20 +210,16 @@ def handle_print_information(self, *_args, **_kwargs):

sys.exit(0)

def __getattribute__(self, __name: str) -> Any:
out = object.__getattribute__(self, __name)
if callable(out) and __name not in {"handle_print_information"} and not __name.startswith("__"):
# We exit early, displaying the message
return self.handle_print_information
return out


def get_external_methods() -> Tuple[Dict[str, TrainerConfig], Dict[str, str]]:
"""Returns the external methods trainer configs and the descriptions."""
method_configs = {}
descriptions = {}
method_configs: Dict[str, TrainerConfig] = {}
descriptions: Dict[str, str] = {}
for external_method in external_methods:
for config_slug, config_description in external_method.configurations:
method_configs[config_slug] = ExternalMethodTrainerConfig(method_name=config_slug, _method=external_method)
descriptions[config_slug] = f"""[External] {config_description}"""
method_configs[config_slug] = cast( # Need a cast because this is not a real TrainerConfig.
TrainerConfig,
ExternalMethodDummyTrainerConfig(method_name=config_slug, method=external_method),
)
descriptions[config_slug] = f"""[External, run to install] {config_description}"""
return method_configs, descriptions

0 comments on commit 5bc7f79

Please sign in to comment.