Skip to content

Commit

Permalink
add test
Browse files Browse the repository at this point in the history
  • Loading branch information
Adam-D-Lewis committed Dec 30, 2024
1 parent 2264558 commit 747a293
Show file tree
Hide file tree
Showing 4 changed files with 98 additions and 24 deletions.
70 changes: 48 additions & 22 deletions src/_nebari/stages/infrastructure/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import tempfile
from typing import Annotated, Any, Dict, List, Literal, Optional, Tuple, Type, Union

from pydantic import ConfigDict, Field, field_validator, model_validator
from pydantic import AfterValidator, ConfigDict, Field, field_validator, model_validator

from _nebari import constants
from _nebari.provider import opentofu
Expand Down Expand Up @@ -42,33 +42,56 @@ class NodeGroup(schema.Base):
instance: str
min_nodes: Annotated[int, Field(ge=0)] = 0
max_nodes: Annotated[int, Field(ge=1)] = 1
taints: Optional[List[schema.Taint]] = []
taints: Optional[List[schema.Taint]] = None

@field_validator("taints", mode="before")
def validate_taint_strings(cls, value: list[Any]):
def validate_taint_strings(cls, taints: list[Any]):
if taints is None:
return taints

TAINT_STR_REGEX = re.compile(r"(\w+)=(\w+):(\w+)")
return_value = []
for taint in value:
for taint in taints:
if not isinstance(taint, str):
return_value.append(taint)
else:
match = TAINT_STR_REGEX.match(taint)
if not match:
raise ValueError(f"Invalid taint string: {taint}")
key, value, effect = match.groups()
parsed_taint = schema.Taint(key=key, value=value, effect=effect)
key, taints, effect = match.groups()
parsed_taint = schema.Taint(key=key, value=taints, effect=effect)
return_value.append(parsed_taint)

return return_value


DEFAULT_GENERAL_TAINTS = []
DEFAULT_USER_TAINTS = [schema.Taint(key="dedicated", value="user", effect="NoSchedule")]
DEFAULT_WORKER_TAINTS = [
schema.Taint(key="dedicated", value="worker", effect="NoSchedule")
DEFAULT_GENERAL_NODE_GROUP_TAINTS = []
DEFAULT_NODE_GROUP_TAINTS = [
schema.Taint(key="dedicated", value="nebari", effect="NoSchedule")
]


def set_missing_taints_to_default_taints(node_groups: NodeGroup) -> NodeGroup:

for node_group_name, node_group in node_groups.items():
if node_group.taints is None:
if node_group_name == "general":
node_group.taints = DEFAULT_GENERAL_NODE_GROUP_TAINTS
else:
node_group.taints = DEFAULT_NODE_GROUP_TAINTS
return node_groups


class DigitalOceanInputVars(schema.Base):
name: str
environment: str
region: str
tags: List[str]
kubernetes_version: str
node_groups: Dict[str, "DigitalOceanNodeGroup"]
kubeconfig_filename: str = get_kubeconfig_filename()


class GCPNodeGroupInputVars(schema.Base):
name: str
instance_type: str
Expand Down Expand Up @@ -362,19 +385,16 @@ class GCPNodeGroup(NodeGroup):
instance="e2-standard-8",
min_nodes=1,
max_nodes=1,
taints=DEFAULT_GENERAL_TAINTS,
),
"user": GCPNodeGroup(
instance="e2-standard-4",
min_nodes=0,
max_nodes=5,
taints=DEFAULT_USER_TAINTS,
),
"worker": GCPNodeGroup(
instance="e2-standard-4",
min_nodes=0,
max_nodes=5,
taints=DEFAULT_WORKER_TAINTS,
),
}

Expand All @@ -388,7 +408,9 @@ class GoogleCloudPlatformProvider(schema.Base):
kubernetes_version: str
availability_zones: Optional[List[str]] = []
release_channel: str = constants.DEFAULT_GKE_RELEASE_CHANNEL
node_groups: Dict[str, GCPNodeGroup] = DEFAULT_GCP_NODE_GROUPS
node_groups: Annotated[
Dict[str, GCPNodeGroup], AfterValidator(set_missing_taints_to_default_taints)
] = Field(DEFAULT_GCP_NODE_GROUPS, validate_default=True)
tags: Optional[List[str]] = []
networking_mode: str = "ROUTE"
network: str = "default"
Expand Down Expand Up @@ -447,16 +469,16 @@ class AzureNodeGroup(NodeGroup):
instance="Standard_D8_v3",
min_nodes=1,
max_nodes=1,
taints=DEFAULT_GENERAL_TAINTS,
),
"user": AzureNodeGroup(
instance="Standard_D4_v3", min_nodes=0, max_nodes=5, taints=DEFAULT_USER_TAINTS
instance="Standard_D4_v3",
min_nodes=0,
max_nodes=5,
),
"worker": AzureNodeGroup(
instance="Standard_D4_v3",
min_nodes=0,
max_nodes=5,
taints=DEFAULT_WORKER_TAINTS,
),
}

