diff --git a/poetry.lock b/poetry.lock index e3a0873a..8e9330d2 100644 --- a/poetry.lock +++ b/poetry.lock @@ -129,14 +129,14 @@ files = [ [[package]] name = "certifi" -version = "2023.7.22" +version = "2023.11.17" description = "Python package for providing Mozilla's CA Bundle." category = "main" optional = false python-versions = ">=3.6" files = [ - {file = "certifi-2023.7.22-py3-none-any.whl", hash = "sha256:92d6037539857d8206b8f6ae472e8b77db8058fec5937a1ef3f54304089edbb9"}, - {file = "certifi-2023.7.22.tar.gz", hash = "sha256:539cc1d13202e33ca466e88b2807e29f4c13049d6d87031a3c110744495cb082"}, + {file = "certifi-2023.11.17-py3-none-any.whl", hash = "sha256:e036ab49d5b79556f99cfc2d9320b34cfbe5be05c5871b51de9329f0603b0474"}, + {file = "certifi-2023.11.17.tar.gz", hash = "sha256:9b469f3a900bf28dc19b8cfbf8019bf47f7fdd1a65a1d4ffb98fc14166beb4d1"}, ] [[package]] @@ -336,14 +336,14 @@ toml = ["tomli"] [[package]] name = "exceptiongroup" -version = "1.1.3" +version = "1.2.0" description = "Backport of PEP 654 (exception groups)" category = "main" optional = false python-versions = ">=3.7" files = [ - {file = "exceptiongroup-1.1.3-py3-none-any.whl", hash = "sha256:343280667a4585d195ca1cf9cef84a4e178c4b6cf2274caef9859782b567d5e3"}, - {file = "exceptiongroup-1.1.3.tar.gz", hash = "sha256:097acd85d473d75af5bb98e41b61ff7fe35efe6675e4f9370ec6ec5126d160e9"}, + {file = "exceptiongroup-1.2.0-py3-none-any.whl", hash = "sha256:4bfd3996ac73b41e9b9628b04e079f193850720ea5945fc96a08633c66912f14"}, + {file = "exceptiongroup-1.2.0.tar.gz", hash = "sha256:91f5c769735f051a4290d52edd0858999b57e5876e9f85937691bd4c9fa3ed68"}, ] [package.extras] @@ -371,6 +371,21 @@ dev = ["autoflake (>=1.4.0,<2.0.0)", "flake8 (>=3.8.3,<4.0.0)", "passlib[bcrypt] doc = ["mdx-include (>=1.4.1,<2.0.0)", "mkdocs (>=1.1.2,<2.0.0)", "mkdocs-markdownextradata-plugin (>=0.1.7,<0.3.0)", "mkdocs-material (>=8.1.4,<9.0.0)", "pyyaml (>=5.3.1,<6.0.0)", "typer-cli (>=0.0.12,<0.0.13)"] test = ["anyio[trio] (>=3.2.1,<4.0.0)", "black (==21.9b0)", "databases[sqlite] (>=0.3.2,<0.6.0)", "email_validator (>=1.1.1,<2.0.0)", "flake8 (>=3.8.3,<4.0.0)", "flask (>=1.1.2,<3.0.0)", "httpx (>=0.14.0,<0.19.0)", "isort (>=5.0.6,<6.0.0)", "mypy (==0.910)", "orjson (>=3.2.1,<4.0.0)", "peewee (>=3.13.3,<4.0.0)", "pytest (>=6.2.4,<7.0.0)", "pytest-cov (>=2.12.0,<4.0.0)", "python-multipart (>=0.0.5,<0.0.6)", "requests (>=2.24.0,<3.0.0)", "sqlalchemy (>=1.3.18,<1.5.0)", "types-dataclasses (==0.1.7)", "types-orjson (==3.6.0)", "types-ujson (==0.1.1)", "ujson (>=4.0.1,<5.0.0)"] +[[package]] +name = "freezegun" +version = "1.2.2" +description = "Let your Python tests travel through time" +category = "main" +optional = false +python-versions = ">=3.6" +files = [ + {file = "freezegun-1.2.2-py3-none-any.whl", hash = "sha256:ea1b963b993cb9ea195adbd893a48d573fda951b0da64f60883d7e988b606c9f"}, + {file = "freezegun-1.2.2.tar.gz", hash = "sha256:cd22d1ba06941384410cd967d8a99d5ae2442f57dfafeff2fda5de8dc5c05446"}, +] + +[package.dependencies] +python-dateutil = ">=2.7" + [[package]] name = "gen3authz" version = "1.5.1" @@ -424,14 +439,14 @@ files = [ [[package]] name = "httpcore" -version = "1.0.1" +version = "1.0.2" description = "A minimal low-level HTTP client." category = "main" optional = false python-versions = ">=3.8" files = [ - {file = "httpcore-1.0.1-py3-none-any.whl", hash = "sha256:c5e97ef177dca2023d0b9aad98e49507ef5423e9f1d94ffe2cfe250aa28e63b0"}, - {file = "httpcore-1.0.1.tar.gz", hash = "sha256:fce1ddf9b606cfb98132ab58865c3728c52c8e4c3c46e2aabb3674464a186e92"}, + {file = "httpcore-1.0.2-py3-none-any.whl", hash = "sha256:096cc05bca73b8e459a1fc3dcf585148f63e534eae4339559c9b8a8d6399acc7"}, + {file = "httpcore-1.0.2.tar.gz", hash = "sha256:9fc092e4799b26174648e54b74ed5f683132a464e95643b226e00c2ed2fa6535"}, ] [package.dependencies] @@ -804,18 +819,18 @@ use-chardet-on-py3 = ["chardet (>=3.0.2,<6)"] [[package]] name = "setuptools" -version = "68.2.2" +version = "69.0.1" description = "Easily download, build, install, upgrade, and uninstall Python packages" category = "main" optional = false python-versions = ">=3.8" files = [ - {file = "setuptools-68.2.2-py3-none-any.whl", hash = "sha256:b454a35605876da60632df1a60f736524eb73cc47bbc9f3f1ef1b644de74fd2a"}, - {file = "setuptools-68.2.2.tar.gz", hash = "sha256:4ac1475276d2f1c48684874089fefcd83bd7162ddaafb81fac866ba0db282a87"}, + {file = "setuptools-69.0.1-py3-none-any.whl", hash = "sha256:6875bbd06382d857b1b90cd07cee6a2df701a164f241095706b5192bc56c5c62"}, + {file = "setuptools-69.0.1.tar.gz", hash = "sha256:f25195d54deb649832182d6455bffba7ac3d8fe71d35185e738d2198a4310044"}, ] [package.extras] -docs = ["furo", "jaraco.packaging (>=9.3)", "jaraco.tidelift (>=1.4)", "pygments-github-lexers (==0.0.5)", "rst.linker (>=1.9)", "sphinx (>=3.5)", "sphinx-favicon", "sphinx-hoverxref (<2)", "sphinx-inline-tabs", "sphinx-lint", "sphinx-notfound-page (>=1,<2)", "sphinx-reredirects", "sphinxcontrib-towncrier"] +docs = ["furo", "jaraco.packaging (>=9.3)", "jaraco.tidelift (>=1.4)", "pygments-github-lexers (==0.0.5)", "rst.linker (>=1.9)", "sphinx (<7.2.5)", "sphinx (>=3.5)", "sphinx-favicon", "sphinx-inline-tabs", "sphinx-lint", "sphinx-notfound-page (>=1,<2)", "sphinx-reredirects", "sphinxcontrib-towncrier"] testing = ["build[virtualenv]", "filelock (>=3.4.0)", "flake8-2020", "ini2toml[lite] (>=0.9)", "jaraco.develop (>=7.21)", "jaraco.envs (>=2.2)", "jaraco.path (>=3.2.0)", "pip (>=19.1)", "pytest (>=6)", "pytest-black (>=0.3.7)", "pytest-checkdocs (>=2.4)", "pytest-cov", "pytest-enabler (>=2.2)", "pytest-mypy (>=0.9.1)", "pytest-perf", "pytest-ruff", "pytest-timeout", "pytest-xdist", "tomli-w (>=1.0.0)", "virtualenv (>=13.0.0)", "wheel"] testing-integration = ["build[virtualenv] (>=1.0.3)", "filelock (>=3.4.0)", "jaraco.envs (>=2.2)", "jaraco.path (>=3.2.0)", "packaging (>=23.1)", "pytest", "pytest-enabler", "pytest-xdist", "tomli", "virtualenv (>=13.0.0)", "wheel"] @@ -899,19 +914,18 @@ files = [ [[package]] name = "urllib3" -version = "2.0.7" +version = "2.1.0" description = "HTTP library with thread-safe connection pooling, file post, and more." category = "main" optional = false -python-versions = ">=3.7" +python-versions = ">=3.8" files = [ - {file = "urllib3-2.0.7-py3-none-any.whl", hash = "sha256:fdb6d215c776278489906c2f8916e6e7d4f5a9b602ccbcfdf7f016fc8da0596e"}, - {file = "urllib3-2.0.7.tar.gz", hash = "sha256:c97dfde1f7bd43a71c8d2a58e369e9b2bf692d1334ea9f9cae55add7d0dd0f84"}, + {file = "urllib3-2.1.0-py3-none-any.whl", hash = "sha256:55901e917a5896a349ff771be919f8bd99aff50b79fe58fec595eb37bbc56bb3"}, + {file = "urllib3-2.1.0.tar.gz", hash = "sha256:df7aa8afb0148fa78488e7899b2c59b5f4ffcfa82e6c54ccb9dd37c1d7b52d54"}, ] [package.extras] brotli = ["brotli (>=1.0.9)", "brotlicffi (>=0.8.0)"] -secure = ["certifi", "cryptography (>=1.9)", "idna (>=2.0.0)", "pyopenssl (>=17.1.0)", "urllib3-secure-extra"] socks = ["pysocks (>=1.5.6,!=1.5.7,<2.0)"] zstd = ["zstandard (>=0.18.0)"] @@ -954,4 +968,4 @@ testing = ["big-O", "jaraco.functools", "jaraco.itertools", "more-itertools", "p [metadata] lock-version = "2.0" python-versions = "^3.9" -content-hash = "b4e2b4516d3304b53ec3b8a4870aab4cf3413f386bfd93b40009541575aed69b" +content-hash = "c7d58e155f14e5d67b2aeffe25509b52c176a02e2c1666869d923bdd7a3b6fe7" diff --git a/pyproject.toml b/pyproject.toml index 0220dfae..189e8ba6 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "argowrapper" -version = "1.8.0" +version = "1.8.1" description = "argo wrapper for va workflow" authors = ["UchicagoZchen138 "] repository = "https://github.com/uc-cdis/argo-wrapper" @@ -21,10 +21,12 @@ importlib-resources = "^5.4.0" requests = "^2.27.1" PyJWT = "^2.4.0" pytest-cov = "^3.0.0" +freezegun = "^1.2.2" [tool.poetry.dev-dependencies] pytest = "^6.2.5" + [build-system] requires = ["poetry-core>=1.0.0"] build-backend = "poetry.core.masonry.api" diff --git a/src/argowrapper/constants.py b/src/argowrapper/constants.py index f62536db..d15fa642 100644 --- a/src/argowrapper/constants.py +++ b/src/argowrapper/constants.py @@ -37,6 +37,9 @@ WORKFLOW_KIND: Final = "workflow" GEN3_USER_METADATA_LABEL: Final = "gen3username" GEN3_TEAM_PROJECT_METADATA_LABEL: Final = "gen3teamproject" +GEN3_WORKFLOW_PHASE_LABEL: Final = "phase" +GEN3_SUBMIT_TIMESTAMP_LABEL: Final = "submittedAt" +GEN3_NON_VA_WORKFLOW_MONTHLY_CAP: Final = 20 class POD_COMPLETION_STRATEGY(Enum): diff --git a/src/argowrapper/engine/argo_engine.py b/src/argowrapper/engine/argo_engine.py index 6c1366a9..aa4b13bc 100644 --- a/src/argowrapper/engine/argo_engine.py +++ b/src/argowrapper/engine/argo_engine.py @@ -29,11 +29,15 @@ WORKFLOW, GEN3_USER_METADATA_LABEL, GEN3_TEAM_PROJECT_METADATA_LABEL, + GEN3_WORKFLOW_PHASE_LABEL, + GEN3_SUBMIT_TIMESTAMP_LABEL, ) from argowrapper.engine.helpers import argo_engine_helper from argowrapper.engine.helpers.workflow_factory import WorkflowFactory from argowrapper.workflows.argo_workflows.gwas import GWAS +from datetime import datetime + class ArgoEngine: """ @@ -378,6 +382,40 @@ def get_workflows_for_user(self, auth_header: str) -> List[Dict]: user_only_workflows.append(workflow) return user_only_workflows + def get_user_workflows_for_current_month(self, auth_header: str) -> List[Dict]: + """ + Get the list of all succeeded and running workflows the current user owns in the current month. + Each item in the list contains the workflow name, its status, start and end time. + + Args: + auth_header: authorization header that contains the user's jwt token + + Returns: + List[Dict]: List of workflow dictionaries with details of workflows + that the user has ran. + + Raises: + raises Exception in case of any error. + """ + username = argo_engine_helper.get_username_from_token(auth_header) + user_label = argo_engine_helper.convert_gen3username_to_pod_label(username) + label_selector = f"{GEN3_USER_METADATA_LABEL}={user_label}" + all_user_workflows = self.get_workflows_for_label_selector( + label_selector=label_selector + ) + user_monthly_workflows = [] + for workflow in all_user_workflows: + if workflow[GEN3_WORKFLOW_PHASE_LABEL] in {"Running", "Succeeded"}: + submitted_time_str = workflow[GEN3_SUBMIT_TIMESTAMP_LABEL] + submitted_time = datetime.strptime( + submitted_time_str, "%Y-%m-%dT%H:%M:%SZ" + ) + first_day_of_month = datetime.today().replace(day=1) + if submitted_time.date() >= first_day_of_month.date(): + user_monthly_workflows.append(workflow) + + return user_monthly_workflows + def get_workflows_for_label_selector(self, label_selector: str) -> List[Dict]: try: workflow_list_return = self.api_instance.list_workflows( diff --git a/src/argowrapper/routes/routes.py b/src/argowrapper/routes/routes.py index 38019d44..dfc1a7a5 100644 --- a/src/argowrapper/routes/routes.py +++ b/src/argowrapper/routes/routes.py @@ -7,6 +7,7 @@ from starlette.status import ( HTTP_200_OK, HTTP_401_UNAUTHORIZED, + HTTP_403_FORBIDDEN, HTTP_500_INTERNAL_SERVER_ERROR, ) from argowrapper.constants import ( @@ -14,6 +15,7 @@ TEAM_PROJECT_LIST_FIELD_NAME, GEN3_TEAM_PROJECT_METADATA_LABEL, GEN3_USER_METADATA_LABEL, + GEN3_NON_VA_WORKFLOW_MONTHLY_CAP, ) from argowrapper import logger @@ -152,6 +154,7 @@ def check_user_billing_id(request): """ header = {"Authorization": request.headers.get("Authorization")} + # TODO: Make this configurable url = "http://fence-service/user" try: r = requests.get(url=url, headers=header) @@ -173,6 +176,38 @@ def check_user_billing_id(request): return None +def check_user_reached_monthly_workflow_cap(request_token): + """ + Query Argo service to see how many successful run user already + have in the current calendar month. If the number is greater than + the threshold, return error. + """ + + try: + current_month_workflows = argo_engine.get_user_workflows_for_current_month( + request_token + ) + + if len(current_month_workflows) >= GEN3_NON_VA_WORKFLOW_MONTHLY_CAP: + logger.info( + "User already executed {} workflows this month and cannot create new ones anymore.".format( + len(current_month_workflows) + ) + ) + logger.info( + "The currently monthly cap is {}.".format( + GEN3_NON_VA_WORKFLOW_MONTHLY_CAP + ) + ) + return True + + return False + except Exception as e: + logger.error(e) + traceback.print_exc() + raise e + + @router.get("/test") def test(): """route to test that the argo-workflow is correctly running""" @@ -188,13 +223,27 @@ def submit_workflow( ) -> str: """route to submit workflow""" try: + reached_monthly_cap = False + # check if user has a billing id tag: billing_id = check_user_billing_id(request) - # submit workflow: - return argo_engine.workflow_submission( - request_body, request.headers.get("Authorization"), billing_id - ) + # if user has billing_id (non-VA user), check if they already reached the monthly cap + if billing_id: + reached_monthly_cap = check_user_reached_monthly_workflow_cap( + request.headers.get("Authorization") + ) + + # submit workflow: + if not reached_monthly_cap: + return argo_engine.workflow_submission( + request_body, request.headers.get("Authorization"), billing_id + ) + else: + return HTMLResponse( + content="You have reached the workflow monthly cap.", + status_code=HTTP_403_FORBIDDEN, + ) except Exception as exception: return HTMLResponse( content=str(exception), diff --git a/test/test_argo_engine.py b/test/test_argo_engine.py index d3a5b18b..62cbdd39 100644 --- a/test/test_argo_engine.py +++ b/test/test_argo_engine.py @@ -11,6 +11,7 @@ from test.constants import EXAMPLE_AUTH_HEADER from argowrapper.workflows.argo_workflows.gwas import * from unittest.mock import patch +from freezegun import freeze_time class WorkFlow: @@ -777,3 +778,53 @@ def test_get_archived_workflow_wf_name_and_team_project(): ) = engine._get_archived_workflow_wf_name_and_team_project("dummy_uid") assert given_name == "dummy_wf_name" assert team_project == "dummy_team_project_label" + + +@freeze_time("Nov 16th, 2023") +def test_get_user_workflows_for_current_month(monkeypatch): + + engine = ArgoEngine() + workflows_mock_response = [ + { + "uid": "uid_1", + "phase": "Running", + "submittedAt": "2023-11-14T16:44:02Z", + }, + { + "uid": "uid_2", + "phase": "Succeeded", + "submittedAt": "2023-11-15T17:52:52Z", + }, + { + "uid": "uid_3", + "phase": "Failed", + "submittedAt": "2023-11-02T00:00:00Z", + }, + { + "uid": "uid_4", + "phase": "Succeeded", + "submittedAt": "2023-10-31T00:00:00Z", + }, + ] + + expected_workflow_reponse = [ + { + "uid": "uid_1", + "phase": "Running", + "submittedAt": "2023-11-14T16:44:02Z", + }, + { + "uid": "uid_2", + "phase": "Succeeded", + "submittedAt": "2023-11-15T17:52:52Z", + }, + ] + engine.get_workflows_for_label_selector = mock.MagicMock( + return_value=workflows_mock_response + ) + + user_monthly_workflow = engine.get_user_workflows_for_current_month( + EXAMPLE_AUTH_HEADER + ) + + assert user_monthly_workflow == expected_workflow_reponse diff --git a/test/test_routes.py b/test/test_routes.py index 365b3a57..324fb6f9 100644 --- a/test/test_routes.py +++ b/test/test_routes.py @@ -8,7 +8,8 @@ from fastapi.testclient import TestClient from argowrapper.constants import * from test.constants import EXAMPLE_AUTH_HEADER -from argowrapper.routes.routes import router +from argowrapper.routes.routes import router, check_user_reached_monthly_workflow_cap +from argowrapper.constants import GEN3_NON_VA_WORKFLOW_MONTHLY_CAP variables = [ {"variable_type": "concept", "concept_id": "2000000324"}, @@ -460,7 +461,7 @@ def test_if_endpoints_are_set_to_the_right_check_auth(client): mock_log.assert_called_with("check_auth") -def test_check_user_billing_id(client): +def test_submit_workflow_with_user_billing_id(client): with patch("argowrapper.routes.routes.auth.authenticate") as mock_auth, patch( "argowrapper.routes.routes.argo_engine.workflow_submission" ) as mock_engine, patch( @@ -502,16 +503,19 @@ def test_check_user_billing_id(client): mock_resp.json = mock.Mock( return_value={"tags": {"othertag1": "tag1", "billing_id": "1234"}} ) - - response = client.post( - "/submit", - data=json.dumps(data), - headers={ - "Content-Type": "application/json", - "Authorization": EXAMPLE_AUTH_HEADER, - }, - ) - assert mock_engine.call_args.args[2] == "1234" + with patch( + "argowrapper.routes.routes.check_user_reached_monthly_workflow_cap" + ) as mock_check_monthly_cap: + mock_check_monthly_cap.return_value = False + response = client.post( + "/submit", + data=json.dumps(data), + headers={ + "Content-Type": "application/json", + "Authorization": EXAMPLE_AUTH_HEADER, + }, + ) + assert mock_engine.call_args.args[2] == "1234" mock_resp.status_code == 500 mock_resp.raise_for_status.side_effect = Exception("fence is down") @@ -525,3 +529,55 @@ def test_check_user_billing_id(client): ) assert response.status_code == 500 assert "fence is down" in str(response.content) + + +def test_check_user_reached_monthly_workflow_cap(): + headers = { + "Content-Type": "application/json", + "Authorization": EXAMPLE_AUTH_HEADER, + } + + with patch( + "argowrapper.engine.argo_engine.ArgoEngine.get_user_workflows_for_current_month" + ) as mock_get_workflow: + mock_get_workflow.return_value = [ + {"wf_name": "workflow1"}, + {"wf_name": "workflow2"}, + ] + assert ( + check_user_reached_monthly_workflow_cap(headers["Authorization"]) == False + ) + + workflows = [] + for index in range(GEN3_NON_VA_WORKFLOW_MONTHLY_CAP + 1): + workflows.append({"wf_name": "workflow" + str(index)}) + + mock_get_workflow.return_value = workflows + + assert check_user_reached_monthly_workflow_cap(headers["Authorization"]) == True + + +def test_submit_workflow_with_billing_id_and_over_monthly_cap(client): + with patch("argowrapper.routes.routes.auth.authenticate") as mock_auth, patch( + "argowrapper.routes.routes.argo_engine.workflow_submission" + ) as mock_engine, patch( + "argowrapper.routes.routes.log_auth_check_type" + ) as mock_log, patch( + "argowrapper.routes.routes.check_user_billing_id" + ) as mock_check_billing_id, patch( + "argowrapper.routes.routes.check_user_reached_monthly_workflow_cap" + ) as mock_check_monthly_cap: + mock_auth.return_value = True + mock_engine.return_value = "workflow_123" + mock_check_billing_id.return_value = "1234" + mock_check_monthly_cap.return_value = True + + response = client.post( + "/submit", + data=json.dumps(data), + headers={ + "Content-Type": "application/json", + "Authorization": EXAMPLE_AUTH_HEADER, + }, + ) + assert response.status_code == 403