From 842de7bf66e485753f69abc665fb943f5d5f152b Mon Sep 17 00:00:00 2001 From: Fangchen Li Date: Sat, 4 Nov 2023 20:26:15 -0700 Subject: [PATCH] update --- src/_nebari/config.py | 18 +++++++++++++----- src/_nebari/initialize.py | 3 ++- src/_nebari/stages/infrastructure/__init__.py | 10 +++++----- src/_nebari/subcommands/init.py | 13 +++++++------ 4 files changed, 27 insertions(+), 17 deletions(-) diff --git a/src/_nebari/config.py b/src/_nebari/config.py index 80b7a64a18..ba48fcd7ff 100644 --- a/src/_nebari/config.py +++ b/src/_nebari/config.py @@ -2,19 +2,19 @@ import pathlib import re import sys -import typing +from typing import Any, Dict, List, Union import pydantic from _nebari.utils import yaml -def set_nested_attribute(data: typing.Any, attrs: typing.List[str], value: typing.Any): +def set_nested_attribute(data: Any, attrs: List[str], value: Any): """Takes an arbitrary set of attributes and accesses the deep nested object config to set value """ - def _get_attr(d: typing.Any, attr: str): + def _get_attr(d: Any, attr: str): if isinstance(d, list) and re.fullmatch(r"\d+", attr): return d[int(attr)] elif hasattr(d, "__getitem__"): @@ -22,7 +22,7 @@ def _get_attr(d: typing.Any, attr: str): else: return getattr(d, attr) - def _set_attr(d: typing.Any, attr: str, value: typing.Any): + def _set_attr(d: Any, attr: str, value: Any): if isinstance(d, list) and re.fullmatch(r"\d+", attr): d[int(attr)] = value elif hasattr(d, "__getitem__"): @@ -63,6 +63,13 @@ def set_config_from_environment_variables( return config +def dump_nested_model(model_dict: Dict[str, Union[pydantic.BaseModel, str]]): + result = {} + for key, value in model_dict.items(): + result[key] = value.model_dump() if isinstance(value, pydantic.BaseModel) else value + return result + + def read_configuration( config_filename: pathlib.Path, config_schema: pydantic.BaseModel, @@ -88,7 +95,7 @@ def read_configuration( def write_configuration( config_filename: pathlib.Path, - config: typing.Union[pydantic.BaseModel, typing.Dict], + config: Union[pydantic.BaseModel, Dict], mode: str = "w", ): """Write the nebari configuration file to disk""" @@ -96,6 +103,7 @@ def write_configuration( if isinstance(config, pydantic.BaseModel): yaml.dump(config.model_dump(), f) else: + config = dump_nested_model(config) yaml.dump(config, f) diff --git a/src/_nebari/initialize.py b/src/_nebari/initialize.py index 44974a9788..a24cd5ddcc 100644 --- a/src/_nebari/initialize.py +++ b/src/_nebari/initialize.py @@ -3,6 +3,7 @@ import re import tempfile from pathlib import Path +from typing import Any, Dict import pydantic import requests @@ -45,7 +46,7 @@ def render_config( region: str = None, disable_prompt: bool = False, ssl_cert_email: str = None, -): +) -> Dict[str, Any]: config = { "provider": cloud_provider.value, "namespace": namespace, diff --git a/src/_nebari/stages/infrastructure/__init__.py b/src/_nebari/stages/infrastructure/__init__.py index aebe84a42f..c35d8178df 100644 --- a/src/_nebari/stages/infrastructure/__init__.py +++ b/src/_nebari/stages/infrastructure/__init__.py @@ -503,8 +503,8 @@ def _check_input(cls, data: Any) -> Any: class LocalProvider(schema.Base): - kube_context: typing.Optional[str] = None - node_selectors: typing.Dict[str, KeyValueDict] = { + kube_context: Optional[str] = None + node_selectors: Dict[str, KeyValueDict] = { "general": KeyValueDict(key="kubernetes.io/os", value="linux"), "user": KeyValueDict(key="kubernetes.io/os", value="linux"), "worker": KeyValueDict(key="kubernetes.io/os", value="linux"), @@ -512,8 +512,8 @@ class LocalProvider(schema.Base): class ExistingProvider(schema.Base): - kube_context: typing.Optional[str] = None - node_selectors: typing.Dict[str, KeyValueDict] = { + kube_context: Optional[str] = None + node_selectors: Dict[str, KeyValueDict] = { "general": KeyValueDict(key="kubernetes.io/os", value="linux"), "user": KeyValueDict(key="kubernetes.io/os", value="linux"), "worker": KeyValueDict(key="kubernetes.io/os", value="linux"), @@ -694,7 +694,7 @@ def tf_objects(self) -> List[Dict]: def input_vars(self, stage_outputs: Dict[str, Dict[str, Any]]): if self.config.provider == schema.ProviderEnum.local: - return LocalInputVars(kube_context=self.config.local.kube_context).dict() + return LocalInputVars(kube_context=self.config.local.kube_context).model_dump() elif self.config.provider == schema.ProviderEnum.existing: return ExistingInputVars( kube_context=self.config.existing.kube_context diff --git a/src/_nebari/subcommands/init.py b/src/_nebari/subcommands/init.py index b4276438b3..e7c79aee88 100644 --- a/src/_nebari/subcommands/init.py +++ b/src/_nebari/subcommands/init.py @@ -3,6 +3,7 @@ import pathlib import re import typing +from typing import Optional import questionary import rich @@ -84,17 +85,17 @@ class GitRepoEnum(str, enum.Enum): class InitInputs(schema.Base): cloud_provider: ProviderEnum = ProviderEnum.local project_name: schema.project_name_pydantic = "" - domain_name: typing.Optional[str] = None - namespace: typing.Optional[schema.namespace_pydantic] = "dev" + domain_name: Optional[str] = None + namespace: Optional[schema.namespace_pydantic] = "dev" auth_provider: AuthenticationEnum = AuthenticationEnum.password auth_auto_provision: bool = False - repository: typing.Optional[schema.github_url_pydantic] = None + repository: Optional[schema.github_url_pydantic] = None repository_auto_provision: bool = False ci_provider: CiEnum = CiEnum.none terraform_state: TerraformStateEnum = TerraformStateEnum.remote - kubernetes_version: typing.Union[str, None] = None - region: typing.Union[str, None] = None - ssl_cert_email: typing.Union[schema.email_pydantic, None] = None + kubernetes_version: Optional[str] = None + region: Optional[str] = None + ssl_cert_email: Optional[schema.email_pydantic] = None disable_prompt: bool = False output: pathlib.Path = pathlib.Path("nebari-config.yaml")