Expand All @@ -466,7 +488,9 @@ class AzureProvider(schema.Base):
kubernetes_version: Optional[str] = None
storage_account_postfix: str
resource_group_name: Optional[str] = None
node_groups: Dict[str, AzureNodeGroup] = DEFAULT_AZURE_NODE_GROUPS
node_groups: Annotated[
Dict[str, AzureNodeGroup], AfterValidator(set_missing_taints_to_default_taints)
] = Field(DEFAULT_AZURE_NODE_GROUPS, validate_default=True)
storage_account_postfix: str
vnet_subnet_id: Optional[str] = None
private_cluster_enabled: bool = False
Expand Down Expand Up @@ -537,21 +561,21 @@ def check_launch_template(cls, values):

DEFAULT_AWS_NODE_GROUPS = {
"general": AWSNodeGroup(
instance="m5.2xlarge", min_nodes=1, max_nodes=1, taints=DEFAULT_GENERAL_TAINTS
instance="m5.2xlarge",
min_nodes=1,
max_nodes=1,
),
"user": AWSNodeGroup(
instance="m5.xlarge",
min_nodes=0,
max_nodes=5,
single_subnet=False,
taints=DEFAULT_USER_TAINTS,
),
"worker": AWSNodeGroup(
instance="m5.xlarge",
min_nodes=0,
max_nodes=5,
single_subnet=False,
taints=DEFAULT_WORKER_TAINTS,
),
}

Expand All @@ -560,7 +584,9 @@ class AmazonWebServicesProvider(schema.Base):
region: str
kubernetes_version: str
availability_zones: Optional[List[str]]
node_groups: Dict[str, AWSNodeGroup] = DEFAULT_AWS_NODE_GROUPS
node_groups: Annotated[
Dict[str, AWSNodeGroup], AfterValidator(set_missing_taints_to_default_taints)
] = Field(DEFAULT_AWS_NODE_GROUPS, validate_default=True)
eks_endpoint_access: Optional[
Literal["private", "public", "public_and_private"]
] = "public"
Expand Down
1 change: 0 additions & 1 deletion src/nebari/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,5 +129,4 @@ class Taint(Base):
ProviderEnum.gcp: "google_cloud_platform",
ProviderEnum.aws: "amazon_web_services",
ProviderEnum.azure: "azure",
ProviderEnum.do: "digital_ocean",
}
2 changes: 1 addition & 1 deletion tests/tests_unit/test_cli_init.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,7 +209,7 @@ def assert_nebari_init_args(
app, args + ["--output", tmp_file.resolve()], input=input
)

assert not result.exception
assert not result.exception, result.output
assert 0 == result.exit_code
assert tmp_file.exists() is True

Expand Down
49 changes: 49 additions & 0 deletions tests/tests_unit/test_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,10 @@
import pytest
from pydantic import ValidationError

from _nebari.stages.infrastructure import (
DEFAULT_GENERAL_NODE_GROUP_TAINTS,
DEFAULT_NODE_GROUP_TAINTS,
)
from nebari import schema
from nebari.plugins import nebari_plugin_manager

Expand Down Expand Up @@ -82,6 +86,51 @@ def test_provider_validation(config_schema, provider, exception):
assert config.provider == provider


@pytest.mark.parametrize(
"provider, full_name, default_fields",
[
(
"aws",
"amazon_web_services",
{"region": "us-east-1", "kubernetes_version": "1.18"},
),
(
"gcp",
"google_cloud_platform",
{
"region": "us-east1",
"project": "test-project",
"kubernetes_version": "1.18",
},
),
(
"azure",
"azure",
{
"region": "eastus",
"kubernetes_version": "1.18",
"storage_account_postfix": "test",
},
),
],
)
def test_node_group_default_taints_set(
config_schema, provider, full_name, default_fields
):
config_dict = {
"project_name": "test",
"provider": f"{provider}",
f"{full_name}": default_fields,
}
config = config_schema(**config_dict)
ng = getattr(config, schema.provider_enum_name_map[config.provider]).node_groups
for ng_name in ng:
if ng_name == "general":
assert ng[ng_name].taints == DEFAULT_GENERAL_NODE_GROUP_TAINTS
else:
assert ng[ng_name].taints == DEFAULT_NODE_GROUP_TAINTS


@pytest.mark.parametrize(
"provider, full_name, default_fields",
[
Expand Down

0 comments on commit 747a293

Please sign in to comment.