Skip to content

Commit

Permalink
Merge branch 'master' into feat/argo_ns_no_debug_prints
Browse files Browse the repository at this point in the history
  • Loading branch information
m0nhawk authored Feb 21, 2024
2 parents 8f4f732 + 7842c2f commit 8dcba15
Show file tree
Hide file tree
Showing 5 changed files with 223 additions and 52 deletions.
1 change: 1 addition & 0 deletions config.ini
Original file line number Diff line number Diff line change
Expand Up @@ -2,3 +2,4 @@
ARGO_ACCESS_METHOD = access
ARGO_HOST = http://argo-argo-workflows-server.argo.svc.cluster.local:2746
ARGO_NAMESPACE = argo
COHORT_DEFINITION_BY_SOURCE_AND_TEAM_PROJECT_URL = http://cohort-middleware-service/cohortdefinition-stats/by-source-id/{}/by-team-project?team-project={}
26 changes: 26 additions & 0 deletions src/argowrapper/auth/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
import requests
from argowrapper import logger
from argowrapper.constants import (
COHORT_DEFINITION_BY_SOURCE_AND_TEAM_PROJECT_URL,
)


def get_cohort_ids_for_team_project(token, source_id, team_project):
header = {"Authorization": token, "cookie": "fence={}".format(token)}
url = COHORT_DEFINITION_BY_SOURCE_AND_TEAM_PROJECT_URL.format(
source_id, team_project
)
try:
r = requests.get(url=url, headers=header)
r.raise_for_status()
team_cohort_info = r.json()
team_cohort_id_set = set()
if "cohort_definitions_and_stats" in team_cohort_info:
for t in team_cohort_info["cohort_definitions_and_stats"]:
if "cohort_definition_id" in t:
team_cohort_id_set.add(t["cohort_definition_id"])
return team_cohort_id_set
except Exception as e:
exception = Exception("Could not get team project cohort ids", e)
logger.error(exception)
raise exception
3 changes: 3 additions & 0 deletions src/argowrapper/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,9 @@
logger.info(f"Access method: {config['DEFAULT']['ARGO_ACCESS_METHOD']}")

