Skip to content

Commit

Permalink
Merge branch 'develop' into pydantic-v2
Browse files Browse the repository at this point in the history
  • Loading branch information
fangchenli authored Nov 8, 2023
2 parents e4b458c + 748bb6a commit a443deb
Show file tree
Hide file tree
Showing 9 changed files with 199 additions and 39 deletions.
8 changes: 4 additions & 4 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ ci:
repos:
# general
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v4.4.0
rev: v4.5.0
hooks:
- id: end-of-file-fixer
exclude: "^docs-sphinx/cli.html"
Expand Down Expand Up @@ -53,13 +53,13 @@ repos:

# python
- repo: https://github.com/psf/black
rev: 23.9.1
rev: 23.10.1
hooks:
- id: black
args: ["--line-length=88", "--exclude=/src/_nebari/template/"]

- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.0.292
rev: v0.1.4
hooks:
- id: ruff
args: ["--fix"]
Expand All @@ -75,7 +75,7 @@ repos:

# terraform
- repo: https://github.com/antonbabenko/pre-commit-terraform
rev: v1.83.4
rev: v1.83.5
hooks:
- id: terraform_fmt
args:
Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@ dependencies = [
"rich==13.5.1",
"ruamel.yaml==0.17.32",
"typer==0.9.0",
"packaging==23.2",
]

[project.optional-dependencies]
Expand Down
52 changes: 44 additions & 8 deletions src/_nebari/provider/cloud/google_cloud.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import json
import os
import subprocess
from typing import Dict, List
from typing import Dict, List, Set

from _nebari import constants
from _nebari.provider.cloud.commons import filter_by_highest_supported_k8s_version
Expand Down Expand Up @@ -33,14 +33,14 @@ def projects() -> Dict[str, str]:


@functools.lru_cache()
def regions(project: str) -> Dict[str, str]:
"""Return a dict of available regions."""
def regions() -> Set[str]:
"""Return a set of available regions."""
check_credentials()
output = subprocess.check_output(
["gcloud", "compute", "regions", "list", "--project", project, "--format=json"]
["gcloud", "compute", "regions", "list", "--format=json(name)"]
)
data = json.loads(output.decode("utf-8"))
return {_["description"]: _["name"] for _ in data}
data = json.loads(output)
return {_["name"] for _ in data}


@functools.lru_cache()
Expand Down Expand Up @@ -92,6 +92,22 @@ def instances(project: str) -> Dict[str, str]:
return {_["description"]: _["name"] for _ in data}


def activated_services() -> Set[str]:
"""Return a list of activated services."""
check_credentials()
output = subprocess.check_output(
[
"gcloud",
"services",
"list",
"--enabled",
"--format=json(config.title)",
]
)
data = json.loads(output)
return {service["config"]["title"] for service in data}


def cluster_exists(cluster_name: str, project_id: str, region: str) -> bool:
"""Check if a GKE cluster exists."""
try:
Expand Down Expand Up @@ -253,16 +269,36 @@ def gcp_cleanup(config: schema.Main):
delete_service_account(service_account_name, project_id)


def check_missing_service() -> None:
"""Check if all required services are activated."""
required = {
"Compute Engine API",
"Kubernetes Engine API",
"Cloud Monitoring API",
"Cloud Autoscaling API",
"Identity and Access Management (IAM) API",
"Cloud Resource Manager API",
}
activated = activated_services()
common = required.intersection(activated)
missing = required.difference(common)
if missing:
raise ValueError(
f"""Missing required services: {missing}\n
Please see the documentation for more information: {constants.GCP_ENV_DOCS}"""
)


# Getting pricing data could come from here
# https://cloudpricingcalculator.appspot.com/static/data/pricelist.json


### PYDANTIC VALIDATORS ###


