diff --git a/pyproject.toml b/pyproject.toml index 9c21badb43..7bfa0a59c3 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -63,7 +63,7 @@ dependencies = [ "kubernetes==27.2.0", "pluggy==1.3.0", "prompt-toolkit==3.0.36", - "pydantic==1.10.12", + "pydantic==2.4.2", "pynacl==1.5.0", "python-keycloak>=3.9.0", "questionary==2.0.0", diff --git a/src/_nebari/config.py b/src/_nebari/config.py index d1b2f42944..7c27274f36 100644 --- a/src/_nebari/config.py +++ b/src/_nebari/config.py @@ -2,19 +2,19 @@ import pathlib import re import sys -import typing +from typing import Any, Dict, List, Union import pydantic from _nebari.utils import yaml -def set_nested_attribute(data: typing.Any, attrs: typing.List[str], value: typing.Any): +def set_nested_attribute(data: Any, attrs: List[str], value: Any): """Takes an arbitrary set of attributes and accesses the deep nested object config to set value """ - def _get_attr(d: typing.Any, attr: str): + def _get_attr(d: Any, attr: str): if isinstance(d, list) and re.fullmatch(r"\d+", attr): return d[int(attr)] elif hasattr(d, "__getitem__"): @@ -22,7 +22,7 @@ def _get_attr(d: typing.Any, attr: str): else: return getattr(d, attr) - def _set_attr(d: typing.Any, attr: str, value: typing.Any): + def _set_attr(d: Any, attr: str, value: Any): if isinstance(d, list) and re.fullmatch(r"\d+", attr): d[int(attr)] = value elif hasattr(d, "__getitem__"): @@ -63,6 +63,15 @@ def set_config_from_environment_variables( return config +def dump_nested_model(model_dict: Dict[str, Union[pydantic.BaseModel, str]]): + result = {} + for key, value in model_dict.items(): + result[key] = ( + value.model_dump() if isinstance(value, pydantic.BaseModel) else value + ) + return result + + def read_configuration( config_filename: pathlib.Path, config_schema: pydantic.BaseModel, @@ -77,7 +86,8 @@ def read_configuration( ) with filename.open() as f: - config = config_schema(**yaml.load(f.read())) + config_dict = yaml.load(f) + config = config_schema(**config_dict) if read_environment: config = set_config_from_environment_variables(config) @@ -87,14 +97,15 @@ def read_configuration( def write_configuration( config_filename: pathlib.Path, - config: typing.Union[pydantic.BaseModel, typing.Dict], + config: Union[pydantic.BaseModel, Dict], mode: str = "w", ): """Write the nebari configuration file to disk""" with config_filename.open(mode) as f: if isinstance(config, pydantic.BaseModel): - yaml.dump(config.dict(), f) + yaml.dump(config.model_dump(), f) else: + config = dump_nested_model(config) yaml.dump(config, f) diff --git a/src/_nebari/initialize.py b/src/_nebari/initialize.py index e7af82025b..df693ca8f0 100644 --- a/src/_nebari/initialize.py +++ b/src/_nebari/initialize.py @@ -3,6 +3,7 @@ import re import tempfile from pathlib import Path +from typing import Any, Dict import pydantic import requests @@ -52,7 +53,7 @@ def render_config( region: str = None, disable_prompt: bool = False, ssl_cert_email: str = None, -): +) -> Dict[str, Any]: config = { "provider": cloud_provider, "namespace": namespace, @@ -119,7 +120,7 @@ def render_config( if cloud_provider == ProviderEnum.do: do_region = region or constants.DO_DEFAULT_REGION do_kubernetes_versions = kubernetes_version or get_latest_kubernetes_version( - digital_ocean.kubernetes_versions(do_region) + digital_ocean.kubernetes_versions() ) config["digital_ocean"] = { "kubernetes_version": do_kubernetes_versions, @@ -200,7 +201,7 @@ def render_config( from nebari.plugins import nebari_plugin_manager try: - config_model = nebari_plugin_manager.config_schema.parse_obj(config) + config_model = nebari_plugin_manager.config_schema.model_validate(config) except pydantic.ValidationError as e: print(str(e)) diff --git a/src/_nebari/provider/cicd/github.py b/src/_nebari/provider/cicd/github.py index 7b58464c43..2563af6ad9 100644 --- a/src/_nebari/provider/cicd/github.py +++ b/src/_nebari/provider/cicd/github.py @@ -4,7 +4,7 @@ import requests from nacl import encoding, public -from pydantic import BaseModel, Field +from pydantic import BaseModel, ConfigDict, Field, RootModel from _nebari.constants import LATEST_SUPPORTED_PYTHON_VERSION from _nebari.provider.cicd.common import pip_install_nebari @@ -143,49 +143,34 @@ class GHA_on_extras(BaseModel): paths: List[str] -class GHA_on(BaseModel): - # to allow for dynamic key names - __root__: Dict[str, GHA_on_extras] - - # TODO: validate __root__ values - # `push`, `pull_request`, etc. - - -class GHA_job_steps_extras(BaseModel): - # to allow for dynamic key names - __root__: Union[str, float, int] +GHA_on = RootModel[Dict[str, GHA_on_extras]] +GHA_job_steps_extras = RootModel[Union[str, float, int]] class GHA_job_step(BaseModel): name: str - uses: Optional[str] - with_: Optional[Dict[str, GHA_job_steps_extras]] = Field(alias="with") - run: Optional[str] - env: Optional[Dict[str, GHA_job_steps_extras]] - - class Config: - allow_population_by_field_name = True + uses: Optional[str] = None + with_: Optional[Dict[str, GHA_job_steps_extras]] = Field(alias="with", default=None) + run: Optional[str] = None + env: Optional[Dict[str, GHA_job_steps_extras]] = None + model_config = ConfigDict(populate_by_name=True) class GHA_job_id(BaseModel): name: str runs_on_: str = Field(alias="runs-on") - permissions: Optional[Dict[str, str]] + permissions: Optional[Dict[str, str]] = None steps: List[GHA_job_step] - - class Config: - allow_population_by_field_name = True + model_config = ConfigDict(populate_by_name=True) -class GHA_jobs(BaseModel): - # to allow for dynamic key names - __root__: Dict[str, GHA_job_id] +GHA_jobs = RootModel[Dict[str, GHA_job_id]] class GHA(BaseModel): name: str on: GHA_on - env: Optional[Dict[str, str]] + env: Optional[Dict[str, str]] = None jobs: GHA_jobs @@ -204,11 +189,7 @@ def checkout_image_step(): return GHA_job_step( name="Checkout Image", uses="actions/checkout@v3", - with_={ - "token": GHA_job_steps_extras( - __root__="${{ secrets.REPOSITORY_ACCESS_TOKEN }}" - ) - }, + with_={"token": GHA_job_steps_extras("${{ secrets.REPOSITORY_ACCESS_TOKEN }}")}, ) @@ -216,11 +197,7 @@ def setup_python_step(): return GHA_job_step( name="Set up Python", uses="actions/setup-python@v4", - with_={ - "python-version": GHA_job_steps_extras( - __root__=LATEST_SUPPORTED_PYTHON_VERSION - ) - }, + with_={"python-version": GHA_job_steps_extras(LATEST_SUPPORTED_PYTHON_VERSION)}, ) @@ -242,7 +219,7 @@ def gen_nebari_ops(config): env_vars = gha_env_vars(config) push = GHA_on_extras(branches=[config.ci_cd.branch], paths=["nebari-config.yaml"]) - on = GHA_on(__root__={"push": push}) + on = GHA_on({"push": push}) step1 = checkout_image_step() step2 = setup_python_step() @@ -272,7 +249,7 @@ def gen_nebari_ops(config): ), env={ "COMMIT_MSG": GHA_job_steps_extras( - __root__="nebari-config.yaml automated commit: ${{ github.sha }}" + "nebari-config.yaml automated commit: ${{ github.sha }}" ) }, ) @@ -291,7 +268,7 @@ def gen_nebari_ops(config): }, steps=gha_steps, ) - jobs = GHA_jobs(__root__={"build": job1}) + jobs = GHA_jobs({"build": job1}) return NebariOps( name="nebari auto update", @@ -312,18 +289,16 @@ def gen_nebari_linter(config): pull_request = GHA_on_extras( branches=[config.ci_cd.branch], paths=["nebari-config.yaml"] ) - on = GHA_on(__root__={"pull_request": pull_request}) + on = GHA_on({"pull_request": pull_request}) step1 = checkout_image_step() step2 = setup_python_step() step3 = install_nebari_step(config.nebari_version) step4_envs = { - "PR_NUMBER": GHA_job_steps_extras(__root__="${{ github.event.number }}"), - "REPO_NAME": GHA_job_steps_extras(__root__="${{ github.repository }}"), - "GITHUB_TOKEN": GHA_job_steps_extras( - __root__="${{ secrets.REPOSITORY_ACCESS_TOKEN }}" - ), + "PR_NUMBER": GHA_job_steps_extras("${{ github.event.number }}"), + "REPO_NAME": GHA_job_steps_extras("${{ github.repository }}"), + "GITHUB_TOKEN": GHA_job_steps_extras("${{ secrets.REPOSITORY_ACCESS_TOKEN }}"), } step4 = GHA_job_step( @@ -336,7 +311,7 @@ def gen_nebari_linter(config): name="nebari", runs_on_="ubuntu-latest", steps=[step1, step2, step3, step4] ) jobs = GHA_jobs( - __root__={ + { "nebari-validate": job1, } ) diff --git a/src/_nebari/provider/cicd/gitlab.py b/src/_nebari/provider/cicd/gitlab.py index e2d02b388b..d5e944f36d 100644 --- a/src/_nebari/provider/cicd/gitlab.py +++ b/src/_nebari/provider/cicd/gitlab.py @@ -1,40 +1,34 @@ from typing import Dict, List, Optional, Union -from pydantic import BaseModel, Field +from pydantic import BaseModel, ConfigDict, Field, RootModel from _nebari.constants import LATEST_SUPPORTED_PYTHON_VERSION from _nebari.provider.cicd.common import pip_install_nebari - -class GLCI_extras(BaseModel): - # to allow for dynamic key names - __root__: Union[str, float, int] +GLCI_extras = RootModel[Union[str, float, int]] class GLCI_image(BaseModel): name: str - entrypoint: Optional[str] + entrypoint: Optional[str] = None class GLCI_rules(BaseModel): if_: Optional[str] = Field(alias="if") - changes: Optional[List[str]] - - class Config: - allow_population_by_field_name = True + changes: Optional[List[str]] = None + model_config = ConfigDict(populate_by_name=True) class GLCI_job(BaseModel): - image: Optional[Union[str, GLCI_image]] - variables: Optional[Dict[str, str]] - before_script: Optional[List[str]] - after_script: Optional[List[str]] + image: Optional[Union[str, GLCI_image]] = None + variables: Optional[Dict[str, str]] = None + before_script: Optional[List[str]] = None + after_script: Optional[List[str]] = None script: List[str] - rules: Optional[List[GLCI_rules]] + rules: Optional[List[GLCI_rules]] = None -class GLCI(BaseModel): - __root__: Dict[str, GLCI_job] +GLCI = RootModel[Dict[str, GLCI_job]] def gen_gitlab_ci(config): @@ -76,7 +70,7 @@ def gen_gitlab_ci(config): ) return GLCI( - __root__={ + { "render-nebari": render_nebari, } ) diff --git a/src/_nebari/provider/cloud/amazon_web_services.py b/src/_nebari/provider/cloud/amazon_web_services.py index 2bf905bfcb..1123c07fe0 100644 --- a/src/_nebari/provider/cloud/amazon_web_services.py +++ b/src/_nebari/provider/cloud/amazon_web_services.py @@ -7,25 +7,18 @@ import boto3 from botocore.exceptions import ClientError, EndpointConnectionError -from _nebari import constants +from _nebari.constants import AWS_ENV_DOCS from _nebari.provider.cloud.commons import filter_by_highest_supported_k8s_version +from _nebari.utils import check_environment_variables from nebari import schema MAX_RETRIES = 5 DELAY = 5 -def check_credentials(): - """Check for AWS credentials are set in the environment.""" - for variable in { - "AWS_ACCESS_KEY_ID", - "AWS_SECRET_ACCESS_KEY", - }: - if variable not in os.environ: - raise ValueError( - f"""Missing the following required environment variable: {variable}\n - Please see the documentation for more information: {constants.AWS_ENV_DOCS}""" - ) +def check_credentials() -> None: + required_variables = {"AWS_ACCESS_KEY_ID", "AWS_SECRET_ACCESS_KEY"} + check_environment_variables(required_variables, AWS_ENV_DOCS) @functools.lru_cache() diff --git a/src/_nebari/provider/cloud/azure_cloud.py b/src/_nebari/provider/cloud/azure_cloud.py index 992e5c1362..44ebdaaee6 100644 --- a/src/_nebari/provider/cloud/azure_cloud.py +++ b/src/_nebari/provider/cloud/azure_cloud.py @@ -9,10 +9,11 @@ from azure.mgmt.containerservice import ContainerServiceClient from azure.mgmt.resource import ResourceManagementClient -from _nebari import constants +from _nebari.constants import AZURE_ENV_DOCS from _nebari.provider.cloud.commons import filter_by_highest_supported_k8s_version from _nebari.utils import ( AZURE_TF_STATE_RESOURCE_GROUP_SUFFIX, + check_environment_variables, construct_azure_resource_group_name, ) from nebari import schema @@ -24,29 +25,18 @@ RETRIES = 10 -def check_credentials(): - """Check if credentials are valid.""" - - required_variables = { - "ARM_CLIENT_ID": os.environ.get("ARM_CLIENT_ID", None), - "ARM_SUBSCRIPTION_ID": os.environ.get("ARM_SUBSCRIPTION_ID", None), - "ARM_TENANT_ID": os.environ.get("ARM_TENANT_ID", None), - } - arm_client_secret = os.environ.get("ARM_CLIENT_SECRET", None) - - if not all(required_variables.values()): - raise ValueError( - f"""Missing the following required environment variables: {required_variables}\n - Please see the documentation for more information: {constants.AZURE_ENV_DOCS}""" - ) +def check_credentials() -> DefaultAzureCredential: + required_variables = {"ARM_CLIENT_ID", "ARM_SUBSCRIPTION_ID", "ARM_TENANT_ID"} + check_environment_variables(required_variables, AZURE_ENV_DOCS) + optional_variable = "ARM_CLIENT_SECRET" + arm_client_secret = os.environ.get(optional_variable, None) if arm_client_secret: logger.info("Authenticating as a service principal.") - return DefaultAzureCredential() else: - logger.info("No ARM_CLIENT_SECRET environment variable found.") + logger.info(f"No {optional_variable} environment variable found.") logger.info("Allowing Azure SDK to authenticate using OIDC or other methods.") - return DefaultAzureCredential() + return DefaultAzureCredential() @functools.lru_cache() diff --git a/src/_nebari/provider/cloud/digital_ocean.py b/src/_nebari/provider/cloud/digital_ocean.py index d64ca4c6de..3e4a507be6 100644 --- a/src/_nebari/provider/cloud/digital_ocean.py +++ b/src/_nebari/provider/cloud/digital_ocean.py @@ -7,24 +7,20 @@ import kubernetes.config import requests -from _nebari import constants +from _nebari.constants import DO_ENV_DOCS from _nebari.provider.cloud.amazon_web_services import aws_delete_s3_bucket from _nebari.provider.cloud.commons import filter_by_highest_supported_k8s_version -from _nebari.utils import set_do_environment +from _nebari.utils import check_environment_variables, set_do_environment from nebari import schema -def check_credentials(): - for variable in { +def check_credentials() -> None: + required_variables = { + "DIGITALOCEAN_TOKEN", "SPACES_ACCESS_KEY_ID", "SPACES_SECRET_ACCESS_KEY", - "DIGITALOCEAN_TOKEN", - }: - if variable not in os.environ: - raise ValueError( - f"""Missing the following required environment variable: {variable}\n - Please see the documentation for more information: {constants.DO_ENV_DOCS}""" - ) + } + check_environment_variables(required_variables, DO_ENV_DOCS) def digital_ocean_request(url, method="GET", json=None): @@ -63,7 +59,7 @@ def regions(): return _kubernetes_options()["options"]["regions"] -def kubernetes_versions(region) -> typing.List[str]: +def kubernetes_versions() -> typing.List[str]: """Return list of available kubernetes supported by cloud provider. Sorted from oldest to latest.""" supported_kubernetes_versions = sorted( [_["slug"].split("-")[0] for _ in _kubernetes_options()["options"]["versions"]] diff --git a/src/_nebari/provider/cloud/google_cloud.py b/src/_nebari/provider/cloud/google_cloud.py index ba95f713cf..67d0ebad7a 100644 --- a/src/_nebari/provider/cloud/google_cloud.py +++ b/src/_nebari/provider/cloud/google_cloud.py @@ -1,21 +1,17 @@ import functools import json -import os import subprocess from typing import Dict, List, Set -from _nebari import constants +from _nebari.constants import GCP_ENV_DOCS from _nebari.provider.cloud.commons import filter_by_highest_supported_k8s_version +from _nebari.utils import check_environment_variables from nebari import schema -def check_credentials(): - for variable in {"GOOGLE_CREDENTIALS", "PROJECT_ID"}: - if variable not in os.environ: - raise ValueError( - f"""Missing the following required environment variable: {variable}\n - Please see the documentation for more information: {constants.GCP_ENV_DOCS}""" - ) +def check_credentials() -> None: + required_variables = {"GOOGLE_CREDENTIALS", "PROJECT_ID"} + check_environment_variables(required_variables, GCP_ENV_DOCS) @functools.lru_cache() @@ -282,7 +278,7 @@ def check_missing_service() -> None: if missing: raise ValueError( f"""Missing required services: {missing}\n - Please see the documentation for more information: {constants.GCP_ENV_DOCS}""" + Please see the documentation for more information: {GCP_ENV_DOCS}""" ) diff --git a/src/_nebari/stages/bootstrap/__init__.py b/src/_nebari/stages/bootstrap/__init__.py index 688146999b..97e754d9c8 100644 --- a/src/_nebari/stages/bootstrap/__init__.py +++ b/src/_nebari/stages/bootstrap/__init__.py @@ -96,7 +96,7 @@ def render(self) -> Dict[str, str]: for fn, workflow in gen_cicd(self.config).items(): stream = io.StringIO() schema.yaml.dump( - workflow.dict( + workflow.model_dump( by_alias=True, exclude_unset=True, exclude_defaults=True ), stream, diff --git a/src/_nebari/stages/infrastructure/__init__.py b/src/_nebari/stages/infrastructure/__init__.py index 965795438a..e7bc56babd 100644 --- a/src/_nebari/stages/infrastructure/__init__.py +++ b/src/_nebari/stages/infrastructure/__init__.py @@ -5,9 +5,9 @@ import re import sys import tempfile -from typing import Any, Dict, List, Optional, Tuple, Type, Union +from typing import Annotated, Any, Dict, List, Optional, Tuple, Type, Union -import pydantic +from pydantic import Field, field_validator, model_validator from _nebari import constants from _nebari.provider import terraform @@ -34,7 +34,7 @@ def get_kubeconfig_filename(): class LocalInputVars(schema.Base): kubeconfig_filename: str = get_kubeconfig_filename() - kube_context: Optional[str] + kube_context: Optional[str] = None class ExistingInputVars(schema.Base): @@ -180,11 +180,11 @@ def _calculate_node_groups(config: schema.Main): elif config.provider == schema.ProviderEnum.existing: return config.existing.node_selectors else: - return config.local.dict()["node_selectors"] + return config.local.model_dump()["node_selectors"] def node_groups_to_dict(node_groups): - return {ng_name: ng.dict() for ng_name, ng in node_groups.items()} + return {ng_name: ng.model_dump() for ng_name, ng in node_groups.items()} @contextlib.contextmanager @@ -223,8 +223,8 @@ class DigitalOceanNodeGroup(schema.Base): """ instance: str - min_nodes: pydantic.conint(ge=1) = 1 - max_nodes: pydantic.conint(ge=1) = 1 + min_nodes: Annotated[int, Field(ge=1)] = 1 + max_nodes: Annotated[int, Field(ge=1)] = 1 DEFAULT_DO_NODE_GROUPS = { @@ -236,56 +236,44 @@ class DigitalOceanNodeGroup(schema.Base): class DigitalOceanProvider(schema.Base): region: str - kubernetes_version: str + kubernetes_version: Optional[str] = None # Digital Ocean image slugs are listed here https://slugs.do-api.dev/ node_groups: Dict[str, DigitalOceanNodeGroup] = DEFAULT_DO_NODE_GROUPS tags: Optional[List[str]] = [] - @pydantic.validator("region") - def _validate_region(cls, value): + @model_validator(mode="before") + @classmethod + def _check_input(cls, data: Any) -> Any: digital_ocean.check_credentials() + # check if region is valid available_regions = set(_["slug"] for _ in digital_ocean.regions()) - if value not in available_regions: + if data["region"] not in available_regions: raise ValueError( - f"Digital Ocean region={value} is not one of {available_regions}" + f"Digital Ocean region={data['region']} is not one of {available_regions}" ) - return value - - @pydantic.validator("node_groups") - def _validate_node_group(cls, value): - digital_ocean.check_credentials() - available_instances = {_["slug"] for _ in digital_ocean.instances()} - for name, node_group in value.items(): - if node_group.instance not in available_instances: - raise ValueError( - f"Digital Ocean instance {node_group.instance} not one of available instance types={available_instances}" - ) - - return value - - @pydantic.root_validator - def _validate_kubernetes_version(cls, values): - digital_ocean.check_credentials() - - if "region" not in values: - raise ValueError("Region required in order to set kubernetes_version") - - available_kubernetes_versions = digital_ocean.kubernetes_versions( - values["region"] - ) - assert available_kubernetes_versions - if ( - values["kubernetes_version"] is not None - and values["kubernetes_version"] not in available_kubernetes_versions - ): + # check if kubernetes version is valid + available_kubernetes_versions = digital_ocean.kubernetes_versions() + if len(available_kubernetes_versions) == 0: raise ValueError( - f"\nInvalid `kubernetes-version` provided: {values['kubernetes_version']}.\nPlease select from one of the following supported Kubernetes versions: {available_kubernetes_versions} or omit flag to use latest Kubernetes version available." + "Request to Digital Ocean for available Kubernetes versions failed." + ) + if data["kubernetes_version"] is None: + data["kubernetes_version"] = available_kubernetes_versions[-1] + elif data["kubernetes_version"] not in available_kubernetes_versions: + raise ValueError( + f"\nInvalid `kubernetes-version` provided: {data['kubernetes_version']}.\nPlease select from one of the following supported Kubernetes versions: {available_kubernetes_versions} or omit flag to use latest Kubernetes version available." ) - else: - values["kubernetes_version"] = available_kubernetes_versions[-1] - return values + + available_instances = {_["slug"] for _ in digital_ocean.instances()} + if "node_groups" in data: + for _, node_group in data["node_groups"].items(): + if node_group["instance"] not in available_instances: + raise ValueError( + f"Digital Ocean instance {node_group.instance} not one of available instance types={available_instances}" + ) + return data class GCPIPAllocationPolicy(schema.Base): @@ -318,13 +306,13 @@ class GCPGuestAccelerator(schema.Base): """ name: str - count: pydantic.conint(ge=1) = 1 + count: Annotated[int, Field(ge=1)] = 1 class GCPNodeGroup(schema.Base): instance: str - min_nodes: pydantic.conint(ge=0) = 0 - max_nodes: pydantic.conint(ge=1) = 1 + min_nodes: Annotated[int, Field(ge=0)] = 0 + max_nodes: Annotated[int, Field(ge=1)] = 1 preemptible: bool = False labels: Dict[str, str] = {} guest_accelerators: List[GCPGuestAccelerator] = [] @@ -352,31 +340,23 @@ class GoogleCloudPlatformProvider(schema.Base): master_authorized_networks_config: Optional[Union[GCPCIDRBlock, None]] = None private_cluster_config: Optional[Union[GCPPrivateClusterConfig, None]] = None - @pydantic.root_validator - def validate_all(cls, values): - region = values.get("region") - project_id = values.get("project") - - if project_id is None: - raise ValueError("The `google_cloud_platform.project` field is required.") - - if region is None: - raise ValueError("The `google_cloud_platform.region` field is required.") - - # validate region - google_cloud.validate_region(region) - - # validate kubernetes version - kubernetes_version = values.get("kubernetes_version") - available_kubernetes_versions = google_cloud.kubernetes_versions(region) - if kubernetes_version is None: - values["kubernetes_version"] = available_kubernetes_versions[-1] - elif kubernetes_version not in available_kubernetes_versions: + @model_validator(mode="before") + @classmethod + def _check_input(cls, data: Any) -> Any: + google_cloud.check_credentials() + avaliable_regions = google_cloud.regions() + if data["region"] not in avaliable_regions: raise ValueError( - f"\nInvalid `kubernetes-version` provided: {values['kubernetes_version']}.\nPlease select from one of the following supported Kubernetes versions: {available_kubernetes_versions} or omit flag to use latest Kubernetes version available." + f"Google Cloud region={data['region']} is not one of {avaliable_regions}" ) - return values + available_kubernetes_versions = google_cloud.kubernetes_versions(data["region"]) + print(available_kubernetes_versions) + if data["kubernetes_version"] not in available_kubernetes_versions: + raise ValueError( + f"\nInvalid `kubernetes-version` provided: {data['kubernetes_version']}.\nPlease select from one of the following supported Kubernetes versions: {available_kubernetes_versions} or omit flag to use latest Kubernetes version available." + ) + return data class AzureNodeGroup(schema.Base): @@ -394,20 +374,27 @@ class AzureNodeGroup(schema.Base): class AzureProvider(schema.Base): region: str - kubernetes_version: str + kubernetes_version: Optional[str] = None storage_account_postfix: str - resource_group_name: str = None + resource_group_name: Optional[str] = None node_groups: Dict[str, AzureNodeGroup] = DEFAULT_AZURE_NODE_GROUPS storage_account_postfix: str - vnet_subnet_id: Optional[Union[str, None]] = None + vnet_subnet_id: Optional[str] = None private_cluster_enabled: bool = False resource_group_name: Optional[str] = None tags: Optional[Dict[str, str]] = {} network_profile: Optional[Dict[str, str]] = None max_pods: Optional[int] = None - @pydantic.validator("kubernetes_version") - def _validate_kubernetes_version(cls, value): + @model_validator(mode="before") + @classmethod + def _check_credentials(cls, data: Any) -> Any: + azure_cloud.check_credentials() + return data + + @field_validator("kubernetes_version") + @classmethod + def _validate_kubernetes_version(cls, value: Optional[str]) -> str: available_kubernetes_versions = azure_cloud.kubernetes_versions() if value is None: value = available_kubernetes_versions[-1] @@ -417,7 +404,8 @@ def _validate_kubernetes_version(cls, value): ) return value - @pydantic.validator("resource_group_name") + @field_validator("resource_group_name") + @classmethod def _validate_resource_group_name(cls, value): if value is None: return value @@ -435,9 +423,10 @@ def _validate_resource_group_name(cls, value): return value - @pydantic.validator("tags") - def _validate_tags(cls, tags): - return azure_cloud.validate_tags(tags) + @field_validator("tags") + @classmethod + def _validate_tags(cls, value: Optional[Dict[str, str]]) -> Dict[str, str]: + return value if value is None else azure_cloud.validate_tags(value) class AWSNodeGroup(schema.Base): @@ -465,49 +454,66 @@ class AmazonWebServicesProvider(schema.Base): kubernetes_version: str availability_zones: Optional[List[str]] node_groups: Dict[str, AWSNodeGroup] = DEFAULT_AWS_NODE_GROUPS - existing_subnet_ids: List[str] = None - existing_security_group_id: str = None + existing_subnet_ids: Optional[List[str]] = None + existing_security_group_id: Optional[str] = None vpc_cidr_block: str = "10.10.0.0/16" permissions_boundary: Optional[str] = None tags: Optional[Dict[str, str]] = {} - @pydantic.root_validator - def validate_all(cls, values): - region = values.get("region") - if region is None: - raise ValueError("The `amazon_web_services.region` field is required.") - - # validate region - amazon_web_services.validate_region(region) - - # validate kubernetes version - kubernetes_version = values.get("kubernetes_version") - available_kubernetes_versions = amazon_web_services.kubernetes_versions(region) - if kubernetes_version is None: - values["kubernetes_version"] = available_kubernetes_versions[-1] - elif kubernetes_version not in available_kubernetes_versions: + @model_validator(mode="before") + @classmethod + def _check_input(cls, data: Any) -> Any: + amazon_web_services.check_credentials() + + # check if region is valid + available_regions = amazon_web_services.regions(data["region"]) + if data["region"] not in available_regions: raise ValueError( - f"\nInvalid `kubernetes-version` provided: {values['kubernetes_version']}.\nPlease select from one of the following supported Kubernetes versions: {available_kubernetes_versions} or omit flag to use latest Kubernetes version available." + f"Amazon Web Services region={data['region']} is not one of {available_regions}" ) - # validate node groups - node_groups = values["node_groups"] - available_instances = amazon_web_services.instances(region) - for name, node_group in node_groups.items(): - if node_group.instance not in available_instances: - raise ValueError( - f"Instance {node_group.instance} not available out of available instances {available_instances.keys()}" - ) + # check if kubernetes version is valid + available_kubernetes_versions = amazon_web_services.kubernetes_versions( + data["region"] + ) + if len(available_kubernetes_versions) == 0: + raise ValueError("Request to AWS for available Kubernetes versions failed.") + if data["kubernetes_version"] is None: + data["kubernetes_version"] = available_kubernetes_versions[-1] + elif data["kubernetes_version"] not in available_kubernetes_versions: + raise ValueError( + f"\nInvalid `kubernetes-version` provided: {data['kubernetes_version']}.\nPlease select from one of the following supported Kubernetes versions: {available_kubernetes_versions} or omit flag to use latest Kubernetes version available." + ) - if values["availability_zones"] is None: - zones = amazon_web_services.zones(region) - values["availability_zones"] = list(sorted(zones))[:2] + # check if availability zones are valid + available_zones = amazon_web_services.zones(data["region"]) + if "availability_zones" not in data: + data["availability_zones"] = list(sorted(available_zones))[:2] + else: + for zone in data["availability_zones"]: + if zone not in available_zones: + raise ValueError( + f"Amazon Web Services availability zone={zone} is not one of {available_zones}" + ) - return values + # check if instances are valid + available_instances = amazon_web_services.instances(data["region"]) + if "node_groups" in data: + for _, node_group in data["node_groups"].items(): + instance = ( + node_group["instance"] + if hasattr(node_group, "__getitem__") + else node_group.instance + ) + if instance not in available_instances: + raise ValueError( + f"Amazon Web Services instance {node_group.instance} not one of available instance types={available_instances}" + ) + return data class LocalProvider(schema.Base): - kube_context: Optional[str] + kube_context: Optional[str] = None node_selectors: Dict[str, KeyValueDict] = { "general": KeyValueDict(key="kubernetes.io/os", value="linux"), "user": KeyValueDict(key="kubernetes.io/os", value="linux"), @@ -516,7 +522,7 @@ class LocalProvider(schema.Base): class ExistingProvider(schema.Base): - kube_context: Optional[str] + kube_context: Optional[str] = None node_selectors: Dict[str, KeyValueDict] = { "general": KeyValueDict(key="kubernetes.io/os", value="linux"), "user": KeyValueDict(key="kubernetes.io/os", value="linux"), @@ -555,23 +561,24 @@ class ExistingProvider(schema.Base): class InputSchema(schema.Base): - local: Optional[LocalProvider] - existing: Optional[ExistingProvider] - google_cloud_platform: Optional[GoogleCloudPlatformProvider] - amazon_web_services: Optional[AmazonWebServicesProvider] - azure: Optional[AzureProvider] - digital_ocean: Optional[DigitalOceanProvider] - - @pydantic.root_validator(pre=True) - def check_provider(cls, values): - if "provider" in values: - provider: str = values["provider"] + local: Optional[LocalProvider] = None + existing: Optional[ExistingProvider] = None + google_cloud_platform: Optional[GoogleCloudPlatformProvider] = None + amazon_web_services: Optional[AmazonWebServicesProvider] = None + azure: Optional[AzureProvider] = None + digital_ocean: Optional[DigitalOceanProvider] = None + + @model_validator(mode="before") + @classmethod + def check_provider(cls, data: Any) -> Any: + if "provider" in data: + provider: str = data["provider"] if hasattr(schema.ProviderEnum, provider): # TODO: all cloud providers has required fields, but local and existing don't. # And there is no way to initialize a model without user input here. # We preserve the original behavior here, but we should find a better way to do this. - if provider in ["local", "existing"] and provider not in values: - values[provider] = provider_enum_model_map[provider]() + if provider in ["local", "existing"] and provider not in data: + data[provider] = provider_enum_model_map[provider]() else: # if the provider field is invalid, it won't be set when this validator is called # so we need to check for it explicitly here, and set the `pre` to True @@ -583,16 +590,16 @@ def check_provider(cls, values): setted_providers = [ provider for provider in provider_name_abbreviation_map.keys() - if provider in values + if provider in data ] num_providers = len(setted_providers) if num_providers > 1: raise ValueError(f"Multiple providers set: {setted_providers}") elif num_providers == 1: - values["provider"] = provider_name_abbreviation_map[setted_providers[0]] + data["provider"] = provider_name_abbreviation_map[setted_providers[0]] elif num_providers == 0: - values["provider"] = schema.ProviderEnum.local.value - return values + data["provider"] = schema.ProviderEnum.local.value + return data class NodeSelectorKeyValue(schema.Base): @@ -603,20 +610,20 @@ class NodeSelectorKeyValue(schema.Base): class KubernetesCredentials(schema.Base): host: str cluster_ca_certifiate: str - token: Optional[str] - username: Optional[str] - password: Optional[str] - client_certificate: Optional[str] - client_key: Optional[str] - config_path: Optional[str] - config_context: Optional[str] + token: Optional[str] = None + username: Optional[str] = None + password: Optional[str] = None + client_certificate: Optional[str] = None + client_key: Optional[str] = None + config_path: Optional[str] = None + config_context: Optional[str] = None class OutputSchema(schema.Base): node_selectors: Dict[str, NodeSelectorKeyValue] kubernetes_credentials: KubernetesCredentials kubeconfig_filename: str - nfs_endpoint: Optional[str] + nfs_endpoint: Optional[str] = None class KubernetesInfrastructureStage(NebariTerraformStage): @@ -704,11 +711,13 @@ def tf_objects(self) -> List[Dict]: def input_vars(self, stage_outputs: Dict[str, Dict[str, Any]]): if self.config.provider == schema.ProviderEnum.local: - return LocalInputVars(kube_context=self.config.local.kube_context).dict() + return LocalInputVars( + kube_context=self.config.local.kube_context + ).model_dump() elif self.config.provider == schema.ProviderEnum.existing: return ExistingInputVars( kube_context=self.config.existing.kube_context - ).dict() + ).model_dump() elif self.config.provider == schema.ProviderEnum.do: return DigitalOceanInputVars( name=self.config.escaped_project_name, @@ -717,7 +726,7 @@ def input_vars(self, stage_outputs: Dict[str, Dict[str, Any]]): tags=self.config.digital_ocean.tags, kubernetes_version=self.config.digital_ocean.kubernetes_version, node_groups=self.config.digital_ocean.node_groups, - ).dict() + ).model_dump() elif self.config.provider == schema.ProviderEnum.gcp: return GCPInputVars( name=self.config.escaped_project_name, @@ -746,7 +755,7 @@ def input_vars(self, stage_outputs: Dict[str, Dict[str, Any]]): ip_allocation_policy=self.config.google_cloud_platform.ip_allocation_policy, master_authorized_networks_config=self.config.google_cloud_platform.master_authorized_networks_config, private_cluster_config=self.config.google_cloud_platform.private_cluster_config, - ).dict() + ).model_dump() elif self.config.provider == schema.ProviderEnum.azure: return AzureInputVars( name=self.config.escaped_project_name, @@ -777,7 +786,7 @@ def input_vars(self, stage_outputs: Dict[str, Dict[str, Any]]): tags=self.config.azure.tags, network_profile=self.config.azure.network_profile, max_pods=self.config.azure.max_pods, - ).dict() + ).model_dump() elif self.config.provider == schema.ProviderEnum.aws: return AWSInputVars( name=self.config.escaped_project_name, @@ -803,7 +812,7 @@ def input_vars(self, stage_outputs: Dict[str, Dict[str, Any]]): vpc_cidr_block=self.config.amazon_web_services.vpc_cidr_block, permissions_boundary=self.config.amazon_web_services.permissions_boundary, tags=self.config.amazon_web_services.tags, - ).dict() + ).model_dump() else: raise ValueError(f"Unknown provider: {self.config.provider}") diff --git a/src/_nebari/stages/kubernetes_ingress/__init__.py b/src/_nebari/stages/kubernetes_ingress/__init__.py index 13cf3f9bfc..628d383830 100644 --- a/src/_nebari/stages/kubernetes_ingress/__init__.py +++ b/src/_nebari/stages/kubernetes_ingress/__init__.py @@ -3,8 +3,7 @@ import socket import sys import time -import typing -from typing import Any, Dict, List, Type +from typing import Any, Dict, List, Optional, Type from _nebari import constants from _nebari.provider.dns.cloudflare import update_record @@ -128,23 +127,23 @@ def to_yaml(cls, representer, node): class Certificate(schema.Base): type: CertificateEnum = CertificateEnum.selfsigned # existing - secret_name: typing.Optional[str] + secret_name: Optional[str] = None # lets-encrypt - acme_email: typing.Optional[str] + acme_email: Optional[str] = None acme_server: str = "https://acme-v02.api.letsencrypt.org/directory" class DnsProvider(schema.Base): - provider: typing.Optional[str] - auto_provision: typing.Optional[bool] = False + provider: Optional[str] = None + auto_provision: Optional[bool] = False class Ingress(schema.Base): - terraform_overrides: typing.Dict = {} + terraform_overrides: Dict = {} class InputSchema(schema.Base): - domain: typing.Optional[str] + domain: Optional[str] = None certificate: Certificate = Certificate() ingress: Ingress = Ingress() dns: DnsProvider = DnsProvider() @@ -156,7 +155,7 @@ class IngressEndpoint(schema.Base): class OutputSchema(schema.Base): - load_balancer_address: typing.List[IngressEndpoint] + load_balancer_address: List[IngressEndpoint] domain: str diff --git a/src/_nebari/stages/kubernetes_initialize/__init__.py b/src/_nebari/stages/kubernetes_initialize/__init__.py index e40a69ed0f..7afd69b547 100644 --- a/src/_nebari/stages/kubernetes_initialize/__init__.py +++ b/src/_nebari/stages/kubernetes_initialize/__init__.py @@ -1,8 +1,7 @@ import sys -import typing -from typing import Any, Dict, List, Type, Union +from typing import Any, Dict, List, Optional, Type -import pydantic +from pydantic import model_validator from _nebari.stages.base import NebariTerraformStage from _nebari.stages.tf_objects import ( @@ -16,37 +15,34 @@ class ExtContainerReg(schema.Base): enabled: bool = False - access_key_id: typing.Optional[str] - secret_access_key: typing.Optional[str] - extcr_account: typing.Optional[str] - extcr_region: typing.Optional[str] - - @pydantic.root_validator - def enabled_must_have_fields(cls, values): - if values["enabled"]: + access_key_id: Optional[str] = None + secret_access_key: Optional[str] = None + extcr_account: Optional[str] = None + extcr_region: Optional[str] = None + + @model_validator(mode="after") + def enabled_must_have_fields(self): + if self.enabled: for fldname in ( "access_key_id", "secret_access_key", "extcr_account", "extcr_region", ): - if ( - fldname not in values - or values[fldname] is None - or values[fldname].strip() == "" - ): + value = getattr(self, fldname) + if value is None or value.strip() == "": raise ValueError( f"external_container_reg must contain a non-blank {fldname} when enabled is true" ) - return values + return self class InputVars(schema.Base): name: str environment: str cloud_provider: str - aws_region: Union[str, None] = None - external_container_reg: Union[ExtContainerReg, None] = None + aws_region: Optional[str] = None + external_container_reg: Optional[ExtContainerReg] = None gpu_enabled: bool = False gpu_node_group_names: List[str] = [] @@ -78,7 +74,7 @@ def input_vars(self, stage_outputs: Dict[str, Dict[str, Any]]): name=self.config.project_name, environment=self.config.namespace, cloud_provider=self.config.provider.value, - external_container_reg=self.config.external_container_reg.dict(), + external_container_reg=self.config.external_container_reg.model_dump(), ) if self.config.provider == schema.ProviderEnum.gcp: @@ -97,7 +93,7 @@ def input_vars(self, stage_outputs: Dict[str, Dict[str, Any]]): ] input_vars.aws_region = self.config.amazon_web_services.region - return input_vars.dict() + return input_vars.model_dump() def check( self, stage_outputs: Dict[str, Dict[str, Any]], disable_prompt: bool = False diff --git a/src/_nebari/stages/kubernetes_keycloak/__init__.py b/src/_nebari/stages/kubernetes_keycloak/__init__.py index 767c83189b..7ded0f1f57 100644 --- a/src/_nebari/stages/kubernetes_keycloak/__init__.py +++ b/src/_nebari/stages/kubernetes_keycloak/__init__.py @@ -6,11 +6,9 @@ import string import sys import time -import typing -from abc import ABC -from typing import Any, Dict, List, Type +from typing import Any, Dict, List, Optional, Type, Union -import pydantic +from pydantic import Field, ValidationInfo, field_validator from _nebari.stages.base import NebariTerraformStage from _nebari.stages.tf_objects import ( @@ -62,93 +60,79 @@ def to_yaml(cls, representer, node): class GitHubConfig(schema.Base): - client_id: str = pydantic.Field( - default_factory=lambda: os.environ.get("GITHUB_CLIENT_ID") + client_id: str = Field( + default_factory=lambda: os.environ.get("GITHUB_CLIENT_ID"), + validate_default=True, ) - client_secret: str = pydantic.Field( - default_factory=lambda: os.environ.get("GITHUB_CLIENT_SECRET") + client_secret: str = Field( + default_factory=lambda: os.environ.get("GITHUB_CLIENT_SECRET"), + validate_default=True, ) - @pydantic.root_validator(allow_reuse=True) - def validate_required(cls, values): - missing = [] - for k, v in { + @field_validator("client_id", "client_secret", mode="before") + @classmethod + def validate_credentials(cls, value: Optional[str], info: ValidationInfo) -> str: + variable_mapping = { "client_id": "GITHUB_CLIENT_ID", "client_secret": "GITHUB_CLIENT_SECRET", - }.items(): - if not values.get(k): - missing.append(v) - - if len(missing) > 0: + } + if value is None: raise ValueError( - f"Missing the following required environment variable(s): {', '.join(missing)}" + f"Missing the following required environment variable: {variable_mapping[info.field_name]}" ) - - return values + return value class Auth0Config(schema.Base): - client_id: str = pydantic.Field( - default_factory=lambda: os.environ.get("AUTH0_CLIENT_ID") + client_id: str = Field( + default_factory=lambda: os.environ.get("AUTH0_CLIENT_ID"), + validate_default=True, ) - client_secret: str = pydantic.Field( - default_factory=lambda: os.environ.get("AUTH0_CLIENT_SECRET") + client_secret: str = Field( + default_factory=lambda: os.environ.get("AUTH0_CLIENT_SECRET"), + validate_default=True, ) - auth0_subdomain: str = pydantic.Field( - default_factory=lambda: os.environ.get("AUTH0_DOMAIN") + auth0_subdomain: str = Field( + default_factory=lambda: os.environ.get("AUTH0_DOMAIN"), + validate_default=True, ) - @pydantic.root_validator(allow_reuse=True) - def validate_required(cls, values): - missing = [] - for k, v in { + @field_validator("client_id", "client_secret", "auth0_subdomain", mode="before") + @classmethod + def validate_credentials(cls, value: Optional[str], info: ValidationInfo) -> str: + variable_mapping = { "client_id": "AUTH0_CLIENT_ID", "client_secret": "AUTH0_CLIENT_SECRET", "auth0_subdomain": "AUTH0_DOMAIN", - }.items(): - if not values.get(k): - missing.append(v) - - if len(missing) > 0: + } + if value is None: raise ValueError( - f"Missing the following required environment variable(s): {', '.join(missing)}" + f"Missing the following required environment variable: {variable_mapping[info.field_name]} " ) + return value - return values - - -class Authentication(schema.Base, ABC): - _types: typing.Dict[str, type] = {} +class BaseAuthentication(schema.Base): type: AuthenticationEnum - # Based on https://github.com/samuelcolvin/pydantic/issues/2177#issuecomment-739578307 - # This allows type field to determine which subclass of Authentication should be used for validation. +class PasswordAuthentication(BaseAuthentication): + type: AuthenticationEnum = AuthenticationEnum.password - # Used to register automatically all the submodels in `_types`. - def __init_subclass__(cls): - cls._types[cls._typ.value] = cls - @classmethod - def __get_validators__(cls): - yield cls.validate +class Auth0Authentication(BaseAuthentication): + type: AuthenticationEnum = AuthenticationEnum.auth0 + config: Auth0Config = Field(default_factory=lambda: Auth0Config()) - @classmethod - def validate(cls, value: typing.Dict[str, typing.Any]) -> "Authentication": - if "type" not in value: - raise ValueError("type field is missing from security.authentication") - specified_type = value.get("type") - sub_class = cls._types.get(specified_type, None) +class GitHubAuthentication(BaseAuthentication): + type: AuthenticationEnum = AuthenticationEnum.github + config: GitHubConfig = Field(default_factory=lambda: GitHubConfig()) - if not sub_class: - raise ValueError( - f"No registered Authentication type called {specified_type}" - ) - # init with right submodel - return sub_class(**value) +Authentication = Union[ + PasswordAuthentication, Auth0Authentication, GitHubAuthentication +] def random_secure_string( @@ -157,33 +141,56 @@ def random_secure_string( return "".join(secrets.choice(chars) for i in range(length)) -class PasswordAuthentication(Authentication): - _typ = AuthenticationEnum.password - +class Keycloak(schema.Base): + initial_root_password: str = Field(default_factory=random_secure_string) + overrides: Dict = {} + realm_display_name: str = "Nebari" -class Auth0Authentication(Authentication): - _typ = AuthenticationEnum.auth0 - config: Auth0Config = pydantic.Field(default_factory=lambda: Auth0Config()) +auth_enum_to_model = { + AuthenticationEnum.password: PasswordAuthentication, + AuthenticationEnum.auth0: Auth0Authentication, + AuthenticationEnum.github: GitHubAuthentication, +} -class GitHubAuthentication(Authentication): - _typ = AuthenticationEnum.github - config: GitHubConfig = pydantic.Field(default_factory=lambda: GitHubConfig()) - - -class Keycloak(schema.Base): - initial_root_password: str = pydantic.Field(default_factory=random_secure_string) - overrides: typing.Dict = {} - realm_display_name: str = "Nebari" +auth_enum_to_config = { + AuthenticationEnum.auth0: Auth0Config, + AuthenticationEnum.github: GitHubConfig, +} class Security(schema.Base): - authentication: Authentication = PasswordAuthentication( - type=AuthenticationEnum.password - ) + authentication: Authentication = PasswordAuthentication() shared_users_group: bool = True keycloak: Keycloak = Keycloak() + @field_validator("authentication", mode="before") + @classmethod + def validate_authentication(cls, value: Optional[Dict]) -> Authentication: + if value is None: + return PasswordAuthentication() + if "type" not in value: + raise ValueError( + "Authentication type must be specified if authentication is set" + ) + auth_type = value["type"] if hasattr(value, "__getitem__") else value.type + if auth_type in auth_enum_to_model: + if auth_type == AuthenticationEnum.password: + return auth_enum_to_model[auth_type]() + else: + if "config" in value: + config_dict = ( + value["config"] + if hasattr(value, "__getitem__") + else value.config + ) + config = auth_enum_to_config[auth_type](**config_dict) + else: + config = auth_enum_to_config[auth_type]() + return auth_enum_to_model[auth_type](config=config) + else: + raise ValueError(f"Unsupported authentication type {auth_type}") + class InputSchema(schema.Base): security: Security = Security() @@ -226,7 +233,7 @@ def input_vars(self, stage_outputs: Dict[str, Dict[str, Any]]): node_group=stage_outputs["stages/02-infrastructure"]["node_selectors"][ "general" ], - ).dict() + ).model_dump() def check( self, stage_outputs: Dict[str, Dict[str, Any]], disable_check: bool = False diff --git a/src/_nebari/stages/kubernetes_keycloak_configuration/__init__.py b/src/_nebari/stages/kubernetes_keycloak_configuration/__init__.py index 4cb0c23aeb..1c33429e37 100644 --- a/src/_nebari/stages/kubernetes_keycloak_configuration/__init__.py +++ b/src/_nebari/stages/kubernetes_keycloak_configuration/__init__.py @@ -3,6 +3,7 @@ from typing import Any, Dict, List, Type from _nebari.stages.base import NebariTerraformStage +from _nebari.stages.kubernetes_keycloak import Authentication from _nebari.stages.tf_objects import NebariTerraformState from nebari import schema from nebari.hookspecs import NebariStage, hookimpl @@ -14,7 +15,7 @@ class InputVars(schema.Base): realm: str = "nebari" realm_display_name: str - authentication: Dict[str, Any] + authentication: Authentication keycloak_groups: List[str] = ["superadmin", "admin", "developer", "analyst"] default_groups: List[str] = ["analyst"] @@ -39,7 +40,7 @@ def input_vars(self, stage_outputs: Dict[str, Dict[str, Any]]): input_vars.keycloak_groups += users_group input_vars.default_groups += users_group - return input_vars.dict() + return input_vars.model_dump() def check( self, stage_outputs: Dict[str, Dict[str, Any]], disable_prompt: bool = False diff --git a/src/_nebari/stages/kubernetes_services/__init__.py b/src/_nebari/stages/kubernetes_services/__init__.py index 9c47fee6ec..cdc1ae9151 100644 --- a/src/_nebari/stages/kubernetes_services/__init__.py +++ b/src/_nebari/stages/kubernetes_services/__init__.py @@ -2,12 +2,10 @@ import json import sys import time -import typing -from typing import Any, Dict, List, Type +from typing import Any, Dict, List, Optional, Type, Union from urllib.parse import urlencode -import pydantic -from pydantic import Field +from pydantic import ConfigDict, Field, field_validator, model_validator from _nebari import constants from _nebari.stages.base import NebariTerraformStage @@ -81,13 +79,11 @@ class Theme(schema.Base): class KubeSpawner(schema.Base): - cpu_limit: int - cpu_guarantee: int + cpu_limit: float + cpu_guarantee: float mem_limit: str mem_guarantee: str - - class Config: - extra = "allow" + model_config = ConfigDict(extra="allow") class JupyterLabProfile(schema.Base): @@ -95,36 +91,31 @@ class JupyterLabProfile(schema.Base): display_name: str description: str default: bool = False - users: typing.Optional[typing.List[str]] - groups: typing.Optional[typing.List[str]] - kubespawner_override: typing.Optional[KubeSpawner] - - @pydantic.root_validator - def only_yaml_can_have_groups_and_users(cls, values): - if values["access"] != AccessEnum.yaml: - if ( - values.get("users", None) is not None - or values.get("groups", None) is not None - ): + users: Optional[List[str]] = None + groups: Optional[List[str]] = None + kubespawner_override: Optional[KubeSpawner] = None + + @model_validator(mode="after") + def only_yaml_can_have_groups_and_users(self): + if self.access != AccessEnum.yaml: + if self.users is not None or self.groups is not None: raise ValueError( "Profile must not contain groups or users fields unless access = yaml" ) - return values + return self class DaskWorkerProfile(schema.Base): - worker_cores_limit: int - worker_cores: int + worker_cores_limit: float + worker_cores: float worker_memory_limit: str worker_memory: str worker_threads: int = 1 - - class Config: - extra = "allow" + model_config = ConfigDict(extra="allow") class Profiles(schema.Base): - jupyterlab: typing.List[JupyterLabProfile] = [ + jupyterlab: List[JupyterLabProfile] = [ JupyterLabProfile( display_name="Small Instance", description="Stable environment with 2 cpu / 8 GB ram", @@ -147,7 +138,7 @@ class Profiles(schema.Base): ), ), ] - dask_worker: typing.Dict[str, DaskWorkerProfile] = { + dask_worker: Dict[str, DaskWorkerProfile] = { "Small Worker": DaskWorkerProfile( worker_cores_limit=2, worker_cores=1.5, @@ -164,25 +155,26 @@ class Profiles(schema.Base): ), } - @pydantic.validator("jupyterlab") - def check_default(cls, v, values): + @field_validator("jupyterlab") + @classmethod + def check_default(cls, value): """Check if only one default value is present.""" - default = [attrs["default"] for attrs in v if "default" in attrs] + default = [attrs["default"] for attrs in value if "default" in attrs] if default.count(True) > 1: raise TypeError( "Multiple default Jupyterlab profiles may cause unexpected problems." ) - return v + return value class CondaEnvironment(schema.Base): name: str - channels: typing.Optional[typing.List[str]] - dependencies: typing.List[typing.Union[str, typing.Dict[str, typing.List[str]]]] + channels: Optional[List[str]] = None + dependencies: List[Union[str, Dict[str, List[str]]]] class CondaStore(schema.Base): - extra_settings: typing.Dict[str, typing.Any] = {} + extra_settings: Dict[str, Any] = {} extra_config: str = "" image: str = "quansight/conda-store-server" image_tag: str = constants.DEFAULT_CONDA_STORE_IMAGE_TAG @@ -197,7 +189,7 @@ class NebariWorkflowController(schema.Base): class ArgoWorkflows(schema.Base): enabled: bool = True - overrides: typing.Dict = {} + overrides: Dict = {} nebari_workflow_controller: NebariWorkflowController = NebariWorkflowController() @@ -206,9 +198,9 @@ class JHubApps(schema.Base): class MonitoringOverrides(schema.Base): - loki: typing.Dict = {} - promtail: typing.Dict = {} - minio: typing.Dict = {} + loki: Dict = {} + promtail: Dict = {} + minio: Dict = {} class Monitoring(schema.Base): @@ -219,7 +211,7 @@ class Monitoring(schema.Base): class JupyterLabPioneer(schema.Base): enabled: bool = False - log_format: typing.Optional[str] = None + log_format: Optional[str] = None class Telemetry(schema.Base): @@ -227,7 +219,7 @@ class Telemetry(schema.Base): class JupyterHub(schema.Base): - overrides: typing.Dict = {} + overrides: Dict = {} class IdleCuller(schema.Base): @@ -241,10 +233,10 @@ class IdleCuller(schema.Base): class JupyterLab(schema.Base): - default_settings: typing.Dict[str, typing.Any] = {} + default_settings: Dict[str, Any] = {} idle_culler: IdleCuller = IdleCuller() - initial_repositories: typing.List[typing.Dict[str, str]] = [] - preferred_dir: typing.Optional[str] = None + initial_repositories: List[Dict[str, str]] = [] + preferred_dir: Optional[str] = None class InputSchema(schema.Base): @@ -252,7 +244,7 @@ class InputSchema(schema.Base): storage: Storage = Storage() theme: Theme = Theme() profiles: Profiles = Profiles() - environments: typing.Dict[str, CondaEnvironment] = { + environments: Dict[str, CondaEnvironment] = { "environment-dask.yaml": CondaEnvironment( name="dask", channels=["conda-forge"], @@ -374,7 +366,9 @@ class JupyterhubInputVars(schema.Base): initial_repositories: str = Field(alias="initial-repositories") jupyterhub_overrides: List[str] = Field(alias="jupyterhub-overrides") jupyterhub_stared_storage: str = Field(alias="jupyterhub-shared-storage") - jupyterhub_shared_endpoint: str = Field(None, alias="jupyterhub-shared-endpoint") + jupyterhub_shared_endpoint: Optional[str] = Field( + alias="jupyterhub-shared-endpoint", default=None + ) jupyterhub_profiles: List[JupyterLabProfile] = Field(alias="jupyterlab-profiles") jupyterhub_image: ImageNameTag = Field(alias="jupyterhub-image") jupyterhub_hub_extraEnv: str = Field(alias="jupyterhub-hub-extraEnv") @@ -382,9 +376,7 @@ class JupyterhubInputVars(schema.Base): argo_workflows_enabled: bool = Field(alias="argo-workflows-enabled") jhub_apps_enabled: bool = Field(alias="jhub-apps-enabled") cloud_provider: str = Field(alias="cloud-provider") - jupyterlab_preferred_dir: typing.Optional[str] = Field( - alias="jupyterlab-preferred-dir" - ) + jupyterlab_preferred_dir: Optional[str] = Field(alias="jupyterlab-preferred-dir") class DaskGatewayInputVars(schema.Base): @@ -405,7 +397,7 @@ class MonitoringInputVars(schema.Base): class TelemetryInputVars(schema.Base): jupyterlab_pioneer_enabled: bool = Field(alias="jupyterlab-pioneer-enabled") - jupyterlab_pioneer_log_format: typing.Optional[str] = Field( + jupyterlab_pioneer_log_format: Optional[str] = Field( alias="jupyterlab-pioneer-log-format" ) @@ -498,7 +490,7 @@ def input_vars(self, stage_outputs: Dict[str, Dict[str, Any]]): conda_store_vars = CondaStoreInputVars( conda_store_environments={ - k: v.dict() for k, v in self.config.environments.items() + k: v.model_dump() for k, v in self.config.environments.items() }, conda_store_default_namespace=self.config.conda_store.default_namespace, conda_store_filesystem_storage=self.config.storage.conda_store, @@ -511,14 +503,14 @@ def input_vars(self, stage_outputs: Dict[str, Dict[str, Any]]): ) jupyterhub_vars = JupyterhubInputVars( - jupyterhub_theme=jupyterhub_theme.dict(), + jupyterhub_theme=jupyterhub_theme.model_dump(), jupyterlab_image=_split_docker_image_name( self.config.default_images.jupyterlab ), jupyterhub_stared_storage=self.config.storage.shared_filesystem, jupyterhub_shared_endpoint=jupyterhub_shared_endpoint, cloud_provider=cloud_provider, - jupyterhub_profiles=self.config.profiles.dict()["jupyterlab"], + jupyterhub_profiles=self.config.profiles.model_dump()["jupyterlab"], jupyterhub_image=_split_docker_image_name( self.config.default_images.jupyterhub ), @@ -526,7 +518,7 @@ def input_vars(self, stage_outputs: Dict[str, Dict[str, Any]]): jupyterhub_hub_extraEnv=json.dumps( self.config.jupyterhub.overrides.get("hub", {}).get("extraEnv", []) ), - idle_culler_settings=self.config.jupyterlab.idle_culler.dict(), + idle_culler_settings=self.config.jupyterlab.idle_culler.model_dump(), argo_workflows_enabled=self.config.argo_workflows.enabled, jhub_apps_enabled=self.config.jhub_apps.enabled, initial_repositories=str(self.config.jupyterlab.initial_repositories), @@ -538,7 +530,7 @@ def input_vars(self, stage_outputs: Dict[str, Dict[str, Any]]): dask_worker_image=_split_docker_image_name( self.config.default_images.dask_worker ), - dask_gateway_profiles=self.config.profiles.dict()["dask_worker"], + dask_gateway_profiles=self.config.profiles.model_dump()["dask_worker"], cloud_provider=cloud_provider, ) @@ -568,13 +560,13 @@ def input_vars(self, stage_outputs: Dict[str, Dict[str, Any]]): ) return { - **kubernetes_services_vars.dict(by_alias=True), - **conda_store_vars.dict(by_alias=True), - **jupyterhub_vars.dict(by_alias=True), - **dask_gateway_vars.dict(by_alias=True), - **monitoring_vars.dict(by_alias=True), - **argo_workflows_vars.dict(by_alias=True), - **telemetry_vars.dict(by_alias=True), + **kubernetes_services_vars.model_dump(by_alias=True), + **conda_store_vars.model_dump(by_alias=True), + **jupyterhub_vars.model_dump(by_alias=True), + **dask_gateway_vars.model_dump(by_alias=True), + **monitoring_vars.model_dump(by_alias=True), + **argo_workflows_vars.model_dump(by_alias=True), + **telemetry_vars.model_dump(by_alias=True), } def check( diff --git a/src/_nebari/stages/nebari_tf_extensions/__init__.py b/src/_nebari/stages/nebari_tf_extensions/__init__.py index b5bfdbec4f..eaaf131117 100644 --- a/src/_nebari/stages/nebari_tf_extensions/__init__.py +++ b/src/_nebari/stages/nebari_tf_extensions/__init__.py @@ -1,5 +1,4 @@ -import typing -from typing import Any, Dict, List, Type +from typing import Any, Dict, List, Optional, Type from _nebari.stages.base import NebariTerraformStage from _nebari.stages.tf_objects import ( @@ -25,8 +24,8 @@ class NebariExtension(schema.Base): keycloakadmin: bool = False jwt: bool = False nebariconfigyaml: bool = False - logout: typing.Optional[str] - envs: typing.Optional[typing.List[NebariExtensionEnv]] + logout: Optional[str] = None + envs: Optional[List[NebariExtensionEnv]] = None class HelmExtension(schema.Base): @@ -34,12 +33,12 @@ class HelmExtension(schema.Base): repository: str chart: str version: str - overrides: typing.Dict = {} + overrides: Dict = {} class InputSchema(schema.Base): - helm_extensions: typing.List[HelmExtension] = [] - tf_extensions: typing.List[NebariExtension] = [] + helm_extensions: List[HelmExtension] = [] + tf_extensions: List[NebariExtension] = [] class OutputSchema(schema.Base): @@ -67,12 +66,12 @@ def input_vars(self, stage_outputs: Dict[str, Dict[str, Any]]): "realm_id": stage_outputs["stages/06-kubernetes-keycloak-configuration"][ "realm_id" ]["value"], - "tf_extensions": [_.dict() for _ in self.config.tf_extensions], - "nebari_config_yaml": self.config.dict(), + "tf_extensions": [_.model_dump() for _ in self.config.tf_extensions], + "nebari_config_yaml": self.config.model_dump(), "keycloak_nebari_bot_password": stage_outputs[ "stages/05-kubernetes-keycloak" ]["keycloak_nebari_bot_password"]["value"], - "helm_extensions": [_.dict() for _ in self.config.helm_extensions], + "helm_extensions": [_.model_dump() for _ in self.config.helm_extensions], } diff --git a/src/_nebari/stages/terraform_state/__init__.py b/src/_nebari/stages/terraform_state/__init__.py index 094231e967..edd4b9ed8a 100644 --- a/src/_nebari/stages/terraform_state/__init__.py +++ b/src/_nebari/stages/terraform_state/__init__.py @@ -4,10 +4,9 @@ import os import pathlib import re -import typing -from typing import Any, Dict, List, Tuple, Type +from typing import Any, Dict, List, Optional, Tuple, Type -import pydantic +from pydantic import field_validator from _nebari.provider import terraform from _nebari.provider.cloud import azure_cloud @@ -39,10 +38,11 @@ class AzureInputVars(schema.Base): region: str storage_account_postfix: str state_resource_group_name: str - tags: Dict[str, str] = {} + tags: Dict[str, str] - @pydantic.validator("state_resource_group_name") - def _validate_resource_group_name(cls, value): + @field_validator("state_resource_group_name") + @classmethod + def _validate_resource_group_name(cls, value: str) -> str: if value is None: return value length = len(value) + len(AZURE_TF_STATE_RESOURCE_GROUP_SUFFIX) @@ -59,9 +59,10 @@ def _validate_resource_group_name(cls, value): return value - @pydantic.validator("tags") - def _validate_tags(cls, tags): - return azure_cloud.validate_tags(tags) + @field_validator("tags") + @classmethod + def _validate_tags(cls, value: Dict[str, str]) -> Dict[str, str]: + return azure_cloud.validate_tags(value) class AWSInputVars(schema.Base): @@ -82,8 +83,8 @@ def to_yaml(cls, representer, node): class TerraformState(schema.Base): type: TerraformStateEnum = TerraformStateEnum.remote - backend: typing.Optional[str] - config: typing.Dict[str, str] = {} + backend: Optional[str] = None + config: Dict[str, str] = {} class InputSchema(schema.Base): @@ -192,18 +193,18 @@ def input_vars(self, stage_outputs: Dict[str, Dict[str, Any]]): name=self.config.project_name, namespace=self.config.namespace, region=self.config.digital_ocean.region, - ).dict() + ).model_dump() elif self.config.provider == schema.ProviderEnum.gcp: return GCPInputVars( name=self.config.project_name, namespace=self.config.namespace, region=self.config.google_cloud_platform.region, - ).dict() + ).model_dump() elif self.config.provider == schema.ProviderEnum.aws: return AWSInputVars( name=self.config.project_name, namespace=self.config.namespace, - ).dict() + ).model_dump() elif self.config.provider == schema.ProviderEnum.azure: return AzureInputVars( name=self.config.project_name, @@ -217,7 +218,7 @@ def input_vars(self, stage_outputs: Dict[str, Dict[str, Any]]): suffix=AZURE_TF_STATE_RESOURCE_GROUP_SUFFIX, ), tags=self.config.azure.tags, - ).dict() + ).model_dump() elif ( self.config.provider == schema.ProviderEnum.local or self.config.provider == schema.ProviderEnum.existing diff --git a/src/_nebari/subcommands/init.py b/src/_nebari/subcommands/init.py index 979837b81e..9040f3d201 100644 --- a/src/_nebari/subcommands/init.py +++ b/src/_nebari/subcommands/init.py @@ -419,7 +419,7 @@ def check_cloud_provider_kubernetes_version( f"Invalid Kubernetes version `{kubernetes_version}`. Please refer to the GCP docs for a list of valid versions: {versions}" ) elif cloud_provider == ProviderEnum.do.value.lower(): - versions = digital_ocean.kubernetes_versions(region) + versions = digital_ocean.kubernetes_versions() if not kubernetes_version or kubernetes_version == LATEST: kubernetes_version = get_latest_kubernetes_version(versions) @@ -921,7 +921,11 @@ def if_used(key, model=inputs, ignore_list=["cloud_provider"]): return b.format(key=key, value=value).replace("_", "-") cmds = " ".join( - [_ for _ in [if_used(_) for _ in inputs.dict().keys()] if _ is not None] + [ + _ + for _ in [if_used(_) for _ in inputs.model_dump().keys()] + if _ is not None + ] ) rich.print( diff --git a/src/_nebari/upgrade.py b/src/_nebari/upgrade.py index 8129188d8e..64e593be66 100644 --- a/src/_nebari/upgrade.py +++ b/src/_nebari/upgrade.py @@ -9,7 +9,7 @@ import rich from packaging.version import Version -from pydantic.error_wrappers import ValidationError +from pydantic import ValidationError from rich.prompt import Prompt from _nebari.config import backup_configuration diff --git a/src/_nebari/utils.py b/src/_nebari/utils.py index 3488c6f0e8..3ae4ad4bd8 100644 --- a/src/_nebari/utils.py +++ b/src/_nebari/utils.py @@ -11,7 +11,7 @@ import time import warnings from pathlib import Path -from typing import Dict, List +from typing import Dict, List, Set from ruamel.yaml import YAML @@ -338,3 +338,18 @@ def get_provider_config_block_name(provider): return PROVIDER_CONFIG_NAMES[provider] else: return provider + + +def check_environment_variables(variables: Set[str], reference: str) -> None: + """Check that environment variables are set.""" + required_variables = { + variable: os.environ.get(variable, None) for variable in variables + } + missing_variables = { + variable for variable, value in required_variables.items() if value is None + } + if missing_variables: + raise ValueError( + f"""Missing the following required environment variables: {required_variables}\n + Please see the documentation for more information: {reference}""" + ) diff --git a/src/nebari/schema.py b/src/nebari/schema.py index ab90f8ebc7..70b9589e6f 100644 --- a/src/nebari/schema.py +++ b/src/nebari/schema.py @@ -1,6 +1,8 @@ import enum +from typing import Annotated import pydantic +from pydantic import ConfigDict, Field, StringConstraints, field_validator from ruamel.yaml import yaml_object from _nebari.utils import escape_string, yaml @@ -8,26 +10,23 @@ # Regex for suitable project names project_name_regex = r"^[A-Za-z][A-Za-z0-9\-_]{1,14}[A-Za-z0-9]$" -project_name_pydantic = pydantic.constr(regex=project_name_regex) +project_name_pydantic = Annotated[str, StringConstraints(pattern=project_name_regex)] # Regex for suitable namespaces namespace_regex = r"^[A-Za-z][A-Za-z\-_]*[A-Za-z]$" -namespace_pydantic = pydantic.constr(regex=namespace_regex) +namespace_pydantic = Annotated[str, StringConstraints(pattern=namespace_regex)] email_regex = "^[^ @]+@[^ @]+\\.[^ @]+$" -email_pydantic = pydantic.constr(regex=email_regex) +email_pydantic = Annotated[str, StringConstraints(pattern=email_regex)] github_url_regex = "^(https://)?github.com/([^/]+)/([^/]+)/?$" -github_url_pydantic = pydantic.constr(regex=github_url_regex) +github_url_pydantic = Annotated[str, StringConstraints(pattern=github_url_regex)] class Base(pydantic.BaseModel): - ... - - class Config: - extra = "forbid" - validate_assignment = True - allow_population_by_field_name = True + model_config = ConfigDict( + extra="forbid", validate_assignment=True, populate_by_name=True + ) @yaml_object(yaml) @@ -49,7 +48,7 @@ class Main(Base): namespace: namespace_pydantic = "dev" provider: ProviderEnum = ProviderEnum.local # In nebari_version only use major.minor.patch version - drop any pre/post/dev suffixes - nebari_version: str = __version__ + nebari_version: Annotated[str, Field(validate_default=True)] = __version__ prevent_deploy: bool = ( False # Optional, but will be given default value if not present @@ -57,19 +56,13 @@ class Main(Base): # If the nebari_version in the schema is old # we must tell the user to first run nebari upgrade - @pydantic.validator("nebari_version", pre=True, always=True) - def check_default(cls, v): - """ - Always called even if nebari_version is not supplied at all (so defaults to ''). That way we can give a more helpful error message. - """ - if not cls.is_version_accepted(v): - if v == "": - v = "not supplied" - raise ValueError( - f"nebari_version in the config file must be equivalent to {__version__} to be processed by this version of nebari (your config file version is {v})." - " Install a different version of nebari or run nebari upgrade to ensure your config file is compatible." - ) - return v + @field_validator("nebari_version") + @classmethod + def check_default(cls, value): + assert cls.is_version_accepted( + value + ), f"nebari_version={value} is not an accepted version, it must be equivalent to {__version__}.\nInstall a different version of nebari or run nebari upgrade to ensure your config file is compatible." + return value @classmethod def is_version_accepted(cls, v): diff --git a/tests/tests_integration/deployment_fixtures.py b/tests/tests_integration/deployment_fixtures.py index 1709bd7262..f5752d4c24 100644 --- a/tests/tests_integration/deployment_fixtures.py +++ b/tests/tests_integration/deployment_fixtures.py @@ -167,7 +167,7 @@ def deploy(request): config = add_preemptible_node_group(config, cloud=cloud) print("*" * 100) - pprint.pprint(config.dict()) + pprint.pprint(config.model_dump()) print("*" * 100) # render diff --git a/tests/tests_unit/cli_validate/local.error.authentication-type-called-custom.yaml b/tests/tests_unit/cli_validate/local.error.authentication-type-custom.yaml similarity index 100% rename from tests/tests_unit/cli_validate/local.error.authentication-type-called-custom.yaml rename to tests/tests_unit/cli_validate/local.error.authentication-type-custom.yaml diff --git a/tests/tests_unit/cli_validate/local.error.extra-fields.yaml b/tests/tests_unit/cli_validate/local.error.extra-inputs.yaml similarity index 100% rename from tests/tests_unit/cli_validate/local.error.extra-fields.yaml rename to tests/tests_unit/cli_validate/local.error.extra-inputs.yaml diff --git a/tests/tests_unit/conftest.py b/tests/tests_unit/conftest.py index e98661c214..d78dfdf1ec 100644 --- a/tests/tests_unit/conftest.py +++ b/tests/tests_unit/conftest.py @@ -172,7 +172,7 @@ def nebari_config_options(request) -> schema.Main: @pytest.fixture def nebari_config(nebari_config_options): - return nebari_plugin_manager.config_schema.parse_obj( + return nebari_plugin_manager.config_schema.model_validate( render_config(**nebari_config_options) ) diff --git a/tests/tests_unit/test_cli_validate.py b/tests/tests_unit/test_cli_validate.py index 00c46c2cd6..faf2efa8a1 100644 --- a/tests/tests_unit/test_cli_validate.py +++ b/tests/tests_unit/test_cli_validate.py @@ -134,7 +134,13 @@ def test_cli_validate_from_env(): "key, value, provider, expected_message, addl_config", [ ("NEBARI_SECRET__project_name", "123invalid", "local", "validation error", {}), - ("NEBARI_SECRET__this_is_an_error", "true", "local", "object has no field", {}), + ( + "NEBARI_SECRET__this_is_an_error", + "true", + "local", + "Object has no attribute", + {}, + ), ( "NEBARI_SECRET__amazon_web_services__kubernetes_version", "1.0", diff --git a/tests/tests_unit/test_config.py b/tests/tests_unit/test_config.py index ccc52543d7..f20eb3f671 100644 --- a/tests/tests_unit/test_config.py +++ b/tests/tests_unit/test_config.py @@ -97,7 +97,7 @@ def test_read_configuration_non_existent_file(nebari_config): def test_write_configuration_with_dict(nebari_config, tmp_path): config_file = tmp_path / "nebari-config-dict.yaml" - config_dict = nebari_config.dict() + config_dict = nebari_config.model_dump() write_configuration(config_file, config_dict) read_config = read_configuration(config_file, nebari_config.__class__) diff --git a/tests/tests_unit/test_schema.py b/tests/tests_unit/test_schema.py index b4fb58bc62..446b6d1085 100644 --- a/tests/tests_unit/test_schema.py +++ b/tests/tests_unit/test_schema.py @@ -1,7 +1,7 @@ from contextlib import nullcontext import pytest -from pydantic.error_wrappers import ValidationError +from pydantic import ValidationError from nebari import schema from nebari.plugins import nebari_plugin_manager @@ -125,7 +125,7 @@ def test_no_provider(config_schema, provider, full_name, default_fields): } config = config_schema(**config_dict) assert config.provider == provider - assert full_name in config.dict() + assert full_name in config.model_dump() def test_multiple_providers(config_schema): @@ -164,6 +164,6 @@ def test_setted_provider(config_schema, provider): } config = config_schema(**config_dict) assert config.provider == provider - result_config_dict = config.dict() + result_config_dict = config.model_dump() assert provider in result_config_dict assert result_config_dict[provider]["kube_context"] == "some_context"