diff --git a/config.ini b/config.ini index 19f98ae4..5d450fad 100644 --- a/config.ini +++ b/config.ini @@ -2,3 +2,4 @@ ARGO_ACCESS_METHOD = access ARGO_HOST = http://argo-argo-workflows-server.argo.svc.cluster.local:2746 ARGO_NAMESPACE = argo +COHORT_DEFINITION_BY_SOURCE_AND_TEAM_PROJECT_URL = http://cohort-middleware-service/cohortdefinition-stats/by-source-id/{}/by-team-project?team-project={} diff --git a/src/argowrapper/auth/utils.py b/src/argowrapper/auth/utils.py new file mode 100644 index 00000000..6ecd6202 --- /dev/null +++ b/src/argowrapper/auth/utils.py @@ -0,0 +1,26 @@ +import requests +from argowrapper import logger +from argowrapper.constants import ( + COHORT_DEFINITION_BY_SOURCE_AND_TEAM_PROJECT_URL, +) + + +def get_cohort_ids_for_team_project(token, source_id, team_project): + header = {"Authorization": token, "cookie": "fence={}".format(token)} + url = COHORT_DEFINITION_BY_SOURCE_AND_TEAM_PROJECT_URL.format( + source_id, team_project + ) + try: + r = requests.get(url=url, headers=header) + r.raise_for_status() + team_cohort_info = r.json() + team_cohort_id_set = set() + if "cohort_definitions_and_stats" in team_cohort_info: + for t in team_cohort_info["cohort_definitions_and_stats"]: + if "cohort_definition_id" in t: + team_cohort_id_set.add(t["cohort_definition_id"]) + return team_cohort_id_set + except Exception as e: + exception = Exception("Could not get team project cohort ids", e) + logger.error(exception) + raise exception diff --git a/src/argowrapper/constants.py b/src/argowrapper/constants.py index cdacc614..25f43c08 100644 --- a/src/argowrapper/constants.py +++ b/src/argowrapper/constants.py @@ -21,6 +21,9 @@ logger.info(f"Access method: {config['DEFAULT']['ARGO_ACCESS_METHOD']}") ARGO_HOST: Final = config["DEFAULT"]["ARGO_HOST"] +COHORT_DEFINITION_BY_SOURCE_AND_TEAM_PROJECT_URL: Final = config["DEFAULT"][ + "COHORT_DEFINITION_BY_SOURCE_AND_TEAM_PROJECT_URL" +] TEST_WF: Final = "test.yaml" WF_HEADER: Final = "header.yaml" ARGO_NAMESPACE: Final = config["DEFAULT"]["ARGO_NAMESPACE"] diff --git a/src/argowrapper/routes/routes.py b/src/argowrapper/routes/routes.py index dfc1a7a5..3940b87e 100644 --- a/src/argowrapper/routes/routes.py +++ b/src/argowrapper/routes/routes.py @@ -9,6 +9,7 @@ HTTP_401_UNAUTHORIZED, HTTP_403_FORBIDDEN, HTTP_500_INTERNAL_SERVER_ERROR, + HTTP_400_BAD_REQUEST, ) from argowrapper.constants import ( TEAM_PROJECT_FIELD_NAME, @@ -21,6 +22,8 @@ from argowrapper import logger from argowrapper.auth import Auth from argowrapper.engine.argo_engine import ArgoEngine +from argowrapper.auth.utils import get_cohort_ids_for_team_project + import argowrapper.engine.helpers.argo_engine_helper as argo_engine_helper import requests @@ -144,6 +147,68 @@ def wrapper(*args, **kwargs): return wrapper +def check_team_projects_and_cohorts(fn): + """custom annotation to make sure cohort in request belong to user's team project""" + + @wraps(fn) + def wrapper(*args, **kwargs): + + token = kwargs["request"].headers.get("Authorization") + request_body = kwargs["request_body"] + team_project = request_body[TEAM_PROJECT_FIELD_NAME] + source_id = request_body["source_id"] + + # Construct set with all cohort ids requested + cohort_ids = [] + if "cohort_ids" in request_body["outcome"]: + cohort_ids.extend(request_body["outcome"]["cohort_ids"]) + + variables = request_body["variables"] + for v in variables: + if "cohort_ids" in v: + cohort_ids.extend(v["cohort_ids"]) + + if "source_population_cohort" in request_body: + cohort_ids.append(request_body["source_population_cohort"]) + + cohort_id_set = set(cohort_ids) + + if team_project and source_id and len(team_project) > 0 and len(cohort_ids) > 0: + # Get team project cohort ids + team_cohort_id_set = get_cohort_ids_for_team_project( + token, source_id, team_project + ) + + logger.debug("cohort ids are " + " ".join(str(c) for c in cohort_ids)) + logger.debug( + "team cohort ids are " + " ".join(str(c) for c in team_cohort_id_set) + ) + + # Compare the two sets + if cohort_id_set.issubset(team_cohort_id_set): + logger.debug( + "cohort ids submitted all belong to the same team project. Continue.." + ) + return fn(*args, **kwargs) + else: + logger.error( + "Cohort ids submitted do NOT all belong to the same team project." + ) + return HTMLResponse( + content="Cohort ids submitted do NOT all belong to the same team project.", + status_code=HTTP_400_BAD_REQUEST, + ) + + else: + # some required parameters is missing, return bad request: + return HTMLResponse( + content="Missing required parameters", + status_code=HTTP_400_BAD_REQUEST, + ) + + return wrapper + + def check_user_billing_id(request): """ Check whether user is non-VA user @@ -217,6 +282,7 @@ def test(): # submit argo workflow @router.post("/submit", status_code=HTTP_200_OK) @check_auth_and_team_project +@check_team_projects_and_cohorts def submit_workflow( request_body: Dict[Any, Any], request: Request, # pylint: disable=unused-argument diff --git a/test/test_routes.py b/test/test_routes.py index 324fb6f9..88a95eba 100644 --- a/test/test_routes.py +++ b/test/test_routes.py @@ -22,15 +22,30 @@ "variables": variables, "hare_population": "hare", "out_prefix": "vadc_genesis", - "outcome": 1, + "outcome": { + "variable_type": "custom_dichotomous", + "cohort_ids": [2], + "provided_name": "test Pheno", + }, "maf_threshold": 0.01, "imputation_score_cutoff": 0.3, "template_version": "gwas-template-latest", "source_id": 4, "case_cohort_definition_id": 70, "control_cohort_definition_id": -1, + "source_population_cohort": 4, "workflow_name": "wf_name", TEAM_PROJECT_FIELD_NAME: "dummy-team-project", + "user_tags": None, # For testing purpose +} + +cohort_definition_data = { + "cohort_definitions_and_stats": [ + {"cohort_definition_id": 1, "cohort_name": "Cohort 1", "size": 1}, + {"cohort_definition_id": 2, "cohort_name": "Cohort 2", "size": 2}, + {"cohort_definition_id": 3, "cohort_name": "Cohort 3", "size": 3}, + {"cohort_definition_id": 4, "cohort_name": "Cohort 4", "size": 4}, + ] } @@ -55,6 +70,36 @@ def client(app: FastAPI) -> Generator[TestClient, Any, None]: yield client +def mocked_requests_get(*args, **kwargs): + class MockResponse: + def __init__(self, json_data, status_code): + self.json_data = json_data + self.status_code = status_code + + def json(self): + return self.json_data + + def raise_for_status(self): + if self.status_code == 500: + raise Exception("fence is down") + if self.status_code != 200: + raise Exception() + + if ( + kwargs["url"] + == "http://cohort-middleware-service/cohortdefinition-stats/by-source-id/4/by-team-project?team-project=dummy-team-project" + ): + return MockResponse(cohort_definition_data, 200) + + if kwargs["url"] == "http://fence-service/user": + if data["user_tags"] != 500: + return MockResponse(data["user_tags"], 200) + else: + return MockResponse({}, 500) + + return None + + def test_submit_workflow(client): with patch("argowrapper.routes.routes.auth.authenticate") as mock_auth, patch( "argowrapper.routes.routes.argo_engine.workflow_submission" @@ -62,10 +107,13 @@ def test_submit_workflow(client): "argowrapper.routes.routes.log_auth_check_type" ) as mock_log, patch( "argowrapper.routes.routes.check_user_billing_id" - ) as mock_check_billing_id: + ) as mock_check_billing_id, patch( + "requests.get" + ) as mock_requests: mock_auth.return_value = True mock_engine.return_value = "workflow_123" mock_check_billing_id.return_value = None + mock_requests.side_effect = mocked_requests_get response = client.post( "/submit", @@ -466,59 +514,44 @@ def test_submit_workflow_with_user_billing_id(client): "argowrapper.routes.routes.argo_engine.workflow_submission" ) as mock_engine, patch( "argowrapper.routes.routes.log_auth_check_type" - ) as mock_log: + ) as mock_log, patch( + "requests.get" + ) as mock_requests: mock_auth.return_value = True mock_engine.return_value = "workflow_123" - with patch("requests.get") as mock_request: - mock_resp = mock.Mock() - mock_resp.status_code = 200 - mock_resp.raise_for_status = mock.Mock() - mock_resp.json = mock.Mock(return_value={"tags": {}}) - mock_request.return_value = mock_resp + mock_requests.side_effect = mocked_requests_get - response = client.post( - "/submit", - data=json.dumps(data), - headers={ - "Content-Type": "application/json", - "Authorization": EXAMPLE_AUTH_HEADER, - }, - ) - assert response.status_code == 200 - assert mock_engine.call_args.args[2] == None + data["user_tags"] = {"tags": {}} + response = client.post( + "/submit", + data=json.dumps(data), + headers={ + "Content-Type": "application/json", + "Authorization": EXAMPLE_AUTH_HEADER, + }, + ) + assert response.status_code == 200 + assert mock_engine.call_args.args[2] == None - mock_resp.json = mock.Mock(return_value={"tags": {"othertag1": "tag1"}}) + data["user_tags"] = {"tags": {"othertag1": "tag1"}} - response = client.post( - "/submit", - data=json.dumps(data), - headers={ - "Content-Type": "application/json", - "Authorization": EXAMPLE_AUTH_HEADER, - }, - ) - assert response.status_code == 200 - assert mock_engine.call_args.args[2] == None + response = client.post( + "/submit", + data=json.dumps(data), + headers={ + "Content-Type": "application/json", + "Authorization": EXAMPLE_AUTH_HEADER, + }, + ) + assert response.status_code == 200 + assert mock_engine.call_args.args[2] == None - mock_resp.json = mock.Mock( - return_value={"tags": {"othertag1": "tag1", "billing_id": "1234"}} - ) - with patch( - "argowrapper.routes.routes.check_user_reached_monthly_workflow_cap" - ) as mock_check_monthly_cap: - mock_check_monthly_cap.return_value = False - response = client.post( - "/submit", - data=json.dumps(data), - headers={ - "Content-Type": "application/json", - "Authorization": EXAMPLE_AUTH_HEADER, - }, - ) - assert mock_engine.call_args.args[2] == "1234" - - mock_resp.status_code == 500 - mock_resp.raise_for_status.side_effect = Exception("fence is down") + data["user_tags"] = {"tags": {"othertag1": "tag1", "billing_id": "1234"}} + + with patch( + "argowrapper.routes.routes.check_user_reached_monthly_workflow_cap" + ) as mock_check_monthly_cap: + mock_check_monthly_cap.return_value = False response = client.post( "/submit", data=json.dumps(data), @@ -527,8 +560,19 @@ def test_submit_workflow_with_user_billing_id(client): "Authorization": EXAMPLE_AUTH_HEADER, }, ) - assert response.status_code == 500 - assert "fence is down" in str(response.content) + assert mock_engine.call_args.args[2] == "1234" + + data["user_tags"] == 500 + + response = client.post( + "/submit", + data=json.dumps(data), + headers={ + "Content-Type": "application/json", + "Authorization": EXAMPLE_AUTH_HEADER, + }, + ) + assert response.status_code == 500 def test_check_user_reached_monthly_workflow_cap(): @@ -566,11 +610,14 @@ def test_submit_workflow_with_billing_id_and_over_monthly_cap(client): "argowrapper.routes.routes.check_user_billing_id" ) as mock_check_billing_id, patch( "argowrapper.routes.routes.check_user_reached_monthly_workflow_cap" - ) as mock_check_monthly_cap: + ) as mock_check_monthly_cap, patch( + "requests.get" + ) as mock_requests: mock_auth.return_value = True mock_engine.return_value = "workflow_123" mock_check_billing_id.return_value = "1234" mock_check_monthly_cap.return_value = True + mock_requests.side_effect = mocked_requests_get response = client.post( "/submit", @@ -581,3 +628,31 @@ def test_submit_workflow_with_billing_id_and_over_monthly_cap(client): }, ) assert response.status_code == 403 + + +def test_submit_workflow_with_non_team_project_cohort(client): + with patch("argowrapper.routes.routes.auth.authenticate") as mock_auth, patch( + "argowrapper.routes.routes.argo_engine.workflow_submission" + ) as mock_engine, patch( + "argowrapper.routes.routes.log_auth_check_type" + ) as mock_log, patch( + "argowrapper.routes.routes.check_user_billing_id" + ) as mock_check_billing_id, patch( + "requests.get" + ) as mock_requests: + mock_auth.return_value = True + mock_engine.return_value = "workflow_123" + mock_check_billing_id.return_value = None + mock_requests.side_effect = mocked_requests_get + + data["outcome"]["cohort_ids"] = [400] + + response = client.post( + "/submit", + data=json.dumps(data), + headers={ + "Content-Type": "application/json", + "Authorization": EXAMPLE_AUTH_HEADER, + }, + ) + assert response.status_code == 400