diff --git a/daisy/ready_surface.py b/daisy/ready_surface.py index 6fe39289..ff949c3e 100644 --- a/daisy/ready_surface.py +++ b/daisy/ready_surface.py @@ -16,7 +16,7 @@ class ReadySurface: 2) it has downstream dependencies marked as OTHER A node is a BOUNDARY node iff 1) it has been marked as failed - 2) It has upstream dependencies marked as SURFACE + 2) it has upstream dependencies marked as SURFACE """ def __init__(self, get_downstream_nodes, get_upstream_nodes): @@ -103,7 +103,7 @@ def mark_failure(self, node, count_all_orphans=False): # recurse through downstream nodes, adding them to boundary if # necessary down_nodes = set(self.downstream(node)) - orphans = set(down_nodes) + orphans = set() while len(down_nodes) > 0: down_node = down_nodes.pop() if self.__add_to_boundary(down_node): @@ -111,11 +111,7 @@ def mark_failure(self, node, count_all_orphans=False): # nodes. new_nodes = set(self.downstream(down_node)) - orphans down_nodes = down_nodes.union(new_nodes) - orphans = orphans.union(new_nodes) - elif count_all_orphans: - new_nodes = set(self.downstream(down_node)) - orphans - down_nodes = down_nodes.union(new_nodes) - orphans - orphans.union(new_nodes) + orphans.add(down_node) # check if any of the upstream nodes can be removed from surface for up_node in up_nodes: diff --git a/tests/test_scheduler.py b/tests/test_scheduler.py index c6b4f555..ddc67728 100644 --- a/tests/test_scheduler.py +++ b/tests/test_scheduler.py @@ -513,3 +513,31 @@ def test_zero_levels_failure(task_zero_levels): task_state.failed_count + task_state.orphaned_count == task_state.total_block_count ), task_state + +def test_orphan_double_counting(): + def process_block(block): + pass + + task = Task( + task_id="test_orphans", + total_roi=Roi((0, 0), (25, 25)), + read_roi=Roi((0, 0), (7, 7)), + write_roi=Roi((3, 3), (1, 1)), + process_function=process_block, + check_function=None, + read_write_conflict=True, + ) + scheduler = Scheduler([task]) + + while True: + block = scheduler.acquire_block(task.task_id) + if block is None: + break + block.status = BlockStatus.FAILED + scheduler.release_block(block) + + task_state = scheduler.task_states[task.task_id] + assert ( + task_state.failed_count + task_state.orphaned_count + == task_state.total_block_count + ), task_state