Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

VADC-900 #123

Merged
merged 15 commits into from
Feb 7, 2024
24 changes: 24 additions & 0 deletions src/argowrapper/auth/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
import requests
from argowrapper import logger


def get_cohort_ids_for_team_project(token, source_id, team_project):
header = {"Authorization": token, "cookie": "fence={}".format(token)}
url = "http://cohort-middleware-service/cohortdefinition-stats/by-source-id/{}/by-team-project?team-project={}".format(
source_id, team_project
)

try:
r = requests.get(url=url, headers=header)
r.raise_for_status()
team_cohort_info = r.json()
team_cohort_id_set = set()
if "cohort_definitions_and_stats" in team_cohort_info:
for t in team_cohort_info["cohort_definitions_and_stats"]:
if "cohort_definition_id" in t:
team_cohort_id_set.add(t["cohort_definition_id"])
return team_cohort_id_set
except Exception as e:
exception = Exception("Could not get team project cohort ids", e)
logger.error(exception)
raise exception
66 changes: 66 additions & 0 deletions src/argowrapper/routes/routes.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
HTTP_401_UNAUTHORIZED,
HTTP_403_FORBIDDEN,
HTTP_500_INTERNAL_SERVER_ERROR,
HTTP_400_BAD_REQUEST,
)
from argowrapper.constants import (
TEAM_PROJECT_FIELD_NAME,
Expand All @@ -21,6 +22,8 @@
from argowrapper import logger
from argowrapper.auth import Auth
from argowrapper.engine.argo_engine import ArgoEngine
from argowrapper.auth.utils import get_cohort_ids_for_team_project

import argowrapper.engine.helpers.argo_engine_helper as argo_engine_helper

import requests
Expand Down Expand Up @@ -144,6 +147,68 @@ def wrapper(*args, **kwargs):
return wrapper


def check_team_projects_and_cohorts(fn):
"""custom annotation to make sure cohort in request belong to user's team project"""

@wraps(fn)
def wrapper(*args, **kwargs):

token = kwargs["request"].headers.get("Authorization")
request_body = kwargs["request_body"]
team_project = request_body[TEAM_PROJECT_FIELD_NAME]
source_id = request_body["source_id"]

# Construct set with all cohort ids requested
cohort_ids = []
if "cohort_ids" in request_body["outcome"]:
cohort_ids.extend(request_body["outcome"]["cohort_ids"])

variables = request_body["variables"]
for v in variables:
if "cohort_ids" in v:
cohort_ids.extend(v["cohort_ids"])

if "source_population_cohort" in request_body:
cohort_ids.append(request_body["source_population_cohort"])

cohort_id_set = set(cohort_ids)

if team_project and source_id and len(team_project) > 0 and len(cohort_ids) > 0:
# Get team project cohort ids
team_cohort_id_set = get_cohort_ids_for_team_project(
token, source_id, team_project
)

logger.debug("cohort ids are " + " ".join(str(c) for c in cohort_ids))
logger.debug(
"team cohort ids are " + " ".join(str(c) for c in team_cohort_id_set)
)

# Compare the two sets
if cohort_id_set.issubset(team_cohort_id_set):
logger.debug(
"cohort ids submitted all belong to the same team project. Continue.."
)
return fn(*args, **kwargs)
else:
logger.error(
"Cohort ids submitted do NOT all belong to the same team project."
)
return HTMLResponse(
content="Cohort ids submitted do NOT all belong to the same team project.",
status_code=HTTP_400_BAD_REQUEST,
)

else:
# some required parameters is missing, return bad request:
return HTMLResponse(
content="Missing required parameters",
status_code=HTTP_400_BAD_REQUEST,
)

return wrapper


def check_user_billing_id(request):
"""
Check whether user is non-VA user
Expand Down Expand Up @@ -217,6 +282,7 @@ def test():
# submit argo workflow
@router.post("/submit", status_code=HTTP_200_OK)
@check_auth_and_team_project
@check_team_projects_and_cohorts
def submit_workflow(
request_body: Dict[Any, Any],
request: Request, # pylint: disable=unused-argument
Expand Down
179 changes: 127 additions & 52 deletions test/test_routes.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,15 +22,30 @@
"variables": variables,
"hare_population": "hare",
"out_prefix": "vadc_genesis",
"outcome": 1,
"outcome": {
"variable_type": "custom_dichotomous",
"cohort_ids": [2],
"provided_name": "test Pheno",
},
"maf_threshold": 0.01,
"imputation_score_cutoff": 0.3,
"template_version": "gwas-template-latest",
"source_id": 4,
"case_cohort_definition_id": 70,
"control_cohort_definition_id": -1,
"source_population_cohort": 4,
"workflow_name": "wf_name",
TEAM_PROJECT_FIELD_NAME: "dummy-team-project",
"user_tags": None, # For testing purpose
}

cohort_definition_data = {
"cohort_definitions_and_stats": [
{"cohort_definition_id": 1, "cohort_name": "Cohort 1", "size": 1},
{"cohort_definition_id": 2, "cohort_name": "Cohort 2", "size": 2},
{"cohort_definition_id": 3, "cohort_name": "Cohort 3", "size": 3},
{"cohort_definition_id": 4, "cohort_name": "Cohort 4", "size": 4},
]
}


Expand All @@ -55,17 +70,50 @@ def client(app: FastAPI) -> Generator[TestClient, Any, None]:
yield client


def mocked_requests_get(*args, **kwargs):
class MockResponse:
def __init__(self, json_data, status_code):
self.json_data = json_data
self.status_code = status_code

def json(self):
return self.json_data

def raise_for_status(self):
if self.status_code == 500:
raise Exception("fence is down")
if self.status_code != 200:
raise Exception()

if (
kwargs["url"]
== "http://cohort-middleware-service/cohortdefinition-stats/by-source-id/4/by-team-project?team-project=dummy-team-project"
):
return MockResponse(cohort_definition_data, 200)

if kwargs["url"] == "http://fence-service/user":
if data["user_tags"] != 500:
return MockResponse(data["user_tags"], 200)
else:
return MockResponse({}, 500)

return None


def test_submit_workflow(client):
with patch("argowrapper.routes.routes.auth.authenticate") as mock_auth, patch(
"argowrapper.routes.routes.argo_engine.workflow_submission"
) as mock_engine, patch(
"argowrapper.routes.routes.log_auth_check_type"
) as mock_log, patch(
"argowrapper.routes.routes.check_user_billing_id"
) as mock_check_billing_id:
) as mock_check_billing_id, patch(
"requests.get"
) as mock_requests:
mock_auth.return_value = True
mock_engine.return_value = "workflow_123"
mock_check_billing_id.return_value = None
mock_requests.side_effect = mocked_requests_get

response = client.post(
"/submit",
Expand Down Expand Up @@ -466,59 +514,44 @@ def test_submit_workflow_with_user_billing_id(client):
"argowrapper.routes.routes.argo_engine.workflow_submission"
) as mock_engine, patch(
"argowrapper.routes.routes.log_auth_check_type"
) as mock_log:
) as mock_log, patch(
"requests.get"
) as mock_requests:
mock_auth.return_value = True
mock_engine.return_value = "workflow_123"
with patch("requests.get") as mock_request:
mock_resp = mock.Mock()
mock_resp.status_code = 200
mock_resp.raise_for_status = mock.Mock()
mock_resp.json = mock.Mock(return_value={"tags": {}})
mock_request.return_value = mock_resp
mock_requests.side_effect = mocked_requests_get

response = client.post(
"/submit",
data=json.dumps(data),
headers={
"Content-Type": "application/json",
"Authorization": EXAMPLE_AUTH_HEADER,
},
)
assert response.status_code == 200
assert mock_engine.call_args.args[2] == None
data["user_tags"] = {"tags": {}}
response = client.post(
"/submit",
data=json.dumps(data),
headers={
"Content-Type": "application/json",
"Authorization": EXAMPLE_AUTH_HEADER,
},
)
assert response.status_code == 200
assert mock_engine.call_args.args[2] == None

mock_resp.json = mock.Mock(return_value={"tags": {"othertag1": "tag1"}})
data["user_tags"] = {"tags": {"othertag1": "tag1"}}

response = client.post(
"/submit",
data=json.dumps(data),
headers={
"Content-Type": "application/json",
"Authorization": EXAMPLE_AUTH_HEADER,
},
)
assert response.status_code == 200
assert mock_engine.call_args.args[2] == None
response = client.post(
"/submit",
data=json.dumps(data),
headers={
"Content-Type": "application/json",
"Authorization": EXAMPLE_AUTH_HEADER,
},
)
assert response.status_code == 200
assert mock_engine.call_args.args[2] == None

mock_resp.json = mock.Mock(
return_value={"tags": {"othertag1": "tag1", "billing_id": "1234"}}
)
with patch(
"argowrapper.routes.routes.check_user_reached_monthly_workflow_cap"
) as mock_check_monthly_cap:
mock_check_monthly_cap.return_value = False
response = client.post(
"/submit",
data=json.dumps(data),
headers={
"Content-Type": "application/json",
"Authorization": EXAMPLE_AUTH_HEADER,
},
)
assert mock_engine.call_args.args[2] == "1234"

mock_resp.status_code == 500
mock_resp.raise_for_status.side_effect = Exception("fence is down")
data["user_tags"] = {"tags": {"othertag1": "tag1", "billing_id": "1234"}}

with patch(
"argowrapper.routes.routes.check_user_reached_monthly_workflow_cap"
) as mock_check_monthly_cap:
mock_check_monthly_cap.return_value = False
response = client.post(
"/submit",
data=json.dumps(data),
Expand All @@ -527,8 +560,19 @@ def test_submit_workflow_with_user_billing_id(client):
"Authorization": EXAMPLE_AUTH_HEADER,
},
)
assert response.status_code == 500
assert "fence is down" in str(response.content)
assert mock_engine.call_args.args[2] == "1234"

data["user_tags"] == 500

response = client.post(
"/submit",
data=json.dumps(data),
headers={
"Content-Type": "application/json",
"Authorization": EXAMPLE_AUTH_HEADER,
},
)
assert response.status_code == 500


def test_check_user_reached_monthly_workflow_cap():
Expand Down Expand Up @@ -566,11 +610,14 @@ def test_submit_workflow_with_billing_id_and_over_monthly_cap(client):
"argowrapper.routes.routes.check_user_billing_id"
) as mock_check_billing_id, patch(
"argowrapper.routes.routes.check_user_reached_monthly_workflow_cap"
) as mock_check_monthly_cap:
) as mock_check_monthly_cap, patch(
"requests.get"
) as mock_requests:
mock_auth.return_value = True
mock_engine.return_value = "workflow_123"
mock_check_billing_id.return_value = "1234"
mock_check_monthly_cap.return_value = True
mock_requests.side_effect = mocked_requests_get

response = client.post(
"/submit",
Expand All @@ -581,3 +628,31 @@ def test_submit_workflow_with_billing_id_and_over_monthly_cap(client):
},
)
assert response.status_code == 403


def test_submit_workflow_with_non_team_project_cohort(client):
with patch("argowrapper.routes.routes.auth.authenticate") as mock_auth, patch(
"argowrapper.routes.routes.argo_engine.workflow_submission"
) as mock_engine, patch(
"argowrapper.routes.routes.log_auth_check_type"
) as mock_log, patch(
"argowrapper.routes.routes.check_user_billing_id"
) as mock_check_billing_id, patch(
"requests.get"
) as mock_requests:
mock_auth.return_value = True
mock_engine.return_value = "workflow_123"
mock_check_billing_id.return_value = None
mock_requests.side_effect = mocked_requests_get

data["outcome"]["cohort_ids"] = [400]

response = client.post(
"/submit",
data=json.dumps(data),
headers={
"Content-Type": "application/json",
"Authorization": EXAMPLE_AUTH_HEADER,
},
)
assert response.status_code == 400
Loading