ARGO_HOST: Final = config["DEFAULT"]["ARGO_HOST"]
COHORT_DEFINITION_BY_SOURCE_AND_TEAM_PROJECT_URL: Final = config["DEFAULT"][
"COHORT_DEFINITION_BY_SOURCE_AND_TEAM_PROJECT_URL"
]
TEST_WF: Final = "test.yaml"
WF_HEADER: Final = "header.yaml"
ARGO_NAMESPACE: Final = config["DEFAULT"]["ARGO_NAMESPACE"]
Expand Down
66 changes: 66 additions & 0 deletions src/argowrapper/routes/routes.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
HTTP_401_UNAUTHORIZED,
HTTP_403_FORBIDDEN,
HTTP_500_INTERNAL_SERVER_ERROR,
HTTP_400_BAD_REQUEST,
)
from argowrapper.constants import (
TEAM_PROJECT_FIELD_NAME,
Expand All @@ -21,6 +22,8 @@
from argowrapper import logger
from argowrapper.auth import Auth
from argowrapper.engine.argo_engine import ArgoEngine
from argowrapper.auth.utils import get_cohort_ids_for_team_project

import argowrapper.engine.helpers.argo_engine_helper as argo_engine_helper

import requests
Expand Down Expand Up @@ -144,6 +147,68 @@ def wrapper(*args, **kwargs):
return wrapper


def check_team_projects_and_cohorts(fn):
"""custom annotation to make sure cohort in request belong to user's team project"""

@wraps(fn)
def wrapper(*args, **kwargs):

token = kwargs["request"].headers.get("Authorization")
request_body = kwargs["request_body"]
team_project = request_body[TEAM_PROJECT_FIELD_NAME]
source_id = request_body["source_id"]

# Construct set with all cohort ids requested
cohort_ids = []
if "cohort_ids" in request_body["outcome"]:
cohort_ids.extend(request_body["outcome"]["cohort_ids"])

variables = request_body["variables"]
for v in variables:
if "cohort_ids" in v:
cohort_ids.extend(v["cohort_ids"])

if "source_population_cohort" in request_body:
cohort_ids.append(request_body["source_population_cohort"])

cohort_id_set = set(cohort_ids)

if team_project and source_id and len(team_project) > 0 and len(cohort_ids) > 0:
# Get team project cohort ids
team_cohort_id_set = get_cohort_ids_for_team_project(
token, source_id, team_project
)

logger.debug("cohort ids are " + " ".join(str(c) for c in cohort_ids))
logger.debug(
"team cohort ids are " + " ".join(str(c) for c in team_cohort_id_set)
)

# Compare the two sets
if cohort_id_set.issubset(team_cohort_id_set):
logger.debug(
"cohort ids submitted all belong to the same team project. Continue.."
)
return fn(*args, **kwargs)
else:
logger.error(
"Cohort ids submitted do NOT all belong to the same team project."
)
return HTMLResponse(
content="Cohort ids submitted do NOT all belong to the same team project.",
status_code=HTTP_400_BAD_REQUEST,
)

else:
# some required parameters is missing, return bad request:
return HTMLResponse(
content="Missing required parameters",
status_code=HTTP_400_BAD_REQUEST,
)

return wrapper


def check_user_billing_id(request):
"""
Check whether user is non-VA user
Expand Down Expand Up @@ -217,6 +282,7 @@ def test():
# submit argo workflow
@router.post("/submit", status_code=HTTP_200_OK)
@check_auth_and_team_project
@check_team_projects_and_cohorts
def submit_workflow(
request_body: Dict[Any, Any],
request: Request, # pylint: disable=unused-argument
Expand Down
179 changes: 127 additions & 52 deletions test/test_routes.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,15 +22,30 @@
"variables": variables,
"hare_population": "hare",
"out_prefix": "vadc_genesis",
"outcome": 1,
"outcome": {
"variable_type": "custom_dichotomous",
"cohort_ids": [2],
"provided_name": "test Pheno",
},
"maf_threshold": 0.01,
"imputation_score_cutoff": 0.3,
"template_version": "gwas-template-latest",
"source_id": 4,
"case_cohort_definition_id": 70,
"control_cohort_definition_id": -1,
"source_population_cohort": 4,
"workflow_name": "wf_name",
TEAM_PROJECT_FIELD_NAME: "dummy-team-project",
"user_tags": None, # For testing purpose
}

cohort_definition_data = {
"cohort_definitions_and_stats": [
{"cohort_definition_id": 1, "cohort_name": "Cohort 1", "size": 1},
{"cohort_definition_id": 2, "cohort_name": "Cohort 2", "size": 2},
{"cohort_definition_id": 3, "cohort_name": "Cohort 3", "size": 3},
{"cohort_definition_id": 4, "cohort_name": "Cohort 4", "size": 4},
]
}


Expand All @@ -55,17 +70,50 @@ def client(app: FastAPI) -> Generator[TestClient, Any, None]:
yield client


def mocked_requests_get(*args, **kwargs):
class MockResponse:
def __init__(self, json_data, status_code):
self.json_data = json_data
self.status_code = status_code

def json(self):
return self.json_data

def raise_for_status(self):
if self.status_code == 500:
raise Exception("fence is down")
if self.status_code != 200:
raise Exception()

if (
kwargs["url"]
== "http://cohort-middleware-service/cohortdefinition-stats/by-source-id/4/by-team-project?team-project=dummy-team-project"
):
return MockResponse(cohort_definition_data, 200)

if kwargs["url"] == "http://fence-service/user":
if data["user_tags"] != 500:
return MockResponse(data["user_tags"], 200)
else:
return MockResponse({}, 500)

return None


def test_submit_workflow(client):
with patch("argowrapper.routes.routes.auth.authenticate") as mock_auth, patch(
"argowrapper.routes.routes.argo_engine.workflow_submission"
) as mock_engine, patch(
"argowrapper.routes.routes.log_auth_check_type"
) as mock_log, patch(
"argowrapper.routes.routes.check_user_billing_id"
) as mock_check_billing_id:
) as mock_check_billing_id, patch(
"requests.get"
) as mock_requests:
mock_auth.return_value = True
mock_engine.return_value = "workflow_123"
mock_check_billing_id.return_value = None
mock_requests.side_effect = mocked_requests_get

response = client.post(
"/submit",
Expand Down Expand Up @@ -466,59 +514,44 @@ def test_submit_workflow_with_user_billing_id(client):
"argowrapper.routes.routes.argo_engine.workflow_submission"
) as mock_engine, patch(
"argowrapper.routes.routes.log_auth_check_type"
) as mock_log:
) as mock_log, patch(
"requests.get"
) as mock_requests:
mock_auth.return_value = True
mock_engine.return_value = "workflow_123"
with patch("requests.get") as mock_request:
mock_resp = mock.Mock()
mock_resp.status_code = 200
mock_resp.raise_for_status = mock.Mock()
mock_resp.json = mock.Mock(return_value={"tags": {}})
mock_request.return_value = mock_resp
mock_requests.side_effect = mocked_requests_get

response = client.post(
"/submit",
data=json.dumps(data),
headers={
"Content-Type": "application/json",
"Authorization": EXAMPLE_AUTH_HEADER,
},
)
assert response.status_code == 200
assert mock_engine.call_args.args[2] == None
data["user_tags"] = {"tags": {}}
response = client.post(
"/submit",
data=json.dumps(data),
headers={
"Content-Type": "application/json",
"Authorization": EXAMPLE_AUTH_HEADER,
},
)
assert response.status_code == 200
assert mock_engine.call_args.args[2] == None

mock_resp.json = mock.Mock(return_value={"tags": {"othertag1": "tag1"}})
data["user_tags"] = {"tags": {"othertag1": "tag1"}}

response = client.post(
"/submit",
data=json.dumps(data),
headers={
"Content-Type": "application/json",
"Authorization": EXAMPLE_AUTH_HEADER,
},
)
assert response.status_code == 200
assert mock_engine.call_args.args[2] == None
response = client.post(
"/submit",
data=json.dumps(data),
headers={
"Content-Type": "application/json",
"Authorization": EXAMPLE_AUTH_HEADER,
},
)
assert response.status_code == 200
assert mock_engine.call_args.args[2] == None

mock_resp.json = mock.Mock(
return_value={"tags": {"othertag1": "tag1", "billing_id": "1234"}}
)
with patch(
"argowrapper.routes.routes.check_user_reached_monthly_workflow_cap"
) as mock_check_monthly_cap:
mock_check_monthly_cap.return_value = False
response = client.post(
"/submit",
data=json.dumps(data),
headers={
"Content-Type": "application/json",
"Authorization": EXAMPLE_AUTH_HEADER,
},
)
assert mock_engine.call_args.args[2] == "1234"

mock_resp.status_code == 500
mock_resp.raise_for_status.side_effect = Exception("fence is down")
data["user_tags"] = {"tags": {"othertag1": "tag1", "billing_id": "1234"}}

with patch(
"argowrapper.routes.routes.check_user_reached_monthly_workflow_cap"
) as mock_check_monthly_cap:
mock_check_monthly_cap.return_value = False
response = client.post(
"/submit",
data=json.dumps(data),
Expand All @@ -527,8 +560,19 @@ def test_submit_workflow_with_user_billing_id(client):
"Authorization": EXAMPLE_AUTH_HEADER,
},
)
assert response.status_code == 500
assert "fence is down" in str(response.content)
assert mock_engine.call_args.args[2] == "1234"

data["user_tags"] == 500

response = client.post(
"/submit",
data=json.dumps(data),
headers={
"Content-Type": "application/json",
"Authorization": EXAMPLE_AUTH_HEADER,
},
)
assert response.status_code == 500


def test_check_user_reached_monthly_workflow_cap():
Expand Down Expand Up @@ -566,11 +610,14 @@ def test_submit_workflow_with_billing_id_and_over_monthly_cap(client):
"argowrapper.routes.routes.check_user_billing_id"
) as mock_check_billing_id, patch(
"argowrapper.routes.routes.check_user_reached_monthly_workflow_cap"
) as mock_check_monthly_cap:
) as mock_check_monthly_cap, patch(
"requests.get"
) as mock_requests:
mock_auth.return_value = True
mock_engine.return_value = "workflow_123"
mock_check_billing_id.return_value = "1234"
mock_check_monthly_cap.return_value = True
mock_requests.side_effect = mocked_requests_get

response = client.post(
"/submit",
Expand All @@ -581,3 +628,31 @@ def test_submit_workflow_with_billing_id_and_over_monthly_cap(client):
},
)
assert response.status_code == 403


def test_submit_workflow_with_non_team_project_cohort(client):
with patch("argowrapper.routes.routes.auth.authenticate") as mock_auth, patch(
"argowrapper.routes.routes.argo_engine.workflow_submission"
) as mock_engine, patch(
"argowrapper.routes.routes.log_auth_check_type"
) as mock_log, patch(
"argowrapper.routes.routes.check_user_billing_id"
) as mock_check_billing_id, patch(
"requests.get"
) as mock_requests:
mock_auth.return_value = True
mock_engine.return_value = "workflow_123"
mock_check_billing_id.return_value = None
mock_requests.side_effect = mocked_requests_get

data["outcome"]["cohort_ids"] = [400]

response = client.post(
"/submit",
data=json.dumps(data),
headers={
"Content-Type": "application/json",
"Authorization": EXAMPLE_AUTH_HEADER,
},
)
assert response.status_code == 400

0 comments on commit 8dcba15

Please sign in to comment.