def validate_region(project_id: str, region: str) -> str:
def validate_region(region: str) -> str:
"""Validate the GCP region is valid."""
available_regions = regions(project_id)
available_regions = regions()
if region not in available_regions:
raise ValueError(
f"Region {region} is not one of available regions {available_regions}"
Expand Down
2 changes: 1 addition & 1 deletion src/_nebari/subcommands/init.py
Original file line number Diff line number Diff line change
Expand Up @@ -449,7 +449,7 @@ def check_cloud_provider_region(region: str, cloud_provider: str) -> str:
if not region:
region = GCP_DEFAULT_REGION
rich.print(DEFAULT_REGION_MSG.format(region=region))
if region not in google_cloud.regions(os.environ["PROJECT_ID"]):
if region not in google_cloud.regions():
raise ValueError(
f"Invalid region `{region}`. Please refer to the GCP docs for a list of valid regions: {GCP_REGIONS}"
)
Expand Down
29 changes: 20 additions & 9 deletions src/_nebari/upgrade.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import string
from abc import ABC
from pathlib import Path
from typing import Any, ClassVar, Dict

import rich
from pydantic import ValidationError
Expand Down Expand Up @@ -80,17 +81,29 @@ def do_upgrade(config_filename, attempt_fixes=False):


class UpgradeStep(ABC):
_steps = {}

version = "" # Each subclass must have a version - these should be full release versions (not dev/prerelease)
_steps: ClassVar[Dict[str, Any]] = {}
version: ClassVar[str] = ""

def __init_subclass__(cls):
assert cls.version != ""
try:
parsed_version = Version(cls.version)
except ValueError as exc:
raise ValueError(f"Invalid version string {cls.version}") from exc

cls.parsed_version = parsed_version
assert (
rounded_ver_parse(cls.version) == parsed_version
), f"Invalid version {cls.version}: must be a full release version, not a dev/prerelease/postrelease version"
assert (
cls.version not in cls._steps
) # Would mean multiple upgrades for the same step
), f"Duplicate UpgradeStep version {cls.version}"
cls._steps[cls.version] = cls

@classmethod
def clear_steps_registry(cls):
"""Clears the steps registry. Useful for testing."""
cls._steps.clear()

