Skip to content

Commit

Permalink
Merge branch 'main' into opentofu/stages/sync-versions
Browse files Browse the repository at this point in the history
  • Loading branch information
viniciusdc authored Jan 7, 2025
2 parents c4d3786 + 4a85e87 commit 84601d6
Show file tree
Hide file tree
Showing 19 changed files with 539 additions and 122 deletions.
262 changes: 148 additions & 114 deletions docs-sphinx/cli.html

Large diffs are not rendered by default.

1 change: 1 addition & 0 deletions src/_nebari/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ def common(
[],
"--import-plugin",
help="Import nebari plugin",
callback=import_plugin,
),
excluded_stages: typing.List[str] = typer.Option(
[],
Expand Down
54 changes: 54 additions & 0 deletions src/_nebari/config_set.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
import logging
import pathlib
from typing import Optional

from packaging.requirements import SpecifierSet
from pydantic import BaseModel, ConfigDict, field_validator

from _nebari._version import __version__
from _nebari.utils import yaml

logger = logging.getLogger(__name__)


class ConfigSetMetadata(BaseModel):
model_config: ConfigDict = ConfigDict(extra="allow", arbitrary_types_allowed=True)
name: str # for use with guided init
description: Optional[str] = None
nebari_version: str | SpecifierSet

@field_validator("nebari_version")
@classmethod
def validate_version_requirement(cls, version_req):
if isinstance(version_req, str):
version_req = SpecifierSet(version_req, prereleases=True)

return version_req

def check_version(self, version):
if not self.nebari_version.contains(version, prereleases=True):
raise ValueError(
f'Nebari version "{version}" is not compatible with '
f'version requirement {self.nebari_version} for "{self.name}" config set.'
)


class ConfigSet(BaseModel):
metadata: ConfigSetMetadata
config: dict


def read_config_set(config_set_filepath: str):
"""Read a config set from a config file."""

filename = pathlib.Path(config_set_filepath)

with filename.open() as f:
config_set_yaml = yaml.load(f)

config_set = ConfigSet(**config_set_yaml)

# validation
config_set.metadata.check_version(__version__)

return config_set
10 changes: 8 additions & 2 deletions src/_nebari/initialize.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,8 @@
import pydantic
import requests

from _nebari import constants
from _nebari import constants, utils
from _nebari.config_set import read_config_set
from _nebari.provider import git
from _nebari.provider.cicd import github
from _nebari.provider.cloud import amazon_web_services, azure_cloud, google_cloud
Expand Down Expand Up @@ -47,6 +48,7 @@ def render_config(
region: str = None,
disable_prompt: bool = False,
ssl_cert_email: str = None,
config_set: str = None,
) -> Dict[str, Any]:
config = {
"provider": cloud_provider,
Expand Down Expand Up @@ -176,13 +178,17 @@ def render_config(
config["certificate"] = {"type": CertificateEnum.letsencrypt.value}
config["certificate"]["acme_email"] = ssl_cert_email

if config_set:
config_set = read_config_set(config_set)
config = utils.deep_merge(config, config_set.config)

# validate configuration and convert to model
from nebari.plugins import nebari_plugin_manager

try:
config_model = nebari_plugin_manager.config_schema.model_validate(config)
except pydantic.ValidationError as e:
print(str(e))
raise e

if repository_auto_provision:
match = re.search(github_url_regex, repository)
Expand Down
18 changes: 17 additions & 1 deletion src/_nebari/stages/infrastructure/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import re
import sys
import tempfile
import warnings
from typing import Annotated, Any, Dict, List, Literal, Optional, Tuple, Type, Union

from pydantic import ConfigDict, Field, field_validator, model_validator
Expand Down Expand Up @@ -107,6 +108,7 @@ class AzureInputVars(schema.Base):
tags: Dict[str, str] = {}
max_pods: Optional[int] = None
network_profile: Optional[Dict[str, str]] = None
azure_policy_enabled: bool = None
workload_identity_enabled: bool = False


Expand Down Expand Up @@ -616,11 +618,23 @@ def check_provider(cls, data: Any) -> Any:
data[provider] = provider_enum_model_map[provider]()
else:
# if the provider field is invalid, it won't be set when this validator is called
# so we need to check for it explicitly here, and set the `pre` to True
# so we need to check for it explicitly here, and set mode to "before"
# TODO: this is a workaround, check if there is a better way to do this in Pydantic v2
raise ValueError(
f"'{provider}' is not a valid enumeration member; permitted: local, existing, aws, gcp, azure"
)
set_providers = {
provider
for provider in provider_name_abbreviation_map.keys()
if provider in data and data[provider]
}
expected_provider_config = provider_enum_name_map[provider]
extra_provider_config = set_providers - {expected_provider_config}
if extra_provider_config:
warnings.warn(
f"Provider is set to {getattr(provider, 'value', provider)}, but configuration defined for other providers: {extra_provider_config}"
)

else:
set_providers = [
provider
Expand All @@ -634,6 +648,7 @@ def check_provider(cls, data: Any) -> Any:
data["provider"] = provider_name_abbreviation_map[set_providers[0]]
elif num_providers == 0:
data["provider"] = schema.ProviderEnum.local.value

return data


Expand Down Expand Up @@ -828,6 +843,7 @@ def input_vars(self, stage_outputs: Dict[str, Dict[str, Any]]):
network_profile=self.config.azure.network_profile,
max_pods=self.config.azure.max_pods,
workload_identity_enabled=self.config.azure.workload_identity_enabled,
azure_policy_enabled=self.config.azure.azure_policy_enabled,
).model_dump()
elif self.config.provider == schema.ProviderEnum.aws:
return AWSInputVars(
Expand Down
1 change: 1 addition & 0 deletions src/_nebari/stages/infrastructure/template/azure/main.tf
Original file line number Diff line number Diff line change
Expand Up @@ -43,4 +43,5 @@ module "kubernetes" {
vnet_subnet_id = var.vnet_subnet_id
private_cluster_enabled = var.private_cluster_enabled
workload_identity_enabled = var.workload_identity_enabled
azure_policy_enabled = var.azure_policy_enabled
}
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,9 @@ resource "azurerm_kubernetes_cluster" "main" {
# Azure requires that a new, non-existent Resource Group is used, as otherwise the provisioning of the Kubernetes Service will fail.
node_resource_group = var.node_resource_group_name
private_cluster_enabled = var.private_cluster_enabled
# https://learn.microsoft.com/en-ie/azure/governance/policy/concepts/policy-for-kubernetes
azure_policy_enabled = var.azure_policy_enabled


dynamic "network_profile" {
for_each = var.network_profile != null ? [var.network_profile] : []
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -76,3 +76,9 @@ variable "workload_identity_enabled" {
type = bool
default = false
}

variable "azure_policy_enabled" {
description = "Enable Azure Policy"
type = bool
default = false
}
5 changes: 5 additions & 0 deletions src/_nebari/stages/infrastructure/template/azure/variables.tf
Original file line number Diff line number Diff line change
Expand Up @@ -82,3 +82,8 @@ variable "workload_identity_enabled" {
type = bool
default = false
}

variable "azure_policy_enabled" {
description = "Enable Azure Policy"
type = bool
}
15 changes: 13 additions & 2 deletions src/_nebari/subcommands/info.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,16 @@

@hookimpl
def nebari_subcommand(cli: typer.Typer):
EXTERNAL_PLUGIN_STYLE = "cyan"

@cli.command()
def info(ctx: typer.Context):
from nebari.plugins import nebari_plugin_manager

rich.print(f"Nebari version: {__version__}")

external_plugins = nebari_plugin_manager.get_external_plugins()

hooks = collections.defaultdict(list)
for plugin in nebari_plugin_manager.plugin_manager.get_plugins():
for hook in nebari_plugin_manager.plugin_manager.get_hookcallers(plugin):
Expand All @@ -27,7 +31,8 @@ def info(ctx: typer.Context):

for hook_name, modules in hooks.items():
for module in modules:
table.add_row(hook_name, module)
style = EXTERNAL_PLUGIN_STYLE if module in external_plugins else None
table.add_row(hook_name, module, style=style)

rich.print(table)

Expand All @@ -36,8 +41,14 @@ def info(ctx: typer.Context):
table.add_column("priority")
table.add_column("module")
for stage in nebari_plugin_manager.ordered_stages:
style = (
EXTERNAL_PLUGIN_STYLE if stage.__module__ in external_plugins else None
)
table.add_row(
stage.name, str(stage.priority), f"{stage.__module__}.{stage.__name__}"
stage.name,
str(stage.priority),
f"{stage.__module__}.{stage.__name__}",
style=style,
)

rich.print(table)
9 changes: 9 additions & 0 deletions src/_nebari/subcommands/init.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,7 @@ class InitInputs(schema.Base):
region: Optional[str] = None
ssl_cert_email: Optional[schema.email_pydantic] = None
disable_prompt: bool = False
config_set: Optional[str] = None
output: pathlib.Path = pathlib.Path("nebari-config.yaml")
explicit: int = 0

Expand Down Expand Up @@ -134,6 +135,7 @@ def handle_init(inputs: InitInputs, config_schema: BaseModel):
terraform_state=inputs.terraform_state,
ssl_cert_email=inputs.ssl_cert_email,
disable_prompt=inputs.disable_prompt,
config_set=inputs.config_set,
)

try:
Expand Down Expand Up @@ -496,6 +498,12 @@ def init(
False,
is_eager=True,
),
config_set: str = typer.Option(
None,
"--config-set",
"-s",
help="Apply a pre-defined set of nebari configuration options.",
),
output: str = typer.Option(
pathlib.Path("nebari-config.yaml"),
"--output",
Expand Down Expand Up @@ -554,6 +562,7 @@ def init(
inputs.terraform_state = terraform_state
inputs.ssl_cert_email = ssl_cert_email
inputs.disable_prompt = disable_prompt
inputs.config_set = config_set
inputs.output = output
inputs.explicit = explicit

Expand Down
42 changes: 42 additions & 0 deletions src/_nebari/subcommands/plugin.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
from importlib.metadata import version

import rich
import typer
from rich.table import Table

from nebari.hookspecs import hookimpl


@hookimpl
def nebari_subcommand(cli: typer.Typer):
plugin_cmd = typer.Typer(
add_completion=False,
no_args_is_help=True,
rich_markup_mode="rich",
context_settings={"help_option_names": ["-h", "--help"]},
)

cli.add_typer(
plugin_cmd,
name="plugin",
help="Interact with nebari plugins",
rich_help_panel="Additional Commands",
)

@plugin_cmd.command()
def list(ctx: typer.Context):
"""
List installed plugins
"""
from nebari.plugins import nebari_plugin_manager

external_plugins = nebari_plugin_manager.get_external_plugins()

table = Table(title="Plugins")
table.add_column("name", justify="left", no_wrap=True)
table.add_column("version", justify="left", no_wrap=True)

for plugin in external_plugins:
table.add_row(plugin, version(plugin))

rich.print(table)
4 changes: 2 additions & 2 deletions src/_nebari/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,7 @@ def modified_environ(*remove: List[str], **update: Dict[str, str]):


def deep_merge(*args):
"""Deep merge multiple dictionaries.
"""Deep merge multiple dictionaries. Preserves order in dicts and lists.
>>> value_1 = {
'a': [1, 2],
Expand Down Expand Up @@ -190,7 +190,7 @@ def deep_merge(*args):

if isinstance(d1, dict) and isinstance(d2, dict):
d3 = {}
for key in d1.keys() | d2.keys():
for key in tuple(d1.keys()) + tuple(d2.keys()):
if key in d1 and key in d2:
d3[key] = deep_merge(d1[key], d2[key])
elif key in d1:
Expand Down
9 changes: 9 additions & 0 deletions src/nebari/plugins.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
"_nebari.subcommands.deploy",
"_nebari.subcommands.destroy",
"_nebari.subcommands.keycloak",
"_nebari.subcommands.plugin",
"_nebari.subcommands.render",
"_nebari.subcommands.support",
"_nebari.subcommands.upgrade",
Expand Down Expand Up @@ -121,6 +122,14 @@ def read_config(self, config_path: typing.Union[str, Path], **kwargs):

return read_configuration(config_path, self.config_schema, **kwargs)

def get_external_plugins(self):
external_plugins = []
all_plugins = DEFAULT_SUBCOMMAND_PLUGINS + DEFAULT_STAGES_PLUGINS
for plugin in self.plugin_manager.get_plugins():
if plugin.__name__ not in all_plugins:
external_plugins.append(plugin.__name__)
return external_plugins

@property
def ordered_stages(self):
return self.get_available_stages()
Expand Down
Loading

0 comments on commit 84601d6

Please sign in to comment.