From eaf556dbe94848c9d7781e08ff4b96585606b671 Mon Sep 17 00:00:00 2001 From: Nick Frasser <1693461+nfrasser@users.noreply.github.com> Date: Thu, 14 Mar 2024 10:51:18 -0400 Subject: [PATCH] Feat: add cluster_vars to job.queue (#79) --- cryosparc/job.py | 19 ++++++++++++++-- tests/conftest.py | 2 ++ tests/test_job.py | 56 +++++++++++++++++++++++++++++++++++++++++++++++ 3 files changed, 75 insertions(+), 2 deletions(-) diff --git a/cryosparc/job.py b/cryosparc/job.py index 15ff6bc9..4c24243c 100644 --- a/cryosparc/job.py +++ b/cryosparc/job.py @@ -7,7 +7,7 @@ from io import BytesIO from pathlib import PurePath, PurePosixPath from time import sleep, time -from typing import IO, TYPE_CHECKING, Any, Iterable, List, Optional, Pattern, Union, overload +from typing import IO, TYPE_CHECKING, Any, Dict, Iterable, List, Optional, Pattern, Union, overload from typing_extensions import Literal @@ -116,7 +116,13 @@ def dir(self) -> PurePosixPath: """ return PurePosixPath(self.cs.cli.get_job_dir_abs(self.project_uid, self.uid)) # type: ignore - def queue(self, lane: Optional[str] = None, hostname: Optional[str] = None, gpus: List[int] = []): + def queue( + self, + lane: Optional[str] = None, + hostname: Optional[str] = None, + gpus: List[int] = [], + cluster_vars: Dict[str, Any] = {}, + ): """ Queue a job to a target lane. Available lanes may be queried with `CryoSPARC.get_lanes`_. @@ -136,6 +142,9 @@ def queue(self, lane: Optional[str] = None, hostname: Optional[str] = None, gpus gpus (list[int], optional): GPUs to queue to. If specified, must have as many GPUs as required in job parameters. Leave unspecified to use first available GPU(s). Defaults to []. + cluster_vars (dict[str, Any], optional): Specify custom cluster + variables when queuing to a cluster. Keys are variable names. + Defaults to False. Examples: @@ -154,6 +163,12 @@ def queue(self, lane: Optional[str] = None, hostname: Optional[str] = None, gpus .. _CryoSPARC.get_targets: tools.html#cryosparc.tools.CryoSPARC.get_targets """ + if cluster_vars: + self.cs.cli.set_cluster_job_custom_vars( # type: ignore + project_uid=self.project_uid, + job_uid=self.uid, + cluster_job_custom_vars=cluster_vars, + ) self.cs.cli.enqueue_job( # type: ignore project_uid=self.project_uid, job_uid=self.uid, diff --git a/tests/conftest.py b/tests/conftest.py index 888718d3..83cce371 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -248,6 +248,8 @@ def request_callback_core(request, uri, response_headers): "get_project_dir_abs": "/projects/my-project", "get_project": {"uid": "P1", "title": "My Project"}, "make_job": "J1", + "set_cluster_job_custom_vars": None, + "enqueue_job": "queued", "job_send_streamlog": None, "job_connect_group": True, "job_set_param": True, diff --git a/tests/test_job.py b/tests/test_job.py index 34697be8..5ca001ce 100644 --- a/tests/test_job.py +++ b/tests/test_job.py @@ -13,6 +13,62 @@ def job(project: Project): return project.find_job("J1") +def test_queue(job: Job): + job.queue() + queue_request = httpretty.latest_requests()[-3] + refresh_request = httpretty.latest_requests()[-1] + assert queue_request.parsed_body["method"] == "enqueue_job" + assert queue_request.parsed_body["params"] == { + "project_uid": job.project_uid, + "job_uid": job.uid, + "lane": None, + "user_id": job.cs.user_id, + "hostname": None, + "gpus": False, + } + assert refresh_request.parsed_body["method"] == "get_job" + + +def test_queue_worker(job: Job): + job.queue(lane="workers", hostname="worker1", gpus=[1]) + queue_request = httpretty.latest_requests()[-3] + refresh_request = httpretty.latest_requests()[-1] + assert queue_request.parsed_body["method"] == "enqueue_job" + assert queue_request.parsed_body["params"] == { + "project_uid": job.project_uid, + "job_uid": job.uid, + "lane": "workers", + "user_id": job.cs.user_id, + "hostname": "worker1", + "gpus": [1], + } + assert refresh_request.parsed_body["method"] == "get_job" + + +def test_queue_cluster(job: Job): + vars = {"var1": 42, "var2": "test"} + job.queue(lane="cluster", cluster_vars=vars) + vars_request = httpretty.latest_requests()[-5] + queue_request = httpretty.latest_requests()[-3] + refresh_request = httpretty.latest_requests()[-1] + assert vars_request.parsed_body["method"] == "set_cluster_job_custom_vars" + assert vars_request.parsed_body["params"] == { + "project_uid": job.project_uid, + "job_uid": job.uid, + "cluster_job_custom_vars": vars, + } + assert queue_request.parsed_body["method"] == "enqueue_job" + assert queue_request.parsed_body["params"] == { + "project_uid": job.project_uid, + "job_uid": job.uid, + "lane": "cluster", + "user_id": job.cs.user_id, + "hostname": None, + "gpus": False, + } + assert refresh_request.parsed_body["method"] == "get_job" + + def test_load_output_all_slots(job: Job): output = job.load_output("particles_class_0") assert set(output.prefixes()) == {"location", "blob", "ctf"}