Skip to content

Commit

Permalink
Feat: add cluster_vars to job.queue (#79)
Browse files Browse the repository at this point in the history
  • Loading branch information
nfrasser authored Mar 14, 2024
1 parent 9644af4 commit eaf556d
Show file tree
Hide file tree
Showing 3 changed files with 75 additions and 2 deletions.
19 changes: 17 additions & 2 deletions cryosparc/job.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from io import BytesIO
from pathlib import PurePath, PurePosixPath
from time import sleep, time
from typing import IO, TYPE_CHECKING, Any, Iterable, List, Optional, Pattern, Union, overload
from typing import IO, TYPE_CHECKING, Any, Dict, Iterable, List, Optional, Pattern, Union, overload

from typing_extensions import Literal

Expand Down Expand Up @@ -116,7 +116,13 @@ def dir(self) -> PurePosixPath:
"""
return PurePosixPath(self.cs.cli.get_job_dir_abs(self.project_uid, self.uid)) # type: ignore

def queue(self, lane: Optional[str] = None, hostname: Optional[str] = None, gpus: List[int] = []):
def queue(
self,
lane: Optional[str] = None,
hostname: Optional[str] = None,
gpus: List[int] = [],
cluster_vars: Dict[str, Any] = {},
):
"""
Queue a job to a target lane. Available lanes may be queried with
`CryoSPARC.get_lanes`_.
Expand All @@ -136,6 +142,9 @@ def queue(self, lane: Optional[str] = None, hostname: Optional[str] = None, gpus
gpus (list[int], optional): GPUs to queue to. If specified, must
have as many GPUs as required in job parameters. Leave
unspecified to use first available GPU(s). Defaults to [].
cluster_vars (dict[str, Any], optional): Specify custom cluster
variables when queuing to a cluster. Keys are variable names.
Defaults to False.
Examples:
Expand All @@ -154,6 +163,12 @@ def queue(self, lane: Optional[str] = None, hostname: Optional[str] = None, gpus
.. _CryoSPARC.get_targets:
tools.html#cryosparc.tools.CryoSPARC.get_targets
"""
if cluster_vars:
self.cs.cli.set_cluster_job_custom_vars( # type: ignore
project_uid=self.project_uid,
job_uid=self.uid,
cluster_job_custom_vars=cluster_vars,
)
self.cs.cli.enqueue_job( # type: ignore
project_uid=self.project_uid,
job_uid=self.uid,
Expand Down
2 changes: 2 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -248,6 +248,8 @@ def request_callback_core(request, uri, response_headers):
"get_project_dir_abs": "/projects/my-project",
"get_project": {"uid": "P1", "title": "My Project"},
"make_job": "J1",
"set_cluster_job_custom_vars": None,
"enqueue_job": "queued",
"job_send_streamlog": None,
"job_connect_group": True,
"job_set_param": True,
Expand Down
56 changes: 56 additions & 0 deletions tests/test_job.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,62 @@ def job(project: Project):
return project.find_job("J1")


def test_queue(job: Job):
job.queue()
queue_request = httpretty.latest_requests()[-3]
refresh_request = httpretty.latest_requests()[-1]
assert queue_request.parsed_body["method"] == "enqueue_job"
assert queue_request.parsed_body["params"] == {
"project_uid": job.project_uid,
"job_uid": job.uid,
"lane": None,
"user_id": job.cs.user_id,
"hostname": None,
"gpus": False,
}
assert refresh_request.parsed_body["method"] == "get_job"


def test_queue_worker(job: Job):
job.queue(lane="workers", hostname="worker1", gpus=[1])
queue_request = httpretty.latest_requests()[-3]
refresh_request = httpretty.latest_requests()[-1]
assert queue_request.parsed_body["method"] == "enqueue_job"
assert queue_request.parsed_body["params"] == {
"project_uid": job.project_uid,
"job_uid": job.uid,
"lane": "workers",
"user_id": job.cs.user_id,
"hostname": "worker1",
"gpus": [1],
}
assert refresh_request.parsed_body["method"] == "get_job"


def test_queue_cluster(job: Job):
vars = {"var1": 42, "var2": "test"}
job.queue(lane="cluster", cluster_vars=vars)
vars_request = httpretty.latest_requests()[-5]
queue_request = httpretty.latest_requests()[-3]
refresh_request = httpretty.latest_requests()[-1]
assert vars_request.parsed_body["method"] == "set_cluster_job_custom_vars"
assert vars_request.parsed_body["params"] == {
"project_uid": job.project_uid,
"job_uid": job.uid,
"cluster_job_custom_vars": vars,
}
assert queue_request.parsed_body["method"] == "enqueue_job"
assert queue_request.parsed_body["params"] == {
"project_uid": job.project_uid,
"job_uid": job.uid,
"lane": "cluster",
"user_id": job.cs.user_id,
"hostname": None,
"gpus": False,
}
assert refresh_request.parsed_body["method"] == "get_job"


def test_load_output_all_slots(job: Job):
output = job.load_output("particles_class_0")
assert set(output.prefixes()) == {"location", "blob", "ctf"}
Expand Down

0 comments on commit eaf556d

Please sign in to comment.