Skip to content

Commit

Permalink
address comments and add error handling
Browse files Browse the repository at this point in the history
  • Loading branch information
tianj7 committed Nov 10, 2023
1 parent 9a4d9f3 commit 9ab3654
Show file tree
Hide file tree
Showing 4 changed files with 127 additions and 68 deletions.
4 changes: 2 additions & 2 deletions src/argowrapper/engine/argo_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -482,7 +482,7 @@ def get_workflow_logs(self, workflow_name: str, uid: str) -> List[Dict]:
)

def workflow_submission(
self, request_body: Dict, auth_header: str, billing_id: str
self, request_body: Dict, auth_header: str, billing_id: str = None
):
workflow = WorkflowFactory._get_workflow(
ARGO_NAMESPACE, request_body, auth_header, WORKFLOW.GWAS
Expand All @@ -491,7 +491,7 @@ def workflow_submission(

# If billing_id exists for user, add it to workflow label and pod metadata
# remove gen3-username from pod metadata
if billing_id and len(billing_id) > 0:
if billing_id:
workflow_yaml["metadata"]["labels"]["gen3billing_id"] = billing_id
pod_labels = workflow_yaml["spec"]["podMetadata"]["labels"]
pod_labels["gen3billing_id"] = billing_id
Expand Down
21 changes: 14 additions & 7 deletions src/argowrapper/routes/routes.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,13 +151,18 @@ def check_user_billing_id(request):
remove gen3 username from pod metadata
"""

token = request.headers.get("Authorization")
username = argo_engine_helper.get_username_from_token(token)
header = {"Authorization", request.headers.get("Authorization")}
url = request.base_url._url.rstrip("/") + "/user/user"
params = {"username": username}
user_info = requests.get(url, params, auth=token)

if "billing_id" in user_info["tags"]:
try:
r = requests.get(url=url, headers=header)
r.raise_for_status()
user_info = r.json()
except Exception as e:
logger.error("Could not determine user info from fence")
logger.error(e)
raise

if "tags" in user_info and "billing_id" in user_info["tags"]:
billing_id = user_info["tags"]["billing_id"]
return billing_id
else:
Expand All @@ -177,8 +182,10 @@ def submit_workflow(
request_body: Dict[Any, Any],
request: Request, # pylint: disable=unused-argument
) -> str:
"""route to submit workflow"""
"""check if user has a billing id tag"""
billing_id = check_user_billing_id(request)

"""route to submit workflow"""
try:
return argo_engine.workflow_submission(
request_body, request.headers.get("Authorization"), billing_id
Expand Down
8 changes: 4 additions & 4 deletions test/test_argo_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ def test_argo_engine_submit_succeeded():
"variables": variables,
"team_project": "dummy-team-project",
}
result = engine.workflow_submission(parameters, EXAMPLE_AUTH_HEADER, None)
result = engine.workflow_submission(parameters, EXAMPLE_AUTH_HEADER)
assert "gwas" in result


Expand Down Expand Up @@ -103,7 +103,7 @@ def test_argo_engine_submit_failed():
"n_pcs": 100,
"template_version": "test",
}
engine.workflow_submission(parameters, EXAMPLE_AUTH_HEADER, None)
engine.workflow_submission(parameters, EXAMPLE_AUTH_HEADER)


def test_argo_engine_cancel_succeeded():
Expand Down Expand Up @@ -606,7 +606,7 @@ def test_argo_engine_submit_yaml_succeeded():
"argowrapper.engine.argo_engine.argo_engine_helper._get_argo_config_dict"
) as mock_config_dict:
mock_config_dict.return_value = config
engine.workflow_submission(input_parameters, EXAMPLE_AUTH_HEADER, None)
engine.workflow_submission(input_parameters, EXAMPLE_AUTH_HEADER)
args = engine.api_instance.create_workflow.call_args_list
for parameter in args[0][1]["body"]["workflow"]["spec"]["arguments"][
"parameters"
Expand Down Expand Up @@ -644,7 +644,7 @@ def test_argo_engine_new_submit_succeeded():
"argowrapper.engine.argo_engine.argo_engine_helper._get_argo_config_dict"
) as mock_config_dict:
mock_config_dict.return_value = config
res = engine.workflow_submission(request_body, EXAMPLE_AUTH_HEADER, None)
res = engine.workflow_submission(request_body, EXAMPLE_AUTH_HEADER)
assert len(res) > 0


Expand Down
162 changes: 107 additions & 55 deletions test/test_routes.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import json
from typing import Any, Generator
from unittest.mock import patch
from unittest import mock

import pytest
from fastapi import FastAPI
Expand All @@ -9,6 +10,28 @@
from test.constants import EXAMPLE_AUTH_HEADER
from argowrapper.routes.routes import router

variables = [
{"variable_type": "concept", "concept_id": "2000000324"},
{"variable_type": "concept", "concept_id": "2000000123"},
{"variable_type": "custom_dichotomous", "cohort_ids": [1, 3]},
]

data = {
"n_pcs": 3,
"variables": variables,
"hare_population": "hare",
"out_prefix": "vadc_genesis",
"outcome": 1,
"maf_threshold": 0.01,
"imputation_score_cutoff": 0.3,
"template_version": "gwas-template-latest",
"source_id": 4,
"case_cohort_definition_id": 70,
"control_cohort_definition_id": -1,
"workflow_name": "wf_name",
TEAM_PROJECT_FIELD_NAME: "dummy-team-project",
}


def start_application():
app = FastAPI()
Expand All @@ -32,68 +55,31 @@ def client(app: FastAPI) -> Generator[TestClient, Any, None]:


def test_submit_workflow(client):

variables = [
{"variable_type": "concept", "concept_id": "2000000324"},
{"variable_type": "concept", "concept_id": "2000000123"},
{"variable_type": "custom_dichotomous", "cohort_ids": [1, 3]},
]

data = {
"n_pcs": 3,
"variables": variables,
"hare_population": "hare",
"out_prefix": "vadc_genesis",
"outcome": 1,
"maf_threshold": 0.01,
"imputation_score_cutoff": 0.3,
"template_version": "gwas-template-latest",
"source_id": 4,
"case_cohort_definition_id": 70,
"control_cohort_definition_id": -1,
"workflow_name": "wf_name",
TEAM_PROJECT_FIELD_NAME: "dummy-team-project",
}

with patch("argowrapper.routes.routes.auth.authenticate") as mock_auth, patch(
"argowrapper.routes.routes.argo_engine.workflow_submission"
) as mock_engine, patch(
"argowrapper.routes.routes.log_auth_check_type"
) as mock_log:
) as mock_log, patch(
"argowrapper.routes.routes.check_user_billing_id"
) as mock_check_billing_id:
mock_auth.return_value = True
mock_engine.return_value = "workflow_123"
with patch("requests.get") as mock_request:
url = "http://testserver/user/user"
mock_request.return_value.status_code = 200
mock_request.return_value = {"tags": {}}

response = client.post(
"/submit",
data=json.dumps(data),
headers={
"Content-Type": "application/json",
"Authorization": EXAMPLE_AUTH_HEADER,
},
)
assert response.status_code == 200
assert response.content.decode("utf-8") == '"workflow_123"'
mock_auth.assert_called_with(
token=EXAMPLE_AUTH_HEADER, team_project="dummy-team-project"
)
mock_log.assert_called_with("check_auth_and_team_project")
# No billing Id for this test call
assert mock_engine.call_args.args[2] == None
mock_check_billing_id.return_value = None

mock_request.return_value = {"tags": {"billing_id": "1234"}}
response = client.post(
"/submit",
data=json.dumps(data),
headers={
"Content-Type": "application/json",
"Authorization": EXAMPLE_AUTH_HEADER,
},
)
assert mock_engine.call_args.args[2] == "1234"
response = client.post(
"/submit",
data=json.dumps(data),
headers={
"Content-Type": "application/json",
"Authorization": EXAMPLE_AUTH_HEADER,
},
)
assert response.status_code == 200
assert response.content.decode("utf-8") == '"workflow_123"'
mock_auth.assert_called_with(
token=EXAMPLE_AUTH_HEADER, team_project="dummy-team-project"
)
mock_log.assert_called_with("check_auth_and_team_project")


def test_submit_workflow_missing_team_project(client):
Expand Down Expand Up @@ -472,3 +458,69 @@ def test_if_endpoints_are_set_to_the_right_check_auth(client):

client.get("/logs/workflow_123?uid=workflow_uid")
mock_log.assert_called_with("check_auth")


def test_check_user_billing_id(client):
with patch("argowrapper.routes.routes.auth.authenticate") as mock_auth, patch(
"argowrapper.routes.routes.argo_engine.workflow_submission"
) as mock_engine, patch(
"argowrapper.routes.routes.log_auth_check_type"
) as mock_log:
mock_auth.return_value = True
mock_engine.return_value = "workflow_123"
with patch("requests.get") as mock_request:
mock_resp = mock.Mock()
mock_resp.status_code = 200
mock_resp.raise_for_status = mock.Mock()
mock_resp.json = mock.Mock(return_value={"tags": {}})
mock_request.return_value = mock_resp

response = client.post(
"/submit",
data=json.dumps(data),
headers={
"Content-Type": "application/json",
"Authorization": EXAMPLE_AUTH_HEADER,
},
)
assert response.status_code == 200
assert mock_engine.call_args.args[2] == None

mock_resp.json = mock.Mock(return_value={"tags": {"othertag1": "tag1"}})

response = client.post(
"/submit",
data=json.dumps(data),
headers={
"Content-Type": "application/json",
"Authorization": EXAMPLE_AUTH_HEADER,
},
)
assert response.status_code == 200
assert mock_engine.call_args.args[2] == None

mock_resp.json = mock.Mock(
return_value={"tags": {"othertag1": "tag1", "billing_id": "1234"}}
)

response = client.post(
"/submit",
data=json.dumps(data),
headers={
"Content-Type": "application/json",
"Authorization": EXAMPLE_AUTH_HEADER,
},
)
assert mock_engine.call_args.args[2] == "1234"

mock_resp.status_code == 500
mock_resp.raise_for_status.side_effect = Exception("fence is down")
with pytest.raises(Exception):
response = client.post(
"/submit",
data=json.dumps(data),
headers={
"Content-Type": "application/json",
"Authorization": EXAMPLE_AUTH_HEADER,
},
)

0 comments on commit 9ab3654

Please sign in to comment.