diff --git a/src/jobflow_remote/cli/flow.py b/src/jobflow_remote/cli/flow.py index 54d6cec6..43f000da 100644 --- a/src/jobflow_remote/cli/flow.py +++ b/src/jobflow_remote/cli/flow.py @@ -68,7 +68,7 @@ def flows_list( flows_info = jc.get_flows_info( job_ids=job_id, db_ids=db_id, - flow_id=flow_id, + flow_ids=flow_id, state=state, start_date=start_date, end_date=end_date, @@ -112,7 +112,7 @@ def delete( flows_info = jc.get_flows_info( job_ids=job_id, db_ids=db_id, - flow_id=flow_id, + flow_ids=flow_id, state=state, start_date=start_date, end_date=end_date, diff --git a/src/jobflow_remote/cli/utils.py b/src/jobflow_remote/cli/utils.py index 398c8082..5360eb08 100644 --- a/src/jobflow_remote/cli/utils.py +++ b/src/jobflow_remote/cli/utils.py @@ -1,6 +1,7 @@ from __future__ import annotations import functools +import uuid from contextlib import contextmanager from enum import Enum @@ -121,6 +122,7 @@ def get_job_db_ids(job_db_id: str, job_index: int | None): except ValueError: db_id = None job_id = job_db_id + check_valid_uuid(job_id) if job_index and db_id is not None: out_console.print( @@ -143,6 +145,7 @@ def get_job_ids_indexes(job_ids: list[str] | None) -> list[tuple[str, int]] | No "(e.g. e1d66c4f-81db-4fff-bda2-2bf1d79d5961:2). " f"Wrong format for {j}" ) + check_valid_uuid(split[0]) job_ids_indexes.append((split[0], int(split[1]))) return job_ids_indexes @@ -170,3 +173,14 @@ def wrapper(*args, **kwargs): ) return wrapper + + +def check_valid_uuid(uuid_str): + try: + uuid_obj = uuid.UUID(uuid_str) + if str(uuid_obj) == uuid_str: + return + except ValueError: + pass + + raise typer.BadParameter(f"UUID {uuid_str} is in the wrong format.") diff --git a/src/jobflow_remote/fireworks/launchpad.py b/src/jobflow_remote/fireworks/launchpad.py index 787a6b06..d94557c4 100644 --- a/src/jobflow_remote/fireworks/launchpad.py +++ b/src/jobflow_remote/fireworks/launchpad.py @@ -345,6 +345,12 @@ def generate_id_query( ) -> tuple[dict, list | None]: query: dict = {} sort: list | None = None + + if (job_id is None) == (fw_id is None): + raise ValueError( + "One and only one among job_id and db_id should be defined" + ) + if fw_id: query["fw_id"] = fw_id if job_id: @@ -363,8 +369,10 @@ def _check_ids( job_id: str | None = None, job_index: int | None = None, ): - if job_id is None and fw_id is None: - raise ValueError("At least one among fw_id and job_id should be defined") + if (job_id is None) == (fw_id is None): + raise ValueError( + "One and only one among fw_id and job_id should be defined" + ) if job_id: fw_id = self.get_fw_id_from_job_id(job_id, job_index) return fw_id, job_id diff --git a/src/jobflow_remote/jobs/jobcontroller.py b/src/jobflow_remote/jobs/jobcontroller.py index ffc6642a..a36a1cff 100644 --- a/src/jobflow_remote/jobs/jobcontroller.py +++ b/src/jobflow_remote/jobs/jobcontroller.py @@ -121,7 +121,7 @@ def _build_query_wf( self, job_ids: str | list[str] | None = None, db_ids: int | list[int] | None = None, - flow_id: str | None = None, + flow_ids: str | None = None, state: FlowState | None = None, start_date: datetime | None = None, end_date: datetime | None = None, @@ -139,8 +139,8 @@ def _build_query_wf( if job_ids: query[f"fws.{FW_UUID_PATH}"] = {"$in": job_ids} - if flow_id: - query["metadata.flow_id"] = flow_id + if flow_ids: + query["metadata.flow_id"] = {"$in": flow_ids} if state: if state == FlowState.WAITING: @@ -389,7 +389,7 @@ def get_flows_info( self, job_ids: str | list[str] | None = None, db_ids: int | list[int] | None = None, - flow_id: str | None = None, + flow_ids: str | None = None, state: FlowState | None = None, start_date: datetime | None = None, end_date: datetime | None = None, @@ -399,7 +399,7 @@ def get_flows_info( query = self._build_query_wf( job_ids=job_ids, db_ids=db_ids, - flow_id=flow_id, + flow_ids=flow_ids, state=state, start_date=start_date, end_date=end_date,