Skip to content

Commit

Permalink
Simplify and tidy.
Browse files Browse the repository at this point in the history
  • Loading branch information
hjoliver committed Dec 4, 2023
1 parent e462512 commit f928d6e
Show file tree
Hide file tree
Showing 3 changed files with 86 additions and 115 deletions.
2 changes: 2 additions & 0 deletions cylc/flow/scripts/set.py
Original file line number Diff line number Diff line change
Expand Up @@ -286,6 +286,8 @@ async def run(

@cli_function(get_option_parser)
def main(parser: COP, options: 'Values', *ids) -> None:
if options.outputs and options.prerequisites:
raise InputError("Use --prerequisite or --output, not both.")
validate_flow_opts(options)
call_multi(
partial(run, options),
Expand Down
194 changes: 80 additions & 114 deletions cylc/flow/task_pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,28 @@
Pool = Dict['PointBase', Dict[str, TaskProxy]]


def prereqs_str_to_tokens(prereqs: List[str]) -> List[Tuple[str, str, str]]:
"""Convert prerequisite strings to token tuples.
["<cycle>/<task>:<sel>", ...] --> [(<cycle>, <task>, <sel>), ...]
Selector defaults to "succeeded".
Examples:
>>> prereqs_str_to_tokens(['1/b', '3/c:failed'])
[('1', 'b', 'succeeded'), ('3', 'c', 'failed')]
"""
return [
(
t['cycle'], t['task'], t['task_sel'] or TASK_OUTPUT_SUCCEEDED
)
for t in [
Tokens(p, relative=True) for p in prereqs
]
]


class TaskPool:
"""Task pool of a workflow."""

Expand Down Expand Up @@ -1565,11 +1587,40 @@ def spawn_task(
self.db_add_new_flow_rows(itask)
return itask

def _spawn_transient(
self,
point: 'PointBase',
taskdef: 'TaskDef',
flow_nums: 'FlowNums',
flow_wait: bool
) -> Optional['TaskProxy']:
"""Spawn a transient task proxy and update its outputs from the DB."""

itask = self.spawn_task(
taskdef.name,
point,
flow_nums,
flow_wait=flow_wait,
force=True,
transient=True
)
if itask is not None:
# Update outputs that were already completed.
for outputs_str, fnums in (
self.workflow_db_mgr.pri_dao.select_task_outputs(
itask.tdef.name, str(itask.point))
).items():
if flow_nums.intersection(fnums):
for msg in json.loads(outputs_str):
itask.state.outputs.set_completed_by_msg(msg)
break
return itask

def set( # noqa: A003
self,
items: Iterable[str],
outputs: Optional[List[str]],
prerequisites: Optional[List[str]],
prereqs: Optional[List[str]],
flow: List[str],
flow_wait: bool = False,
flow_descr: Optional[str] = None
Expand All @@ -1591,9 +1642,9 @@ def set( # noqa: A003
Args:
items: task ID match patterns
prerequisites: prerequisites to set
prereqs: prerequisites to set
outputs: outputs to set and spawn children of
flow: Flow numbers for spawned or merged tasks
flow: flow numbers for spawned or merged tasks
flow_wait: wait for flows to catch up before continuing
flow_descr: description of new flow
Expand All @@ -1603,77 +1654,40 @@ def set( # noqa: A003
# Illegal flow command opts
return

# Get matching pool tasks and future task definitions.
itasks, future_tasks, unmatched = self.filter_task_proxies(
items,
future=True,
warn=False,
)

# pool tasks
for itask in itasks:
self.merge_flows(itask, flow_nums)
if not outputs and not prerequisites:
# Default: set required outputs.
outputs = itask.tdef.get_required_outputs()
if outputs:
self._set_outputs_itask(itask, outputs)
if prerequisites:
if prereqs:
self._set_prereqs_itask(
itask, prerequisites, flow_nums, flow_wait)
itask, prereqs, flow_nums, flow_wait)
else:
self._set_outputs_itask(
itask, outputs or itask.tdef.get_required_outputs())

# future task definitions
for name, point in future_tasks:
taskdef = self.config.get_taskdef(name)
if not outputs and not prerequisites:
# Default: set required outputs.
outputs = taskdef.get_required_outputs()
if outputs:
trans = self._spawn_transient_task(
point, taskdef, outputs, flow_nums, flow_wait
)
if trans is not None:
self._set_outputs_itask(trans, outputs)
if prerequisites:
tdef = self.config.get_taskdef(name)
if prereqs:
self._set_prereqs_tdef(
point, taskdef, prerequisites, flow_nums, flow_wait)

def _spawn_transient_task(
self,
point: 'PointBase',
taskdef: 'TaskDef',
outputs: List[str],
flow_nums: 'FlowNums',
flow_wait: bool
) -> Optional['TaskProxy']:
"""Spawn a transient task proxy and update its outputs from the DB."""

itask = self.spawn_task(
taskdef.name,
point,
flow_nums,
flow_wait=flow_wait,
force=True,
transient=True
)
if itask is not None:
# Update outputs that were already completed.
for outputs_str, fnums in (
self.workflow_db_mgr.pri_dao.select_task_outputs(
itask.tdef.name, str(itask.point))
).items():
if flow_nums.intersection(fnums):
for msg in json.loads(outputs_str):
itask.state.outputs.set_completed_by_msg(msg)
break

return itask
point, tdef, prereqs, flow_nums, flow_wait)
else:
trans = self._spawn_transient(
point, tdef, flow_nums, flow_wait)
if trans is not None:
self._set_outputs_itask(
trans, outputs or itask.tdef.get_required_outputs())

def _set_outputs_itask(
self,
itask: 'TaskProxy',
req_outputs: List[str],
) -> None:
"""Set requested outputs on a task and spawn their children."""
"""Set requested outputs on a task and spawn its children."""

# TODO TIDIER "set:"" LOG MESSAGES

Expand Down Expand Up @@ -1704,87 +1718,39 @@ def _set_outputs_itask(

if itask.transient:
# (note tasks states table gets updated from the task pool)
LOG.warning(f"TWAT {itask}")
self.workflow_db_mgr.put_update_task_state(itask)

def _get_valid_prereqs(self, prereqs, taskdef, point):
"""Get valid prerequisites for a task.
Spawn a transient task proxy without incrementing submit
number or checking the flow.
"""
available = set()
itask = TaskProxy(self.tokens, taskdef, point, transient=True)
for p in itask.state.prerequisites:
for pp in p.satisfied.keys():
available.add(pp)

requested = set()
for p in prereqs:
t = Tokens(p, relative=True)
# Default to :succeeded
t['task_sel'] = t['task_sel'] or TASK_OUTPUT_SUCCEEDED
requested.add((t['cycle'], t['task'], t['task_sel']))

good = available & requested
bad = requested - available
if bad:
for b in bad:
LOG.warning(
f"{point}/{taskdef.name} does not depend on"
f" {b[0]}/{b[1]}:{b[2]}"
)

return good

def _set_prereqs_itask(self, itask, prereqs, flow_nums, flow_wait):
"""Set prerequisites of a task in the pool."""
"""Set prerequisites on a task in the pool.
"""
if prereqs == ["all"]:
itask.state.set_all_satisfied()
else:
itask.satisfy_me(
# TODO: IS THIS NEEDED? (JUST LOG BAD ONES FROM SATISFY_ME?)
self._get_valid_prereqs(prereqs, itask.tdef, itask.point)
)

itask.satisfy_me(prereqs_str_to_tokens(prereqs))
self.data_store_mgr.delta_task_prerequisite(itask)

# if (
# self.runahead_limit_point is not None
# and itask.point <= self.runahead_limit_point
# ):
# self.rh_release_and_queue(itask)

def _set_prereqs_tdef(self, point, taskdef, prereqs, flow_nums, flow_wait):
"""Set given prerequisites of a future task."""
"""Spawn a future task and set specified prerequisites on it.
"""
itask = self.spawn_task(taskdef.name, point, flow_nums, flow_wait)
if itask is None:
# E.g. already spawned in flow.
return
if prereqs == ["all"]:
itask.state.set_all_satisfied()
else:
itask.satisfy_me(
self._get_valid_prereqs(prereqs, taskdef, point)
)

self.data_store_mgr.delta_task_prerequisite(itask)
self.add_to_pool(itask)
if (
self.runahead_limit_point is not None
and itask.point <= self.runahead_limit_point
):
self.rh_release_and_queue(itask)
self._set_prereqs_itask(itask, prereqs, flow_nums, flow_wait)

def _get_active_flow_nums(self) -> Set[int]:
"""Return all active, or most recent previous, flow numbers.
"""Return active flow numbers.
If there are no active flows (e.g. on restarting a completed workflow)
return the most recent active flows.
If there are any active flows, return all active flow numbers.
Otherwise (e.g. on restarting a completed workflow) return
the flow numbers of the most recent previous active task.
"""
fnums = set()
for itask in self.get_tasks():
Expand Down
5 changes: 4 additions & 1 deletion cylc/flow/task_proxy.py
Original file line number Diff line number Diff line change
Expand Up @@ -522,7 +522,10 @@ def satisfy_me(self, prereqs) -> bool:
"""
bad = self.state.satisfy_me(prereqs)
for err in bad:
LOG.warning(f"{self.identity} has no prerequisites {err}")
LOG.warning(
f"{self.identity} does not depend on"
f" {err[0]}/{err[1]}:{err[2]}"
)
return len(bad) == 0

def clock_expire(self) -> bool:
Expand Down

0 comments on commit f928d6e

Please sign in to comment.