Skip to content

Commit

Permalink
Set prerequisites by output message or label.
Browse files Browse the repository at this point in the history
  • Loading branch information
hjoliver committed Feb 19, 2024
1 parent a5e196a commit 379bba0
Show file tree
Hide file tree
Showing 6 changed files with 138 additions and 73 deletions.
10 changes: 10 additions & 0 deletions cylc/flow/scripts/set.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@
# set multiple prerequisites at once:
$ cylc set --pre=3/foo:x --pre=3/foo:y,3/foo:z my_workflow//3/bar
"""

from functools import partial
Expand Down Expand Up @@ -155,6 +156,9 @@ def validate_prereq(prereq: str) -> bool:
>>> validate_prereq('1/foo:succeeded')
True
>>> validate_prereq('1/foo') # succeeded
True
>>> validate_prereq('1/foo::succeeded')
False
Expand All @@ -164,6 +168,12 @@ def validate_prereq(prereq: str) -> bool:
>>> validate_prereq('fish')
False
>>> validate_prereq('file1 ready')
False
>>> validate_prereq('1/foo:file1 ready')
True
"""
try:
tokens = Tokens(prereq, relative=True)
Expand Down
12 changes: 0 additions & 12 deletions cylc/flow/task_outputs.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,18 +118,6 @@ def get_all(self):
"""Return an iterator for all output messages."""
return sorted(self._by_message.values(), key=self.msg_sort_key)

def get_msg(self, out):
"""Translate a message or label into message, or None if not valid."""
if out in self._by_message:
# It's already a valid message.
return out
elif out in self._by_trigger:
# It's a valid trigger label, return the message.
return (self._by_trigger[out])[1]
else:
# Not a valid message or trigger label.
return None

def get_completed(self):
"""Return all completed output messages."""
ret = []
Expand Down
115 changes: 78 additions & 37 deletions cylc/flow/task_pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,6 @@
from cylc.flow.task_events_mgr import TaskEventsManager
from cylc.flow.workflow_db_mgr import WorkflowDatabaseManager
from cylc.flow.flow_mgr import FlowMgr, FlowNums
from typing_extensions import Literal


