diff --git a/doc/reference/algorithms/index.rst b/doc/reference/algorithms/index.rst index 270638c59..5c65472a6 100644 --- a/doc/reference/algorithms/index.rst +++ b/doc/reference/algorithms/index.rst @@ -63,3 +63,15 @@ Algorithms for handling acyclicity :toctree: ../../generated/ acyclification + + +*************************************** +Semi-directed (possibly-directed) Paths +*************************************** + +.. automodule:: pywhy_graphs.algorithms.semi_directed_paths +.. autosummary:: + :toctree: ../../generated/ + + all_semi_directed_paths + is_semi_directed_path diff --git a/pywhy_graphs/algorithms/semi_directed_paths.py b/pywhy_graphs/algorithms/semi_directed_paths.py index 7e17743ee..327506683 100644 --- a/pywhy_graphs/algorithms/semi_directed_paths.py +++ b/pywhy_graphs/algorithms/semi_directed_paths.py @@ -148,39 +148,47 @@ def _all_semi_directed_paths_graph( # iterate over neighbors of source stack = [iter(G.neighbors(source))] - # XXX: figure out how to update prev_node for efficient DFS - prev_node = source + prev_nodes = [source] + # if source has no neighbors, then prev_nodes should be None + if not prev_nodes: + prev_nodes = [None] while stack: - # get the iterator through children for the current node - children = stack[-1] - child = next(children, None) + # get the iterator through nbrs for the current node + nbrs = stack[-1] + prev_node = prev_nodes[-1] + nbr = next(nbrs, None) # The first condition guarantees that there is not a directed endpoint # along the path from source to target that points towards source. - if child is None or ( - G.has_edge(child, prev_node, directed_edge_name) - or G.has_edge(child, prev_node, bidirected_edge_name) - ): - # If we've found a directed edge from child to prev_node + if ( + G.has_edge(nbr, prev_node, directed_edge_name) + or G.has_edge(nbr, prev_node, bidirected_edge_name) + ) and nbr not in visited: + # If we've found a directed edge from child to prev_node, + # that we haven't visited, then we don't need to continue down this path + continue + elif nbr is None: # once all children are visited, pop the stack # and remove the child from the visited set stack.pop() visited.popitem() + prev_nodes.pop() elif len(visited) < cutoff: - if child in visited: + if nbr in visited: continue - if child in targets: + if nbr in targets: # we've found a path to a target - yield list(visited) + [child] - visited[child] = True + yield list(visited) + [nbr] + visited[nbr] = True if targets - set(visited.keys()): # expand stack until find all targets - - stack.append(iter(G.neighbors(child))) + stack.append(iter(G.neighbors(nbr))) + prev_nodes.append(nbr) else: visited.popitem() # maybe other ways to child else: # len(visited) == cutoff: - for target in (targets & (set(children) | {child})) - set(visited.keys()): + for target in (targets & (set(nbrs) | {nbr})) - set(visited.keys()): yield list(visited) + [target] stack.pop() visited.popitem() + prev_nodes.pop() diff --git a/pywhy_graphs/algorithms/tests/test_semi_directed_paths.py b/pywhy_graphs/algorithms/tests/test_semi_directed_paths.py index 2006082a0..c84eadf6a 100644 --- a/pywhy_graphs/algorithms/tests/test_semi_directed_paths.py +++ b/pywhy_graphs/algorithms/tests/test_semi_directed_paths.py @@ -15,6 +15,13 @@ def sample_mixed_edge_graph(): G = pywhy_nx.MixedEdgeGraph( graphs=[directed_G, bidirected_G], edge_types=["directed", "bidirected"], name="IV Graph" ) + + G.add_edge_type(nx.DiGraph(), "circle") + G.add_edge("A", "Z", edge_type="directed") + G.add_edge("Z", "A", edge_type="circle") + G.add_edge("A", "B", edge_type="circle") + G.add_edge("B", "A", edge_type="circle") + G.add_edge("B", "Z", edge_type="circle") return G @@ -30,7 +37,7 @@ def test_single_node_path(sample_mixed_edge_graph): def test_nonexistent_node_path(sample_mixed_edge_graph): G = sample_mixed_edge_graph - assert not is_semi_directed_path(G, ["A", "B"]) + assert not is_semi_directed_path(G, ["1", "2"]) def test_repeated_nodes_path(sample_mixed_edge_graph): @@ -40,9 +47,6 @@ def test_repeated_nodes_path(sample_mixed_edge_graph): def test_valid_semi_directed_path(sample_mixed_edge_graph): G = sample_mixed_edge_graph - G.add_edge("A", "Z", edge_type="directed") - G.add_edge_type(nx.DiGraph(), "circle") - G.add_edge("Z", "A", edge_type="circle") assert is_semi_directed_path(G, ["Z", "X"]) assert is_semi_directed_path(G, ["A", "Z", "X"]) @@ -58,9 +62,9 @@ def test_invalid_semi_directed_path(sample_mixed_edge_graph): def test_empty_paths(sample_mixed_edge_graph): G = sample_mixed_edge_graph - source = "A" + source = "1" target = "B" - with pytest.raises(nx.NodeNotFound, match="source node A not in graph"): + with pytest.raises(nx.NodeNotFound, match=f"source node {source} not in graph"): all_semi_directed_paths(G, source, target) G.add_node(source) @@ -80,25 +84,36 @@ def test_no_paths(sample_mixed_edge_graph): def test_multiple_paths(sample_mixed_edge_graph): G = sample_mixed_edge_graph - G.add_edge_type(nx.DiGraph(), "circle") - G.add_edge("A", "Z", edge_type="directed") - G.add_edge("A", "B", edge_type="circle") - G.add_edge("B", "A", edge_type="circle") - G.add_edge("B", "Z", edge_type="circle") source = "A" target = "X" cutoff = 3 paths = all_semi_directed_paths(G, source, target, cutoff) paths = list(paths) + + dig = nx.path_graph(5, create_using=nx.DiGraph()) + G.add_edges_from(dig.edges(), edge_type="directed") + G.add_edge("A", 0, edge_type="circle") + assert len(paths) == 2 assert all(path in paths for path in [["A", "Z", "X"], ["A", "B", "Z", "X"]]) + # for a short cutoff, there is only one path + cutoff = 2 + paths = all_semi_directed_paths(G, source, target, cutoff) + assert all(path in paths for path in [["A", "Z", "X"]]) + + # for an even shorter cutoff, there are no paths now + cutoff = 1 + paths = all_semi_directed_paths(G, source, target, cutoff) + assert list(paths) == [] + def test_long_cutoff(sample_mixed_edge_graph): G = sample_mixed_edge_graph source = "Z" target = "X" cutoff = 10 # Cutoff longer than the actual path length + print(G.edges()) paths = all_semi_directed_paths(G, source, target, cutoff) assert list(paths) == [[source, target]]