From 747a2933b5d32b02d18d16c2a1b866e1ee6280aa Mon Sep 17 00:00:00 2001 From: Adam Lewis <23342526+Adam-D-Lewis@users.noreply.github.com> Date: Mon, 30 Dec 2024 15:03:12 -0600 Subject: [PATCH] add test --- src/_nebari/stages/infrastructure/__init__.py | 70 +++++++++++++------ src/nebari/schema.py | 1 - tests/tests_unit/test_cli_init.py | 2 +- tests/tests_unit/test_schema.py | 49 +++++++++++++ 4 files changed, 98 insertions(+), 24 deletions(-) diff --git a/src/_nebari/stages/infrastructure/__init__.py b/src/_nebari/stages/infrastructure/__init__.py index aa8036fbeb..8e7fc584ed 100644 --- a/src/_nebari/stages/infrastructure/__init__.py +++ b/src/_nebari/stages/infrastructure/__init__.py @@ -8,7 +8,7 @@ import tempfile from typing import Annotated, Any, Dict, List, Literal, Optional, Tuple, Type, Union -from pydantic import ConfigDict, Field, field_validator, model_validator +from pydantic import AfterValidator, ConfigDict, Field, field_validator, model_validator from _nebari import constants from _nebari.provider import opentofu @@ -42,33 +42,56 @@ class NodeGroup(schema.Base): instance: str min_nodes: Annotated[int, Field(ge=0)] = 0 max_nodes: Annotated[int, Field(ge=1)] = 1 - taints: Optional[List[schema.Taint]] = [] + taints: Optional[List[schema.Taint]] = None @field_validator("taints", mode="before") - def validate_taint_strings(cls, value: list[Any]): + def validate_taint_strings(cls, taints: list[Any]): + if taints is None: + return taints + TAINT_STR_REGEX = re.compile(r"(\w+)=(\w+):(\w+)") return_value = [] - for taint in value: + for taint in taints: if not isinstance(taint, str): return_value.append(taint) else: match = TAINT_STR_REGEX.match(taint) if not match: raise ValueError(f"Invalid taint string: {taint}") - key, value, effect = match.groups() - parsed_taint = schema.Taint(key=key, value=value, effect=effect) + key, taints, effect = match.groups() + parsed_taint = schema.Taint(key=key, value=taints, effect=effect) return_value.append(parsed_taint) return return_value -DEFAULT_GENERAL_TAINTS = [] -DEFAULT_USER_TAINTS = [schema.Taint(key="dedicated", value="user", effect="NoSchedule")] -DEFAULT_WORKER_TAINTS = [ - schema.Taint(key="dedicated", value="worker", effect="NoSchedule") +DEFAULT_GENERAL_NODE_GROUP_TAINTS = [] +DEFAULT_NODE_GROUP_TAINTS = [ + schema.Taint(key="dedicated", value="nebari", effect="NoSchedule") ] +def set_missing_taints_to_default_taints(node_groups: NodeGroup) -> NodeGroup: + + for node_group_name, node_group in node_groups.items(): + if node_group.taints is None: + if node_group_name == "general": + node_group.taints = DEFAULT_GENERAL_NODE_GROUP_TAINTS + else: + node_group.taints = DEFAULT_NODE_GROUP_TAINTS + return node_groups + + +class DigitalOceanInputVars(schema.Base): + name: str + environment: str + region: str + tags: List[str] + kubernetes_version: str + node_groups: Dict[str, "DigitalOceanNodeGroup"] + kubeconfig_filename: str = get_kubeconfig_filename() + + class GCPNodeGroupInputVars(schema.Base): name: str instance_type: str @@ -362,19 +385,16 @@ class GCPNodeGroup(NodeGroup): instance="e2-standard-8", min_nodes=1, max_nodes=1, - taints=DEFAULT_GENERAL_TAINTS, ), "user": GCPNodeGroup( instance="e2-standard-4", min_nodes=0, max_nodes=5, - taints=DEFAULT_USER_TAINTS, ), "worker": GCPNodeGroup( instance="e2-standard-4", min_nodes=0, max_nodes=5, - taints=DEFAULT_WORKER_TAINTS, ), } @@ -388,7 +408,9 @@ class GoogleCloudPlatformProvider(schema.Base): kubernetes_version: str availability_zones: Optional[List[str]] = [] release_channel: str = constants.DEFAULT_GKE_RELEASE_CHANNEL - node_groups: Dict[str, GCPNodeGroup] = DEFAULT_GCP_NODE_GROUPS + node_groups: Annotated[ + Dict[str, GCPNodeGroup], AfterValidator(set_missing_taints_to_default_taints) + ] = Field(DEFAULT_GCP_NODE_GROUPS, validate_default=True) tags: Optional[List[str]] = [] networking_mode: str = "ROUTE" network: str = "default" @@ -447,16 +469,16 @@ class AzureNodeGroup(NodeGroup): instance="Standard_D8_v3", min_nodes=1, max_nodes=1, - taints=DEFAULT_GENERAL_TAINTS, ), "user": AzureNodeGroup( - instance="Standard_D4_v3", min_nodes=0, max_nodes=5, taints=DEFAULT_USER_TAINTS + instance="Standard_D4_v3", + min_nodes=0, + max_nodes=5, ), "worker": AzureNodeGroup( instance="Standard_D4_v3", min_nodes=0, max_nodes=5, - taints=DEFAULT_WORKER_TAINTS, ), } @@ -466,7 +488,9 @@ class AzureProvider(schema.Base): kubernetes_version: Optional[str] = None storage_account_postfix: str resource_group_name: Optional[str] = None - node_groups: Dict[str, AzureNodeGroup] = DEFAULT_AZURE_NODE_GROUPS + node_groups: Annotated[ + Dict[str, AzureNodeGroup], AfterValidator(set_missing_taints_to_default_taints) + ] = Field(DEFAULT_AZURE_NODE_GROUPS, validate_default=True) storage_account_postfix: str vnet_subnet_id: Optional[str] = None private_cluster_enabled: bool = False @@ -537,21 +561,21 @@ def check_launch_template(cls, values): DEFAULT_AWS_NODE_GROUPS = { "general": AWSNodeGroup( - instance="m5.2xlarge", min_nodes=1, max_nodes=1, taints=DEFAULT_GENERAL_TAINTS + instance="m5.2xlarge", + min_nodes=1, + max_nodes=1, ), "user": AWSNodeGroup( instance="m5.xlarge", min_nodes=0, max_nodes=5, single_subnet=False, - taints=DEFAULT_USER_TAINTS, ), "worker": AWSNodeGroup( instance="m5.xlarge", min_nodes=0, max_nodes=5, single_subnet=False, - taints=DEFAULT_WORKER_TAINTS, ), } @@ -560,7 +584,9 @@ class AmazonWebServicesProvider(schema.Base): region: str kubernetes_version: str availability_zones: Optional[List[str]] - node_groups: Dict[str, AWSNodeGroup] = DEFAULT_AWS_NODE_GROUPS + node_groups: Annotated[ + Dict[str, AWSNodeGroup], AfterValidator(set_missing_taints_to_default_taints) + ] = Field(DEFAULT_AWS_NODE_GROUPS, validate_default=True) eks_endpoint_access: Optional[ Literal["private", "public", "public_and_private"] ] = "public" diff --git a/src/nebari/schema.py b/src/nebari/schema.py index e58f8f0f9d..bba6e9de11 100644 --- a/src/nebari/schema.py +++ b/src/nebari/schema.py @@ -129,5 +129,4 @@ class Taint(Base): ProviderEnum.gcp: "google_cloud_platform", ProviderEnum.aws: "amazon_web_services", ProviderEnum.azure: "azure", - ProviderEnum.do: "digital_ocean", } diff --git a/tests/tests_unit/test_cli_init.py b/tests/tests_unit/test_cli_init.py index 03b22557ae..25cfcdbe0d 100644 --- a/tests/tests_unit/test_cli_init.py +++ b/tests/tests_unit/test_cli_init.py @@ -209,7 +209,7 @@ def assert_nebari_init_args( app, args + ["--output", tmp_file.resolve()], input=input ) - assert not result.exception + assert not result.exception, result.output assert 0 == result.exit_code assert tmp_file.exists() is True diff --git a/tests/tests_unit/test_schema.py b/tests/tests_unit/test_schema.py index 5c21aef8d6..7025c98b0a 100644 --- a/tests/tests_unit/test_schema.py +++ b/tests/tests_unit/test_schema.py @@ -3,6 +3,10 @@ import pytest from pydantic import ValidationError +from _nebari.stages.infrastructure import ( + DEFAULT_GENERAL_NODE_GROUP_TAINTS, + DEFAULT_NODE_GROUP_TAINTS, +) from nebari import schema from nebari.plugins import nebari_plugin_manager @@ -82,6 +86,51 @@ def test_provider_validation(config_schema, provider, exception): assert config.provider == provider +@pytest.mark.parametrize( + "provider, full_name, default_fields", + [ + ( + "aws", + "amazon_web_services", + {"region": "us-east-1", "kubernetes_version": "1.18"}, + ), + ( + "gcp", + "google_cloud_platform", + { + "region": "us-east1", + "project": "test-project", + "kubernetes_version": "1.18", + }, + ), + ( + "azure", + "azure", + { + "region": "eastus", + "kubernetes_version": "1.18", + "storage_account_postfix": "test", + }, + ), + ], +) +def test_node_group_default_taints_set( + config_schema, provider, full_name, default_fields +): + config_dict = { + "project_name": "test", + "provider": f"{provider}", + f"{full_name}": default_fields, + } + config = config_schema(**config_dict) + ng = getattr(config, schema.provider_enum_name_map[config.provider]).node_groups + for ng_name in ng: + if ng_name == "general": + assert ng[ng_name].taints == DEFAULT_GENERAL_NODE_GROUP_TAINTS + else: + assert ng[ng_name].taints == DEFAULT_NODE_GROUP_TAINTS + + @pytest.mark.parametrize( "provider, full_name, default_fields", [