From 9ab3654d1f8f8f1c8fee7ccd90d787a12005c559 Mon Sep 17 00:00:00 2001 From: tianj7 Date: Fri, 10 Nov 2023 16:19:39 -0600 Subject: [PATCH] address comments and add error handling --- src/argowrapper/engine/argo_engine.py | 4 +- src/argowrapper/routes/routes.py | 21 ++-- test/test_argo_engine.py | 8 +- test/test_routes.py | 162 +++++++++++++++++--------- 4 files changed, 127 insertions(+), 68 deletions(-) diff --git a/src/argowrapper/engine/argo_engine.py b/src/argowrapper/engine/argo_engine.py index fdb85af..ea87114 100644 --- a/src/argowrapper/engine/argo_engine.py +++ b/src/argowrapper/engine/argo_engine.py @@ -482,7 +482,7 @@ def get_workflow_logs(self, workflow_name: str, uid: str) -> List[Dict]: ) def workflow_submission( - self, request_body: Dict, auth_header: str, billing_id: str + self, request_body: Dict, auth_header: str, billing_id: str = None ): workflow = WorkflowFactory._get_workflow( ARGO_NAMESPACE, request_body, auth_header, WORKFLOW.GWAS @@ -491,7 +491,7 @@ def workflow_submission( # If billing_id exists for user, add it to workflow label and pod metadata # remove gen3-username from pod metadata - if billing_id and len(billing_id) > 0: + if billing_id: workflow_yaml["metadata"]["labels"]["gen3billing_id"] = billing_id pod_labels = workflow_yaml["spec"]["podMetadata"]["labels"] pod_labels["gen3billing_id"] = billing_id diff --git a/src/argowrapper/routes/routes.py b/src/argowrapper/routes/routes.py index 7872202..c094049 100644 --- a/src/argowrapper/routes/routes.py +++ b/src/argowrapper/routes/routes.py @@ -151,13 +151,18 @@ def check_user_billing_id(request): remove gen3 username from pod metadata """ - token = request.headers.get("Authorization") - username = argo_engine_helper.get_username_from_token(token) + header = {"Authorization", request.headers.get("Authorization")} url = request.base_url._url.rstrip("/") + "/user/user" - params = {"username": username} - user_info = requests.get(url, params, auth=token) - - if "billing_id" in user_info["tags"]: + try: + r = requests.get(url=url, headers=header) + r.raise_for_status() + user_info = r.json() + except Exception as e: + logger.error("Could not determine user info from fence") + logger.error(e) + raise + + if "tags" in user_info and "billing_id" in user_info["tags"]: billing_id = user_info["tags"]["billing_id"] return billing_id else: @@ -177,8 +182,10 @@ def submit_workflow( request_body: Dict[Any, Any], request: Request, # pylint: disable=unused-argument ) -> str: - """route to submit workflow""" + """check if user has a billing id tag""" billing_id = check_user_billing_id(request) + + """route to submit workflow""" try: return argo_engine.workflow_submission( request_body, request.headers.get("Authorization"), billing_id diff --git a/test/test_argo_engine.py b/test/test_argo_engine.py index 2da1dd9..cbdcfbd 100644 --- a/test/test_argo_engine.py +++ b/test/test_argo_engine.py @@ -46,7 +46,7 @@ def test_argo_engine_submit_succeeded(): "variables": variables, "team_project": "dummy-team-project", } - result = engine.workflow_submission(parameters, EXAMPLE_AUTH_HEADER, None) + result = engine.workflow_submission(parameters, EXAMPLE_AUTH_HEADER) assert "gwas" in result @@ -103,7 +103,7 @@ def test_argo_engine_submit_failed(): "n_pcs": 100, "template_version": "test", } - engine.workflow_submission(parameters, EXAMPLE_AUTH_HEADER, None) + engine.workflow_submission(parameters, EXAMPLE_AUTH_HEADER) def test_argo_engine_cancel_succeeded(): @@ -606,7 +606,7 @@ def test_argo_engine_submit_yaml_succeeded(): "argowrapper.engine.argo_engine.argo_engine_helper._get_argo_config_dict" ) as mock_config_dict: mock_config_dict.return_value = config - engine.workflow_submission(input_parameters, EXAMPLE_AUTH_HEADER, None) + engine.workflow_submission(input_parameters, EXAMPLE_AUTH_HEADER) args = engine.api_instance.create_workflow.call_args_list for parameter in args[0][1]["body"]["workflow"]["spec"]["arguments"][ "parameters" @@ -644,7 +644,7 @@ def test_argo_engine_new_submit_succeeded(): "argowrapper.engine.argo_engine.argo_engine_helper._get_argo_config_dict" ) as mock_config_dict: mock_config_dict.return_value = config - res = engine.workflow_submission(request_body, EXAMPLE_AUTH_HEADER, None) + res = engine.workflow_submission(request_body, EXAMPLE_AUTH_HEADER) assert len(res) > 0 diff --git a/test/test_routes.py b/test/test_routes.py index ed2d500..8e4407c 100644 --- a/test/test_routes.py +++ b/test/test_routes.py @@ -1,6 +1,7 @@ import json from typing import Any, Generator from unittest.mock import patch +from unittest import mock import pytest from fastapi import FastAPI @@ -9,6 +10,28 @@ from test.constants import EXAMPLE_AUTH_HEADER from argowrapper.routes.routes import router +variables = [ + {"variable_type": "concept", "concept_id": "2000000324"}, + {"variable_type": "concept", "concept_id": "2000000123"}, + {"variable_type": "custom_dichotomous", "cohort_ids": [1, 3]}, +] + +data = { + "n_pcs": 3, + "variables": variables, + "hare_population": "hare", + "out_prefix": "vadc_genesis", + "outcome": 1, + "maf_threshold": 0.01, + "imputation_score_cutoff": 0.3, + "template_version": "gwas-template-latest", + "source_id": 4, + "case_cohort_definition_id": 70, + "control_cohort_definition_id": -1, + "workflow_name": "wf_name", + TEAM_PROJECT_FIELD_NAME: "dummy-team-project", +} + def start_application(): app = FastAPI() @@ -32,68 +55,31 @@ def client(app: FastAPI) -> Generator[TestClient, Any, None]: def test_submit_workflow(client): - - variables = [ - {"variable_type": "concept", "concept_id": "2000000324"}, - {"variable_type": "concept", "concept_id": "2000000123"}, - {"variable_type": "custom_dichotomous", "cohort_ids": [1, 3]}, - ] - - data = { - "n_pcs": 3, - "variables": variables, - "hare_population": "hare", - "out_prefix": "vadc_genesis", - "outcome": 1, - "maf_threshold": 0.01, - "imputation_score_cutoff": 0.3, - "template_version": "gwas-template-latest", - "source_id": 4, - "case_cohort_definition_id": 70, - "control_cohort_definition_id": -1, - "workflow_name": "wf_name", - TEAM_PROJECT_FIELD_NAME: "dummy-team-project", - } - with patch("argowrapper.routes.routes.auth.authenticate") as mock_auth, patch( "argowrapper.routes.routes.argo_engine.workflow_submission" ) as mock_engine, patch( "argowrapper.routes.routes.log_auth_check_type" - ) as mock_log: + ) as mock_log, patch( + "argowrapper.routes.routes.check_user_billing_id" + ) as mock_check_billing_id: mock_auth.return_value = True mock_engine.return_value = "workflow_123" - with patch("requests.get") as mock_request: - url = "http://testserver/user/user" - mock_request.return_value.status_code = 200 - mock_request.return_value = {"tags": {}} - - response = client.post( - "/submit", - data=json.dumps(data), - headers={ - "Content-Type": "application/json", - "Authorization": EXAMPLE_AUTH_HEADER, - }, - ) - assert response.status_code == 200 - assert response.content.decode("utf-8") == '"workflow_123"' - mock_auth.assert_called_with( - token=EXAMPLE_AUTH_HEADER, team_project="dummy-team-project" - ) - mock_log.assert_called_with("check_auth_and_team_project") - # No billing Id for this test call - assert mock_engine.call_args.args[2] == None + mock_check_billing_id.return_value = None - mock_request.return_value = {"tags": {"billing_id": "1234"}} - response = client.post( - "/submit", - data=json.dumps(data), - headers={ - "Content-Type": "application/json", - "Authorization": EXAMPLE_AUTH_HEADER, - }, - ) - assert mock_engine.call_args.args[2] == "1234" + response = client.post( + "/submit", + data=json.dumps(data), + headers={ + "Content-Type": "application/json", + "Authorization": EXAMPLE_AUTH_HEADER, + }, + ) + assert response.status_code == 200 + assert response.content.decode("utf-8") == '"workflow_123"' + mock_auth.assert_called_with( + token=EXAMPLE_AUTH_HEADER, team_project="dummy-team-project" + ) + mock_log.assert_called_with("check_auth_and_team_project") def test_submit_workflow_missing_team_project(client): @@ -472,3 +458,69 @@ def test_if_endpoints_are_set_to_the_right_check_auth(client): client.get("/logs/workflow_123?uid=workflow_uid") mock_log.assert_called_with("check_auth") + + +def test_check_user_billing_id(client): + with patch("argowrapper.routes.routes.auth.authenticate") as mock_auth, patch( + "argowrapper.routes.routes.argo_engine.workflow_submission" + ) as mock_engine, patch( + "argowrapper.routes.routes.log_auth_check_type" + ) as mock_log: + mock_auth.return_value = True + mock_engine.return_value = "workflow_123" + with patch("requests.get") as mock_request: + mock_resp = mock.Mock() + mock_resp.status_code = 200 + mock_resp.raise_for_status = mock.Mock() + mock_resp.json = mock.Mock(return_value={"tags": {}}) + mock_request.return_value = mock_resp + + response = client.post( + "/submit", + data=json.dumps(data), + headers={ + "Content-Type": "application/json", + "Authorization": EXAMPLE_AUTH_HEADER, + }, + ) + assert response.status_code == 200 + assert mock_engine.call_args.args[2] == None + + mock_resp.json = mock.Mock(return_value={"tags": {"othertag1": "tag1"}}) + + response = client.post( + "/submit", + data=json.dumps(data), + headers={ + "Content-Type": "application/json", + "Authorization": EXAMPLE_AUTH_HEADER, + }, + ) + assert response.status_code == 200 + assert mock_engine.call_args.args[2] == None + + mock_resp.json = mock.Mock( + return_value={"tags": {"othertag1": "tag1", "billing_id": "1234"}} + ) + + response = client.post( + "/submit", + data=json.dumps(data), + headers={ + "Content-Type": "application/json", + "Authorization": EXAMPLE_AUTH_HEADER, + }, + ) + assert mock_engine.call_args.args[2] == "1234" + + mock_resp.status_code == 500 + mock_resp.raise_for_status.side_effect = Exception("fence is down") + with pytest.raises(Exception): + response = client.post( + "/submit", + data=json.dumps(data), + headers={ + "Content-Type": "application/json", + "Authorization": EXAMPLE_AUTH_HEADER, + }, + )