Skip to content

Commit

Permalink
Add more unit tests, add cleanup step for Digital Ocean integration t…
Browse files Browse the repository at this point in the history
…est (#1910)

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
iameskild and pre-commit-ci[bot] authored Aug 21, 2023
1 parent 42323d8 commit 735a0ca
Show file tree
Hide file tree
Showing 9 changed files with 351 additions and 47 deletions.
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,8 @@ nebari-config.yaml
.ipynb_checkpoints
.DS_Store
/.ruff_cache
.coverage


# Integration tests deployments
_test_deploy
32 changes: 13 additions & 19 deletions src/_nebari/config.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import os
import pathlib
import re
import sys
import typing

import pydantic
Expand All @@ -10,32 +12,23 @@
def set_nested_attribute(data: typing.Any, attrs: typing.List[str], value: typing.Any):
"""Takes an arbitrary set of attributes and accesses the deep
nested object config to set value
"""

def _get_attr(d: typing.Any, attr: str):
if hasattr(d, "__getitem__"):
if re.fullmatch(r"\d+", attr):
try:
return d[int(attr)]
except Exception:
return d[attr]
else:
return d[attr]
if isinstance(d, list) and re.fullmatch(r"\d+", attr):
return d[int(attr)]
elif hasattr(d, "__getitem__"):
return d[attr]
else:
return getattr(d, attr)

def _set_attr(d: typing.Any, attr: str, value: typing.Any):
if hasattr(d, "__getitem__"):
if re.fullmatch(r"\d+", attr):
try:
d[int(attr)] = value
except Exception:
d[attr] = value
else:
d[attr] = value
if isinstance(d, list) and re.fullmatch(r"\d+", attr):
d[int(attr)] = value
elif hasattr(d, "__getitem__"):
d[attr] = value
else:
return setattr(d, attr, value)
setattr(d, attr, value)

data_pos = data
for attr in attrs[:-1]:
Expand Down Expand Up @@ -68,7 +61,7 @@ def read_configuration(
config_schema: pydantic.BaseModel,
read_environment: bool = True,
):
"""Read configuration from multiple sources and apply validation"""
"""Read the nebari configuration from disk and apply validation"""
filename = pathlib.Path(config_filename)

if not filename.is_file():
Expand All @@ -90,6 +83,7 @@ def write_configuration(
config: typing.Union[pydantic.BaseModel, typing.Dict],
mode: str = "w",
):
"""Write the nebari configuration file to disk"""
with config_filename.open(mode) as f:
if isinstance(config, pydantic.BaseModel):
yaml.dump(config.dict(), f)
Expand Down
8 changes: 4 additions & 4 deletions src/_nebari/initialize.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ def render_config(
ssl_cert_email: str = None,
):
config = {
"provider": cloud_provider.value,
"provider": cloud_provider,
"namespace": namespace,
"nebari_version": __version__,
}
Expand All @@ -51,8 +51,8 @@ def render_config(
if nebari_domain is not None:
config["domain"] = nebari_domain

config["ci_cd"] = {"type": ci_provider.value}
config["terraform_state"] = {"type": terraform_state.value}
config["ci_cd"] = {"type": ci_provider}
config["terraform_state"] = {"type": terraform_state}

# Save default password to file
default_password_filename = Path(tempfile.gettempdir()) / "NEBARI_DEFAULT_PASSWORD"
Expand All @@ -68,7 +68,7 @@ def render_config(
"welcome"
] = """Welcome! Learn about Nebari's features and configurations in <a href="https://www.nebari.dev/docs">the documentation</a>. If you have any questions or feedback, reach the team on <a href="https://www.nebari.dev/docs/community#getting-support">Nebari's support forums</a>."""

config["security"]["authentication"] = {"type": auth_provider.value}
config["security"]["authentication"] = {"type": auth_provider}
if auth_provider == AuthenticationEnum.github:
if not disable_prompt:
config["security"]["authentication"]["config"] = {
Expand Down
58 changes: 58 additions & 0 deletions src/_nebari/provider/cloud/amazon_web_services.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
import functools
import os
import time

import boto3
from botocore.exceptions import ClientError

from _nebari import constants
from _nebari.provider.cloud.commons import filter_by_highest_supported_k8s_version
Expand Down Expand Up @@ -61,3 +63,59 @@ def instances():
[j["InstanceType"] for i in paginator.paginate() for j in i["InstanceTypes"]]
)
return {t: t for t in instance_types}


def aws_session(region: str, digitalocean: bool = False):
if digitalocean:
aws_access_key_id = os.environ["SPACES_ACCESS_KEY_ID"]
aws_secret_access_key = os.environ["SPACES_SECRET_ACCESS_KEY"]
else:
aws_access_key_id = os.environ["AWS_ACCESS_KEY_ID"]
aws_secret_access_key = os.environ["AWS_SECRET_ACCESS_KEY"]

return boto3.session.Session(
region_name=region,
aws_access_key_id=aws_access_key_id,
aws_secret_access_key=aws_secret_access_key,
)


def delete_aws_s3_bucket(
bucket_name: str,
region: str,
endpoint: str = None,
digitalocean: bool = False,
):
MAX_RETRIES = 5
DELAY = 5

session = aws_session(region=region, digitalocean=digitalocean)
s3 = session.resource("s3", endpoint_url=endpoint)
try:
bucket = s3.Bucket(bucket_name)

for obj in bucket.objects.all():
obj.delete()

for obj_version in bucket.object_versions.all():
obj_version.delete()

except ClientError as e:
if "NoSuchBucket" in str(e):
print(f"Bucket {bucket_name} does not exist. Skipping...")
return
else:
raise e

for i in range(MAX_RETRIES):
try:
bucket.delete()
print(f"Successfully deleted bucket {bucket_name}")
return
except ClientError as e:
if "BucketNotEmpty" in str(e):
print(f"Bucket is not yet empty. Retrying in {DELAY} seconds...")
time.sleep(DELAY)
else:
raise e
print(f"Failed to delete bucket {bucket_name} after {MAX_RETRIES} retries.")
62 changes: 62 additions & 0 deletions src/_nebari/provider/cloud/digital_ocean.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,15 @@
import functools
import os
import tempfile
import typing

import requests
from kubernetes import client, config

from _nebari import constants
from _nebari.provider.cloud.amazon_web_services import delete_aws_s3_bucket
from _nebari.provider.cloud.commons import filter_by_highest_supported_k8s_version
from _nebari.utils import set_do_environment


def check_credentials():
Expand Down Expand Up @@ -34,6 +38,7 @@ def digital_ocean_request(url, method="GET", json=None):

method_map = {
"GET": requests.get,
"DELETE": requests.delete,
}

response = method_map[method](
Expand Down Expand Up @@ -65,3 +70,60 @@ def kubernetes_versions(region) -> typing.List[str]:
supported_kubernetes_versions
)
return [f"{v}-do.0" for v in filtered_versions]


def digital_ocean_get_cluster_id(cluster_name):
clusters = digital_ocean_request("kubernetes/clusters").json()[
"kubernetes_clusters"
]

cluster_id = None
for cluster in clusters:
if cluster["name"] == cluster_name:
cluster_id = cluster["id"]
break

if not cluster_id:
raise ValueError(f"Cluster {cluster_name} not found")

return cluster_id


def digital_ocean_get_kubeconfig(cluster_name: str):
cluster_id = digital_ocean_get_cluster_id(cluster_name)

kubeconfig_content = digital_ocean_request(
f"kubernetes/clusters/{cluster_id}/kubeconfig"
).content

with tempfile.NamedTemporaryFile(delete=False) as temp_kubeconfig:
temp_kubeconfig.write(kubeconfig_content)

return temp_kubeconfig.name


def digital_ocean_delete_kubernetes_cluster(cluster_name: str):
cluster_id = digital_ocean_get_cluster_id(cluster_name)
digital_ocean_request(f"kubernetes/clusters/{cluster_id}", method="DELETE")


def digital_ocean_cleanup(name: str, namespace: str, region: str):
cluster_name = f"{name}-{namespace}"
tf_state_bucket = f"{cluster_name}-terraform-state"
do_spaces_endpoint = "https://nyc3.digitaloceanspaces.com"

config.load_kube_config(digital_ocean_get_kubeconfig(cluster_name))
api = client.CoreV1Api()

labels = {"component": "singleuser-server", "app": "jupyterhub"}

api.delete_collection_namespaced_pod(
namespace=namespace,
label_selector=",".join([f"{k}={v}" for k, v in labels.items()]),
)

set_do_environment()
delete_aws_s3_bucket(
tf_state_bucket, region=region, digitalocean=True, endpoint=do_spaces_endpoint
)
digital_ocean_delete_kubernetes_cluster(cluster_name)
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ terraform {
required_providers {
digitalocean = {
source = "digitalocean/digitalocean"
version = "2.17.0"
version = "2.29.0"
}
}
required_version = ">= 1.0"
Expand Down
5 changes: 5 additions & 0 deletions src/_nebari/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -275,3 +275,8 @@ def is_relative_to(self: Path, other: Path, /) -> bool:
return True
except ValueError:
return False


def set_do_environment():
os.environ["AWS_ACCESS_KEY_ID"] = os.environ["SPACES_ACCESS_KEY_ID"]
os.environ["AWS_SECRET_ACCESS_KEY"] = os.environ["SPACES_SECRET_ACCESS_KEY"]
Loading

0 comments on commit 735a0ca

Please sign in to comment.