Skip to content

Commit

Permalink
Adding semi directed path test
Browse files Browse the repository at this point in the history
Signed-off-by: Adam Li <[email protected]>
  • Loading branch information
adam2392 committed Nov 21, 2023
1 parent 1afacc0 commit 3aba875
Show file tree
Hide file tree
Showing 3 changed files with 63 additions and 28 deletions.
12 changes: 12 additions & 0 deletions doc/reference/algorithms/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -63,3 +63,15 @@ Algorithms for handling acyclicity
:toctree: ../../generated/

acyclification


***************************************
Semi-directed (possibly-directed) Paths
***************************************

.. automodule:: pywhy_graphs.algorithms.semi_directed_paths
.. autosummary::
:toctree: ../../generated/

all_semi_directed_paths
is_semi_directed_path
42 changes: 25 additions & 17 deletions pywhy_graphs/algorithms/semi_directed_paths.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,39 +148,47 @@ def _all_semi_directed_paths_graph(
# iterate over neighbors of source
stack = [iter(G.neighbors(source))]

# XXX: figure out how to update prev_node for efficient DFS
prev_node = source
prev_nodes = [source]
# if source has no neighbors, then prev_nodes should be None
if not prev_nodes:
prev_nodes = [None]

while stack:
# get the iterator through children for the current node
children = stack[-1]
child = next(children, None)
# get the iterator through nbrs for the current node
nbrs = stack[-1]
prev_node = prev_nodes[-1]
nbr = next(nbrs, None)

# The first condition guarantees that there is not a directed endpoint
# along the path from source to target that points towards source.
if child is None or (
G.has_edge(child, prev_node, directed_edge_name)
or G.has_edge(child, prev_node, bidirected_edge_name)
):
# If we've found a directed edge from child to prev_node
if (
G.has_edge(nbr, prev_node, directed_edge_name)
or G.has_edge(nbr, prev_node, bidirected_edge_name)
) and nbr not in visited:
# If we've found a directed edge from child to prev_node,
# that we haven't visited, then we don't need to continue down this path
continue
elif nbr is None:
# once all children are visited, pop the stack
# and remove the child from the visited set
stack.pop()
visited.popitem()
prev_nodes.pop()
elif len(visited) < cutoff:
if child in visited:
if nbr in visited:
continue
if child in targets:
if nbr in targets:
# we've found a path to a target
yield list(visited) + [child]
visited[child] = True
yield list(visited) + [nbr]
visited[nbr] = True
if targets - set(visited.keys()): # expand stack until find all targets

stack.append(iter(G.neighbors(child)))
stack.append(iter(G.neighbors(nbr)))
prev_nodes.append(nbr)
else:
visited.popitem() # maybe other ways to child
else: # len(visited) == cutoff:
for target in (targets & (set(children) | {child})) - set(visited.keys()):
for target in (targets & (set(nbrs) | {nbr})) - set(visited.keys()):
yield list(visited) + [target]
stack.pop()
visited.popitem()
prev_nodes.pop()
37 changes: 26 additions & 11 deletions pywhy_graphs/algorithms/tests/test_semi_directed_paths.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,13 @@ def sample_mixed_edge_graph():
G = pywhy_nx.MixedEdgeGraph(
graphs=[directed_G, bidirected_G], edge_types=["directed", "bidirected"], name="IV Graph"
)

G.add_edge_type(nx.DiGraph(), "circle")
G.add_edge("A", "Z", edge_type="directed")
G.add_edge("Z", "A", edge_type="circle")
G.add_edge("A", "B", edge_type="circle")
G.add_edge("B", "A", edge_type="circle")
G.add_edge("B", "Z", edge_type="circle")
return G


Expand All @@ -30,7 +37,7 @@ def test_single_node_path(sample_mixed_edge_graph):

def test_nonexistent_node_path(sample_mixed_edge_graph):
G = sample_mixed_edge_graph
assert not is_semi_directed_path(G, ["A", "B"])
assert not is_semi_directed_path(G, ["1", "2"])


def test_repeated_nodes_path(sample_mixed_edge_graph):
Expand All @@ -40,9 +47,6 @@ def test_repeated_nodes_path(sample_mixed_edge_graph):

def test_valid_semi_directed_path(sample_mixed_edge_graph):
G = sample_mixed_edge_graph
G.add_edge("A", "Z", edge_type="directed")
G.add_edge_type(nx.DiGraph(), "circle")
G.add_edge("Z", "A", edge_type="circle")
assert is_semi_directed_path(G, ["Z", "X"])
assert is_semi_directed_path(G, ["A", "Z", "X"])

Expand All @@ -58,9 +62,9 @@ def test_invalid_semi_directed_path(sample_mixed_edge_graph):

def test_empty_paths(sample_mixed_edge_graph):
G = sample_mixed_edge_graph
source = "A"
source = "1"
target = "B"
with pytest.raises(nx.NodeNotFound, match="source node A not in graph"):
with pytest.raises(nx.NodeNotFound, match=f"source node {source} not in graph"):
all_semi_directed_paths(G, source, target)

G.add_node(source)
Expand All @@ -80,25 +84,36 @@ def test_no_paths(sample_mixed_edge_graph):

def test_multiple_paths(sample_mixed_edge_graph):
G = sample_mixed_edge_graph
G.add_edge_type(nx.DiGraph(), "circle")
G.add_edge("A", "Z", edge_type="directed")
G.add_edge("A", "B", edge_type="circle")
G.add_edge("B", "A", edge_type="circle")
G.add_edge("B", "Z", edge_type="circle")

source = "A"
target = "X"
cutoff = 3
paths = all_semi_directed_paths(G, source, target, cutoff)
paths = list(paths)

dig = nx.path_graph(5, create_using=nx.DiGraph())
G.add_edges_from(dig.edges(), edge_type="directed")
G.add_edge("A", 0, edge_type="circle")

assert len(paths) == 2
assert all(path in paths for path in [["A", "Z", "X"], ["A", "B", "Z", "X"]])

# for a short cutoff, there is only one path
cutoff = 2
paths = all_semi_directed_paths(G, source, target, cutoff)
assert all(path in paths for path in [["A", "Z", "X"]])

# for an even shorter cutoff, there are no paths now
cutoff = 1
paths = all_semi_directed_paths(G, source, target, cutoff)
assert list(paths) == []


def test_long_cutoff(sample_mixed_edge_graph):
G = sample_mixed_edge_graph
source = "Z"
target = "X"
cutoff = 10 # Cutoff longer than the actual path length
print(G.edges())
paths = all_semi_directed_paths(G, source, target, cutoff)
assert list(paths) == [[source, target]]

0 comments on commit 3aba875

Please sign in to comment.