@classmethod
def has_step(cls, version):
return version in cls._steps
Expand Down Expand Up @@ -157,9 +170,7 @@ def upgrade_step(self, config, start_version, config_filename, *args, **kwargs):
for any actions that are only required for the particular upgrade you are creating.
"""
finish_version = self.get_version()
__rounded_finish_version__ = ".".join(
[str(c) for c in rounded_ver_parse(finish_version)]
)
__rounded_finish_version__ = str(rounded_ver_parse(finish_version))
rich.print(
f"\n---> Starting upgrade from [green]{start_version or 'old version'}[/green] to [green]{finish_version}[/green]\n"
)
Expand Down Expand Up @@ -636,7 +647,7 @@ def _version_specific_upgrade(
return config


__rounded_version__ = ".".join([str(c) for c in rounded_ver_parse(__version__)])
__rounded_version__ = str(rounded_ver_parse(__version__))

# Manually-added upgrade steps must go above this line
if not UpgradeStep.has_step(__rounded_version__):
Expand Down
31 changes: 15 additions & 16 deletions src/_nebari/version.py
Original file line number Diff line number Diff line change
@@ -1,26 +1,25 @@
"""a backport for the nebari version references."""

import re
from importlib.metadata import distribution

from packaging.version import Version

__version__ = distribution("nebari").version


def rounded_ver_parse(versionstr):
def rounded_ver_parse(version: str) -> Version:
"""
Take a package version string and return an int tuple of only (major,minor,patch),
ignoring and post/dev etc.
Rounds a version string to the nearest patch version.
Parameters
----------
version : str
A version string.
So:
rounded_ver_parse("0.1.2") returns (0,1,2)
rounded_ver_parse("0.1.2.dev65+g2de53174") returns (0,1,2)
rounded_ver_parse("0.1") returns (0,1,0)
Returns
-------
packaging.version.Version
A version object.
"""
m = re.match(
"^(?P<major>[0-9]+)(\\.(?P<minor>[0-9]+)(\\.(?P<patch>[0-9]+))?)?", versionstr
)
assert m is not None
major = int(m.group("major") or 0)
minor = int(m.group("minor") or 0)
patch = int(m.group("patch") or 0)
return (major, minor, patch)
base_version = Version(version).base_version
return Version(base_version)
12 changes: 12 additions & 0 deletions tests/tests_unit/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,18 @@ def nebari_render(nebari_config, nebari_stages, tmp_path):
return tmp_path, config_filename


@pytest.fixture
def new_upgrade_cls():
from _nebari.upgrade import UpgradeStep

assert UpgradeStep._steps
steps_cache = UpgradeStep._steps.copy()
UpgradeStep.clear_steps_registry()
assert not UpgradeStep._steps
yield UpgradeStep
UpgradeStep._steps = steps_cache


@pytest.fixture
def config_schema():
return nebari_plugin_manager.config_schema
54 changes: 54 additions & 0 deletions tests/tests_unit/test_provider.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
from contextlib import nullcontext

import pytest

from _nebari.provider.cloud.google_cloud import check_missing_service


@pytest.mark.parametrize(
"activated_services, exception",
[
(
{
"Compute Engine API",
"Kubernetes Engine API",
"Cloud Monitoring API",
"Cloud Autoscaling API",
"Identity and Access Management (IAM) API",
"Cloud Resource Manager API",
},
nullcontext(),
),
(
{
"Compute Engine API",
"Kubernetes Engine API",
"Cloud Monitoring API",
"Cloud Autoscaling API",
"Identity and Access Management (IAM) API",
"Cloud Resource Manager API",
"Cloud SQL Admin API",
},
nullcontext(),
),
(
{
"Compute Engine API",
"Kubernetes Engine API",
"Cloud Monitoring API",
"Cloud Autoscaling API",
"Cloud SQL Admin API",
},
pytest.raises(ValueError, match=r"Missing required services:.*"),
),
],
)
def test_gcp_missing_service(monkeypatch, activated_services, exception):
def mock_return():
return activated_services

monkeypatch.setattr(
"_nebari.provider.cloud.google_cloud.activated_services", mock_return
)
with exception:
check_missing_service()
49 changes: 48 additions & 1 deletion tests/tests_unit/test_upgrade.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from contextlib import nullcontext
from pathlib import Path

import pytest
Expand Down Expand Up @@ -76,7 +77,7 @@ def test_upgrade_4_0(
assert not hasattr(config.security, "users")
assert not hasattr(config.security, "groups")

__rounded_version__ = ".".join([str(c) for c in rounded_ver_parse(__version__)])
__rounded_version__ = rounded_ver_parse(__version__)

# Check image versions have been bumped up
assert (
Expand All @@ -99,3 +100,49 @@ def test_upgrade_4_0(
tmp_qhub_config_backup = Path(tmp_path, f"{old_qhub_config_path.name}.old.backup")

assert orig_contents == tmp_qhub_config_backup.read_text()


@pytest.mark.parametrize(
"version_str, exception",
[
("1.0.0", nullcontext()),
("1.cool.0", pytest.raises(ValueError, match=r"Invalid version string .*")),
("0,1.0", pytest.raises(ValueError, match=r"Invalid version string .*")),
("", pytest.raises(ValueError, match=r"Invalid version string .*")),
(
"1.0.0-rc1",
pytest.raises(
AssertionError,
match=r"Invalid version .*: must be a full release version, not a dev/prerelease/postrelease version",
),
),
(
"1.0.0dev1",
pytest.raises(
AssertionError,
match=r"Invalid version .*: must be a full release version, not a dev/prerelease/postrelease version",
),
),
],
)
def test_version_string(new_upgrade_cls, version_str, exception):
with exception:

class DummyUpgrade(new_upgrade_cls):
version = version_str


def test_duplicated_version(new_upgrade_cls):
duplicated_version = "1.2.3"
with pytest.raises(
AssertionError, match=rf"Duplicate UpgradeStep version {duplicated_version}"
):

class DummyUpgrade(new_upgrade_cls):
version = duplicated_version

class DummyUpgrade2(new_upgrade_cls):
version = duplicated_version

class DummyUpgrade3(new_upgrade_cls):
version = "1.2.4"

0 comments on commit a443deb

Please sign in to comment.