Skip to content

Commit

Permalink
more fix for id based searches
Browse files Browse the repository at this point in the history
  • Loading branch information
gpetretto committed Aug 22, 2023
1 parent ae9d83e commit b29c55a
Show file tree
Hide file tree
Showing 4 changed files with 31 additions and 9 deletions.
4 changes: 2 additions & 2 deletions src/jobflow_remote/cli/flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ def flows_list(
flows_info = jc.get_flows_info(
job_ids=job_id,
db_ids=db_id,
flow_id=flow_id,
flow_ids=flow_id,
state=state,
start_date=start_date,
end_date=end_date,
Expand Down Expand Up @@ -112,7 +112,7 @@ def delete(
flows_info = jc.get_flows_info(
job_ids=job_id,
db_ids=db_id,
flow_id=flow_id,
flow_ids=flow_id,
state=state,
start_date=start_date,
end_date=end_date,
Expand Down
14 changes: 14 additions & 0 deletions src/jobflow_remote/cli/utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import annotations

import functools
import uuid
from contextlib import contextmanager
from enum import Enum

Expand Down Expand Up @@ -121,6 +122,7 @@ def get_job_db_ids(job_db_id: str, job_index: int | None):
except ValueError:
db_id = None
job_id = job_db_id
check_valid_uuid(job_id)

if job_index and db_id is not None:
out_console.print(
Expand All @@ -143,6 +145,7 @@ def get_job_ids_indexes(job_ids: list[str] | None) -> list[tuple[str, int]] | No
"(e.g. e1d66c4f-81db-4fff-bda2-2bf1d79d5961:2). "
f"Wrong format for {j}"
)
check_valid_uuid(split[0])
job_ids_indexes.append((split[0], int(split[1])))

return job_ids_indexes
Expand Down Expand Up @@ -170,3 +173,14 @@ def wrapper(*args, **kwargs):
)

return wrapper


def check_valid_uuid(uuid_str):
try:
uuid_obj = uuid.UUID(uuid_str)
if str(uuid_obj) == uuid_str:
return
except ValueError:
pass

raise typer.BadParameter(f"UUID {uuid_str} is in the wrong format.")
12 changes: 10 additions & 2 deletions src/jobflow_remote/fireworks/launchpad.py
Original file line number Diff line number Diff line change
Expand Up @@ -345,6 +345,12 @@ def generate_id_query(
) -> tuple[dict, list | None]:
query: dict = {}
sort: list | None = None

if (job_id is None) == (fw_id is None):
raise ValueError(
"One and only one among job_id and db_id should be defined"
)

if fw_id:
query["fw_id"] = fw_id
if job_id:
Expand All @@ -363,8 +369,10 @@ def _check_ids(
job_id: str | None = None,
job_index: int | None = None,
):
if job_id is None and fw_id is None:
raise ValueError("At least one among fw_id and job_id should be defined")
if (job_id is None) == (fw_id is None):
raise ValueError(
"One and only one among fw_id and job_id should be defined"
)
if job_id:
fw_id = self.get_fw_id_from_job_id(job_id, job_index)
return fw_id, job_id
Expand Down
10 changes: 5 additions & 5 deletions src/jobflow_remote/jobs/jobcontroller.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@ def _build_query_wf(
self,
job_ids: str | list[str] | None = None,
db_ids: int | list[int] | None = None,
flow_id: str | None = None,
flow_ids: str | None = None,
state: FlowState | None = None,
start_date: datetime | None = None,
end_date: datetime | None = None,
Expand All @@ -139,8 +139,8 @@ def _build_query_wf(
if job_ids:
query[f"fws.{FW_UUID_PATH}"] = {"$in": job_ids}

if flow_id:
query["metadata.flow_id"] = flow_id
if flow_ids:
query["metadata.flow_id"] = {"$in": flow_ids}

if state:
if state == FlowState.WAITING:
Expand Down Expand Up @@ -389,7 +389,7 @@ def get_flows_info(
self,
job_ids: str | list[str] | None = None,
db_ids: int | list[int] | None = None,
flow_id: str | None = None,
flow_ids: str | None = None,
state: FlowState | None = None,
start_date: datetime | None = None,
end_date: datetime | None = None,
Expand All @@ -399,7 +399,7 @@ def get_flows_info(
query = self._build_query_wf(
job_ids=job_ids,
db_ids=db_ids,
flow_id=flow_id,
flow_ids=flow_ids,
state=state,
start_date=start_date,
end_date=end_date,
Expand Down

0 comments on commit b29c55a

Please sign in to comment.