Pool = Dict['PointBase', Dict[str, TaskProxy]]
Expand Down Expand Up @@ -1687,6 +1686,41 @@ def _get_task_proxy(

return itask

def _standardise_prereqs(self, prereqs: 'List[str]') -> 'List[Tokens]':
"""Return prerequisites as output messages (not trigger labels).
"""
_prereqs = []
for pre in [
Tokens(prereq, relative=True)
for prereq in (prereqs or [])
]:
# Convert trigger labels to output messages
msg = self.config.get_taskdef(
pre['task']
).get_output_msg(pre['task_sel'])
if msg is None:
LOG.warning(
f"output {pre.relative_id_with_selectors} not found")
continue
_prereqs.append(pre.duplicate(task_sel=msg))
return _prereqs

def _standardise_outputs(
self, point: 'PointBase', tdef: 'TaskDef', outputs: List[str]
) -> List[str]:
"""Return outputs as output messages (not trigger labels).
"""
_outputs = []
for output in outputs:
msg = tdef.get_output_msg(output)
if msg is None:
LOG.warning(f"output {point}/{tdef.name}:{output} not found")
continue
_outputs.append(msg)
return _outputs

def set_prereqs_and_outputs(
self,
items: Iterable[str],
Expand Down Expand Up @@ -1720,8 +1754,8 @@ def set_prereqs_and_outputs(
Args:
items: task ID match patterns
prereqs: prerequisites to set
outputs: outputs to set and spawn children of
prereqs: prerequisites (as output message or trigger label) to set
outputs: outputs (as output message or trigger label) to set
flow: flow numbers for spawned or merged tasks
flow_wait: wait for flows to catch up before continuing
flow_descr: description of new flow
Expand All @@ -1732,14 +1766,15 @@ def set_prereqs_and_outputs(
# Illegal flow command opts
return

_prereqs: 'Union[List[Tokens], Literal["all"]]'
if prereqs == ['all']:
_prereqs = 'all'
if prereqs == ["all"]:
prereqs_all = True
prereqs2 = []
elif prereqs is not None:
prereqs_all = False
prereqs2 = self._standardise_prereqs(prereqs)
else:
_prereqs = [
Tokens(prereq, relative=True)
for prereq in (prereqs or [])
]
prereqs_all = False
prereqs2 = []

# Get matching pool tasks and future task definitions.
itasks, future_tasks, unmatched = self.filter_task_proxies(
Expand All @@ -1749,52 +1784,56 @@ def set_prereqs_and_outputs(
)

for itask in itasks:
# Tasks already in the pool.
self.merge_flows(itask, flow_nums)
if _prereqs:
if prereqs2:
self._set_prereqs_itask(
itask, _prereqs, flow_nums, flow_wait)
itask, prereqs_all, prereqs2, flow_nums, flow_wait)
else:
self._set_outputs_itask(itask, outputs)
if not outputs:
outputs = itask.tdef.get_required_outputs()
else:
outputs = self._standardise_outputs(
itask.point, itask.tdef, outputs)
if outputs:
self._set_outputs_itask(itask, outputs)

for name, point in future_tasks:
# Future tasks.
tdef = self.config.get_taskdef(name)
if _prereqs:
if prereqs2:
self._set_prereqs_tdef(
point, tdef, _prereqs, flow_nums, flow_wait)
point, tdef, prereqs2, flow_nums, flow_wait)
else:
trans = self._get_task_proxy(
point, tdef, flow_nums, flow_wait, transient=True)
if trans is not None:
self._set_outputs_itask(trans, outputs)
if not outputs:
outputs = tdef.get_required_outputs()
else:
outputs = self._standardise_outputs(point, tdef, outputs)
if outputs:
trans = self._get_task_proxy(
point, tdef, flow_nums, flow_wait, transient=True)
if trans is not None:
self._set_outputs_itask(trans, outputs)

if self.compute_runahead():
self.release_runahead_tasks()

def _set_outputs_itask(
self,
itask: 'TaskProxy',
outputs: Optional[Iterable[str]],
outputs: Iterable[str],
) -> None:
"""Set requested outputs on a task proxy and spawn children."""

# Default to required outputs.
outputs = outputs or itask.tdef.get_required_outputs()

changed = False
for output in outputs:
# convert trigger label to output message
msg = itask.state.outputs.get_msg(output)
info = f'set: output {itask.identity}:{output}'
if msg is None:
LOG.warning(f"{info} not found")
continue
if itask.state.outputs.is_completed(msg):
LOG.info(f"{info} completed already")
if itask.state.outputs.is_completed(output):
LOG.info(f"output {itask.identity}:{output} completed already")
continue
changed = True
self.task_events_mgr.process_message(
itask, logging.INFO, msg, forced=True)
LOG.info(f"{info} completed")
itask, logging.INFO, output, forced=True)
changed = True
LOG.info(f"output {itask.identity}:{output} completed")

if changed and itask.transient:
self.workflow_db_mgr.put_update_task_state(itask)
Expand All @@ -1803,7 +1842,8 @@ def _set_outputs_itask(
def _set_prereqs_itask(
self,
itask: 'TaskProxy',
prereqs: 'Union[List[Tokens], Literal["all"]]',
prereqs_all: bool,
prereqs: 'List[Tokens]',
flow_nums: Set[int],
flow_wait: bool
) -> None:
Expand All @@ -1812,10 +1852,11 @@ def _set_prereqs_itask(
Prerequisite format: "cycle/task:message" or "all".
"""
if prereqs == "all":
if prereqs_all:
itask.state.set_all_satisfied()
else:
itask.satisfy_me(prereqs)

if (
self.runahead_limit_point is not None
and itask.point <= self.runahead_limit_point
Expand All @@ -1831,7 +1872,7 @@ def _set_prereqs_tdef(self, point, taskdef, prereqs, flow_nums, flow_wait):
if itask is None:
return
self.add_to_pool(itask)
self._set_prereqs_itask(itask, prereqs, flow_nums, flow_wait)
self._set_prereqs_itask(itask, False, prereqs, flow_nums, flow_wait)

def _get_active_flow_nums(self) -> Set[int]:
"""Return active flow numbers.
Expand Down
23 changes: 16 additions & 7 deletions cylc/flow/taskdef.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
"""Task definition."""

from collections import deque
from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, Optional

import cylc.flow.flags
from cylc.flow.exceptions import TaskDefError
Expand Down Expand Up @@ -185,19 +185,28 @@ def _add_std_outputs(self):
for output in SORT_ORDERS:
self.outputs[output] = (output, None)

def get_output_msg(self, label: str) -> Optional[str]:
"""Return an output message corresponding to a trigger label.
Or return the label, if it is already a message, else None.
"""
outputs = {k: v[0] for k, v in self.outputs.items()}
if label in outputs:
return outputs[label]
elif label in outputs.values():
return label
else:
return None

def set_required_output(self, output, required):
"""Set outputs to required or optional."""
# (Note outputs and associated messages already defined.)
message, _ = self.outputs[output]
self.outputs[output] = (message, required)

def get_required_outputs(self):
"""Return list of required outputs."""
res = []
for out, (_msg, req) in self.outputs.items():
if req:
res.append(out)
return res
"""Return list of required outputs (messages, not trigger labels)."""
return [msg for _lab, (msg, req) in self.outputs.items() if req]

def tweak_outputs(self):
"""Output consistency checking and tweaking."""
Expand Down
5 changes: 3 additions & 2 deletions tests/integration/scripts/test_set.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,8 @@ async def test_data_store(
]

# set the 1/a:succeeded prereq of 1/z
schd.pool.set_prereqs_and_outputs(['1/z'], None, ['1/a:succeeded'], ['1'])
schd.pool.set_prereqs_and_outputs(
['1/z'], None, ['1/a:succeeded'], ['1'])
task_z = data[TASK_PROXIES][
schd.pool.get_task(IntegerPoint('1'), 'z').tokens.id
]
Expand Down Expand Up @@ -159,5 +160,5 @@ async def test_pre_all(flow, scheduler, run):
schd = scheduler(id_, paused_start=False)
async with run(schd) as log:
schd.pool.set_prereqs_and_outputs(['1/z'], [], ['all'], ['all'])
warn_or_higher = [i for i in log.records if i.levelno > 20]
warn_or_higher = [i for i in log.records if i.levelno > 30]
assert warn_or_higher == []
46 changes: 31 additions & 15 deletions tests/integration/test_task_pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -1431,13 +1431,13 @@ async def test_set_outputs_live(
assert log_filter(
log, contains="setting missed output: started")

# set foo (default: all required outputs): complete y.
# set foo (default: all required outputs) to complete y.
schd.pool.set_prereqs_and_outputs(["1/foo"], None, None, ['all'])
assert log_filter(
log, contains="output 1/foo:succeeded completed")
assert (
pool_get_task_ids(schd.pool) == ["1/bar", "1/baz"]
)
assert log_filter(
log, contains="[1/foo/00:succeeded] completed")


async def test_set_outputs_future(
Expand All @@ -1456,7 +1456,15 @@ async def test_set_outputs_future(
},
'scheduling': {
'graph': {
'R1': "a => b => c"
'R1': "a:x & a:y => b => c"
}
},
'runtime': {
'a': {
'outputs': {
'x': 'xylophone',
'y': 'yacht'
}
}
}
}
Expand All @@ -1475,9 +1483,15 @@ async def test_set_outputs_future(
pool_get_task_ids(schd.pool) == ["1/a", "1/c"]
)

# try to set an invalid output
schd.pool.set_prereqs_and_outputs(["1/b"], ["shrub"], None, ['all'])
assert log_filter(log, contains="output 1/b:shrub not found")
schd.pool.set_prereqs_and_outputs(
items=["1/a"],
outputs=["xylophone", "yacht", "cheese"],
prereqs=None,
flow=['all']
)
assert log_filter(log, contains="output 1/a:cheese not found")
assert log_filter(log, contains="output 1/a:xylophone completed")
assert log_filter(log, contains="output 1/a:yacht completed")


async def test_prereq_satisfaction(
Expand All @@ -1486,7 +1500,7 @@ async def test_prereq_satisfaction(
start,
log_filter,
):
"""Check manual setting of future task outputs.
"""Check manual setting of task prerequisites.
"""
id_ = flow(
Expand Down Expand Up @@ -1523,14 +1537,16 @@ async def test_prereq_satisfaction(

assert not b.is_waiting_prereqs_done()

# set valid and invalid prerequisites, check log.
b.satisfy_me([
Tokens(id_, relative=True)
for id_ in ["1/a:x", "1/a:y", "1/a:z", "1/a:w"]
])
assert log_filter(log, contains="1/b does not depend on 1/a:z")
assert log_filter(log, contains="1/b does not depend on 1/a:w")
# set valid and invalid prerequisites, by label and message.
schd.pool.set_prereqs_and_outputs(
prereqs=["1/a:xylophone", "1/a:y", "1/a:w", "1/a:z"],
items=["1/b"], outputs=None, flow=['all']
)
assert log_filter(log, contains="1/a:z not found")
assert log_filter(log, contains="1/a:w not found")
assert not log_filter(log, contains="1/b does not depend on 1/a:x")
assert not log_filter(
log, contains="1/b does not depend on 1/a:xylophone")
assert not log_filter(log, contains="1/b does not depend on 1/a:y")

assert b.is_waiting_prereqs_done()
Expand Down

0 comments on commit 379bba0

Please sign in to comment.