Skip to content

Commit

Permalink
#120 Changed arg_list creation
Browse files Browse the repository at this point in the history
  • Loading branch information
ahsimb committed Jun 27, 2024
1 parent 85a8860 commit bbcdd2e
Show file tree
Hide file tree
Showing 3 changed files with 27 additions and 25 deletions.
20 changes: 5 additions & 15 deletions tests/ci_tests/utils/parameters.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,19 +30,9 @@
)


def get_deploy_arg_list(deploy_params: dict[str, Any], schema: str) -> list[Any]:
"""
Creates a CLI parameter list to be used when calling the script deployment
command (see deployment/deploy_cli.py).
"""
args_list: list[Any] = []
for param_name, param_value in deploy_params.items():
args_list.append(f'--{param_name.replace("_", "-")}')
args_list.append(param_value)
args_list.extend(["--schema", schema])
# We validate the server certificate in SaaS, but not in the Docker DB
if "saas_url" in deploy_params:
args_list.append("--use-ssl-cert-validation")
else:
args_list.append("--no-use-ssl-cert-validation")
def get_arg_list(**kwargs) -> list[str]:
args_list: list[str] = []
for k, v in kwargs.items():
args_list.append(f'--{k.replace("_", "-")}')
args_list.append(str(v))
return args_list
11 changes: 8 additions & 3 deletions tests/deployment/test_deploy_cli.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
import pytest
from click.testing import CliRunner
import exasol.bucketfs as bfs

from exasol_sagemaker_extension.deployment import deploy_cli
from tests.ci_tests.utils.parameters import get_deploy_arg_list
from tests.ci_tests.utils.parameters import get_arg_list

DB_SCHEMA = "TEST_CLI_SCHEMA"
AUTOPILOT_TRAINING_LUA_SCRIPT_NAME = \
Expand Down Expand Up @@ -37,9 +38,13 @@ def get_all_scripts(db_conn):


@pytest.mark.slow
def test_deploy_cli_main(db_conn, deploy_params):
def test_deploy_cli_main(backend, db_conn, deploy_params):

args_list = get_deploy_arg_list(deploy_params, DB_SCHEMA)
args_list = get_arg_list(**deploy_params, schema=DB_SCHEMA)
if backend == bfs.path.StorageBackend.saas:
args_list.append("--use-ssl-cert-validation")
else:
args_list.append("--no-use-ssl-cert-validation")

runner = CliRunner()
result = runner.invoke(deploy_cli.main, args_list)
Expand Down
21 changes: 14 additions & 7 deletions tests/fixtures/prepare_environment_fixture.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,22 +7,27 @@

import boto3
import pyexasol
import exasol.bucketfs as bfs
import pytest
from click.testing import CliRunner

from exasol_sagemaker_extension.deployment import deploy_cli
from tests.ci_tests.utils.parameters import (
get_deploy_arg_list, reg_model_setup_params, cls_model_setup_params)
get_arg_list, reg_model_setup_params, cls_model_setup_params)


def __open_schema(db_conn: pyexasol.ExaConnection, model_setup):
query = "CREATE SCHEMA IF NOT EXISTS {schema_name}"
db_conn.execute(query.format(schema_name=model_setup.schema_name))


def __deploy_scripts(deploy_params: dict[str, Any], schema: str):
def __deploy_scripts(backend: bfs.path.StorageBackend, deploy_params: dict[str, Any], schema: str):

args_list = get_deploy_arg_list(deploy_params, schema)
args_list = get_arg_list(deploy_params, schema=schema)
if backend == bfs.path.StorageBackend.saas:
args_list.append("--use-ssl-cert-validation")
else:
args_list.append("--no-use-ssl-cert-validation")

runner = CliRunner()
runner.invoke(deploy_cli.main, args_list)
Expand All @@ -45,10 +50,11 @@ def __insert_into_tables(db_conn, model_setup):
db_conn.execute(query)


def _setup_database(db_conn: pyexasol.ExaConnection, deploy_params: dict[str, Any]):
def _setup_database(backend: bfs.path.StorageBackend, db_conn: pyexasol.ExaConnection,
deploy_params: dict[str, Any]):
for model_setup in [reg_model_setup_params, cls_model_setup_params]:
__open_schema(db_conn, model_setup)
__deploy_scripts(deploy_params, model_setup.schema_name)
__deploy_scripts(backend, deploy_params, model_setup.schema_name)
__create_tables(db_conn, model_setup)
__insert_into_tables(db_conn, model_setup)

Expand Down Expand Up @@ -195,12 +201,13 @@ def aws_bucket_uri(self) -> str:


@pytest.fixture(scope="session")
def prepare_ci_test_environment(db_conn,
def prepare_ci_test_environment(backend,
db_conn,
deploy_params,
aws_s3_bucket,
connection_object_for_aws_credentials,
aws_sagemaker_role) -> CITestEnvironment:
_setup_database(db_conn, deploy_params)
_setup_database(backend, db_conn, deploy_params)
yield CITestEnvironment(db_conn=db_conn,
aws_s3_bucket=aws_s3_bucket,
connection_object_for_aws_credentials=connection_object_for_aws_credentials,
Expand Down

0 comments on commit bbcdd2e

Please sign in to comment.