diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 7284f9f7bf..192d17db92 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -22,7 +22,7 @@ ci: repos: # general - repo: https://github.com/pre-commit/pre-commit-hooks - rev: v4.4.0 + rev: v4.5.0 hooks: - id: end-of-file-fixer exclude: "^docs-sphinx/cli.html" @@ -53,13 +53,13 @@ repos: # python - repo: https://github.com/psf/black - rev: 23.9.1 + rev: 23.10.1 hooks: - id: black args: ["--line-length=88", "--exclude=/src/_nebari/template/"] - repo: https://github.com/astral-sh/ruff-pre-commit - rev: v0.0.292 + rev: v0.1.4 hooks: - id: ruff args: ["--fix"] @@ -75,7 +75,7 @@ repos: # terraform - repo: https://github.com/antonbabenko/pre-commit-terraform - rev: v1.83.4 + rev: v1.83.5 hooks: - id: terraform_fmt args: diff --git a/pyproject.toml b/pyproject.toml index 09f5d0b756..29721f918f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -74,6 +74,7 @@ dependencies = [ "rich==13.5.1", "ruamel.yaml==0.17.32", "typer==0.9.0", + "packaging==23.2", ] [project.optional-dependencies] diff --git a/src/_nebari/provider/cloud/google_cloud.py b/src/_nebari/provider/cloud/google_cloud.py index 561c0a2ff9..010ec1c2c3 100644 --- a/src/_nebari/provider/cloud/google_cloud.py +++ b/src/_nebari/provider/cloud/google_cloud.py @@ -2,7 +2,7 @@ import json import os import subprocess -from typing import Dict, List +from typing import Dict, List, Set from _nebari import constants from _nebari.provider.cloud.commons import filter_by_highest_supported_k8s_version @@ -33,14 +33,14 @@ def projects() -> Dict[str, str]: @functools.lru_cache() -def regions(project: str) -> Dict[str, str]: - """Return a dict of available regions.""" +def regions() -> Set[str]: + """Return a set of available regions.""" check_credentials() output = subprocess.check_output( - ["gcloud", "compute", "regions", "list", "--project", project, "--format=json"] + ["gcloud", "compute", "regions", "list", "--format=json(name)"] ) - data = json.loads(output.decode("utf-8")) - return {_["description"]: _["name"] for _ in data} + data = json.loads(output) + return {_["name"] for _ in data} @functools.lru_cache() @@ -92,6 +92,22 @@ def instances(project: str) -> Dict[str, str]: return {_["description"]: _["name"] for _ in data} +def activated_services() -> Set[str]: + """Return a list of activated services.""" + check_credentials() + output = subprocess.check_output( + [ + "gcloud", + "services", + "list", + "--enabled", + "--format=json(config.title)", + ] + ) + data = json.loads(output) + return {service["config"]["title"] for service in data} + + def cluster_exists(cluster_name: str, project_id: str, region: str) -> bool: """Check if a GKE cluster exists.""" try: @@ -253,6 +269,26 @@ def gcp_cleanup(config: schema.Main): delete_service_account(service_account_name, project_id) +def check_missing_service() -> None: + """Check if all required services are activated.""" + required = { + "Compute Engine API", + "Kubernetes Engine API", + "Cloud Monitoring API", + "Cloud Autoscaling API", + "Identity and Access Management (IAM) API", + "Cloud Resource Manager API", + } + activated = activated_services() + common = required.intersection(activated) + missing = required.difference(common) + if missing: + raise ValueError( + f"""Missing required services: {missing}\n + Please see the documentation for more information: {constants.GCP_ENV_DOCS}""" + ) + + # Getting pricing data could come from here # https://cloudpricingcalculator.appspot.com/static/data/pricelist.json @@ -260,9 +296,9 @@ def gcp_cleanup(config: schema.Main): ### PYDANTIC VALIDATORS ### -def validate_region(project_id: str, region: str) -> str: +def validate_region(region: str) -> str: """Validate the GCP region is valid.""" - available_regions = regions(project_id) + available_regions = regions() if region not in available_regions: raise ValueError( f"Region {region} is not one of available regions {available_regions}" diff --git a/src/_nebari/subcommands/init.py b/src/_nebari/subcommands/init.py index e7c79aee88..7e0427511d 100644 --- a/src/_nebari/subcommands/init.py +++ b/src/_nebari/subcommands/init.py @@ -449,7 +449,7 @@ def check_cloud_provider_region(region: str, cloud_provider: str) -> str: if not region: region = GCP_DEFAULT_REGION rich.print(DEFAULT_REGION_MSG.format(region=region)) - if region not in google_cloud.regions(os.environ["PROJECT_ID"]): + if region not in google_cloud.regions(): raise ValueError( f"Invalid region `{region}`. Please refer to the GCP docs for a list of valid regions: {GCP_REGIONS}" ) diff --git a/src/_nebari/upgrade.py b/src/_nebari/upgrade.py index 4cf276dc26..896fab7236 100644 --- a/src/_nebari/upgrade.py +++ b/src/_nebari/upgrade.py @@ -5,6 +5,7 @@ import string from abc import ABC from pathlib import Path +from typing import Any, ClassVar, Dict import rich from pydantic import ValidationError @@ -80,17 +81,29 @@ def do_upgrade(config_filename, attempt_fixes=False): class UpgradeStep(ABC): - _steps = {} - - version = "" # Each subclass must have a version - these should be full release versions (not dev/prerelease) + _steps: ClassVar[Dict[str, Any]] = {} + version: ClassVar[str] = "" def __init_subclass__(cls): - assert cls.version != "" + try: + parsed_version = Version(cls.version) + except ValueError as exc: + raise ValueError(f"Invalid version string {cls.version}") from exc + + cls.parsed_version = parsed_version + assert ( + rounded_ver_parse(cls.version) == parsed_version + ), f"Invalid version {cls.version}: must be a full release version, not a dev/prerelease/postrelease version" assert ( cls.version not in cls._steps - ) # Would mean multiple upgrades for the same step + ), f"Duplicate UpgradeStep version {cls.version}" cls._steps[cls.version] = cls + @classmethod + def clear_steps_registry(cls): + """Clears the steps registry. Useful for testing.""" + cls._steps.clear() + @classmethod def has_step(cls, version): return version in cls._steps @@ -157,9 +170,7 @@ def upgrade_step(self, config, start_version, config_filename, *args, **kwargs): for any actions that are only required for the particular upgrade you are creating. """ finish_version = self.get_version() - __rounded_finish_version__ = ".".join( - [str(c) for c in rounded_ver_parse(finish_version)] - ) + __rounded_finish_version__ = str(rounded_ver_parse(finish_version)) rich.print( f"\n---> Starting upgrade from [green]{start_version or 'old version'}[/green] to [green]{finish_version}[/green]\n" ) @@ -636,7 +647,7 @@ def _version_specific_upgrade( return config -__rounded_version__ = ".".join([str(c) for c in rounded_ver_parse(__version__)]) +__rounded_version__ = str(rounded_ver_parse(__version__)) # Manually-added upgrade steps must go above this line if not UpgradeStep.has_step(__rounded_version__): diff --git a/src/_nebari/version.py b/src/_nebari/version.py index 7af6817cbe..fcfa649cec 100644 --- a/src/_nebari/version.py +++ b/src/_nebari/version.py @@ -1,26 +1,25 @@ """a backport for the nebari version references.""" -import re from importlib.metadata import distribution +from packaging.version import Version + __version__ = distribution("nebari").version -def rounded_ver_parse(versionstr): +def rounded_ver_parse(version: str) -> Version: """ - Take a package version string and return an int tuple of only (major,minor,patch), - ignoring and post/dev etc. + Rounds a version string to the nearest patch version. + + Parameters + ---------- + version : str + A version string. - So: - rounded_ver_parse("0.1.2") returns (0,1,2) - rounded_ver_parse("0.1.2.dev65+g2de53174") returns (0,1,2) - rounded_ver_parse("0.1") returns (0,1,0) + Returns + ------- + packaging.version.Version + A version object. """ - m = re.match( - "^(?P[0-9]+)(\\.(?P[0-9]+)(\\.(?P[0-9]+))?)?", versionstr - ) - assert m is not None - major = int(m.group("major") or 0) - minor = int(m.group("minor") or 0) - patch = int(m.group("patch") or 0) - return (major, minor, patch) + base_version = Version(version).base_version + return Version(base_version) diff --git a/tests/tests_unit/conftest.py b/tests/tests_unit/conftest.py index 4c1ed02bfe..9840fad7be 100644 --- a/tests/tests_unit/conftest.py +++ b/tests/tests_unit/conftest.py @@ -151,6 +151,18 @@ def nebari_render(nebari_config, nebari_stages, tmp_path): return tmp_path, config_filename +@pytest.fixture +def new_upgrade_cls(): + from _nebari.upgrade import UpgradeStep + + assert UpgradeStep._steps + steps_cache = UpgradeStep._steps.copy() + UpgradeStep.clear_steps_registry() + assert not UpgradeStep._steps + yield UpgradeStep + UpgradeStep._steps = steps_cache + + @pytest.fixture def config_schema(): return nebari_plugin_manager.config_schema diff --git a/tests/tests_unit/test_provider.py b/tests/tests_unit/test_provider.py new file mode 100644 index 0000000000..3c4f35a1d0 --- /dev/null +++ b/tests/tests_unit/test_provider.py @@ -0,0 +1,54 @@ +from contextlib import nullcontext + +import pytest + +from _nebari.provider.cloud.google_cloud import check_missing_service + + +@pytest.mark.parametrize( + "activated_services, exception", + [ + ( + { + "Compute Engine API", + "Kubernetes Engine API", + "Cloud Monitoring API", + "Cloud Autoscaling API", + "Identity and Access Management (IAM) API", + "Cloud Resource Manager API", + }, + nullcontext(), + ), + ( + { + "Compute Engine API", + "Kubernetes Engine API", + "Cloud Monitoring API", + "Cloud Autoscaling API", + "Identity and Access Management (IAM) API", + "Cloud Resource Manager API", + "Cloud SQL Admin API", + }, + nullcontext(), + ), + ( + { + "Compute Engine API", + "Kubernetes Engine API", + "Cloud Monitoring API", + "Cloud Autoscaling API", + "Cloud SQL Admin API", + }, + pytest.raises(ValueError, match=r"Missing required services:.*"), + ), + ], +) +def test_gcp_missing_service(monkeypatch, activated_services, exception): + def mock_return(): + return activated_services + + monkeypatch.setattr( + "_nebari.provider.cloud.google_cloud.activated_services", mock_return + ) + with exception: + check_missing_service() diff --git a/tests/tests_unit/test_upgrade.py b/tests/tests_unit/test_upgrade.py index 0946dcd99d..4871a1fe07 100644 --- a/tests/tests_unit/test_upgrade.py +++ b/tests/tests_unit/test_upgrade.py @@ -1,3 +1,4 @@ +from contextlib import nullcontext from pathlib import Path import pytest @@ -76,7 +77,7 @@ def test_upgrade_4_0( assert not hasattr(config.security, "users") assert not hasattr(config.security, "groups") - __rounded_version__ = ".".join([str(c) for c in rounded_ver_parse(__version__)]) + __rounded_version__ = rounded_ver_parse(__version__) # Check image versions have been bumped up assert ( @@ -99,3 +100,49 @@ def test_upgrade_4_0( tmp_qhub_config_backup = Path(tmp_path, f"{old_qhub_config_path.name}.old.backup") assert orig_contents == tmp_qhub_config_backup.read_text() + + +@pytest.mark.parametrize( + "version_str, exception", + [ + ("1.0.0", nullcontext()), + ("1.cool.0", pytest.raises(ValueError, match=r"Invalid version string .*")), + ("0,1.0", pytest.raises(ValueError, match=r"Invalid version string .*")), + ("", pytest.raises(ValueError, match=r"Invalid version string .*")), + ( + "1.0.0-rc1", + pytest.raises( + AssertionError, + match=r"Invalid version .*: must be a full release version, not a dev/prerelease/postrelease version", + ), + ), + ( + "1.0.0dev1", + pytest.raises( + AssertionError, + match=r"Invalid version .*: must be a full release version, not a dev/prerelease/postrelease version", + ), + ), + ], +) +def test_version_string(new_upgrade_cls, version_str, exception): + with exception: + + class DummyUpgrade(new_upgrade_cls): + version = version_str + + +def test_duplicated_version(new_upgrade_cls): + duplicated_version = "1.2.3" + with pytest.raises( + AssertionError, match=rf"Duplicate UpgradeStep version {duplicated_version}" + ): + + class DummyUpgrade(new_upgrade_cls): + version = duplicated_version + + class DummyUpgrade2(new_upgrade_cls): + version = duplicated_version + + class DummyUpgrade3(new_upgrade_cls): + version = "1.2.4"