Skip to content

Commit

Permalink
Update wip for semi directed paths
Browse files Browse the repository at this point in the history
Signed-off-by: Adam Li <[email protected]>
  • Loading branch information
adam2392 committed Nov 21, 2023
1 parent 7151bed commit 1afacc0
Show file tree
Hide file tree
Showing 8 changed files with 308 additions and 11 deletions.
2 changes: 2 additions & 0 deletions doc/api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,8 @@ causal graph operations.
find_connected_pairs
add_all_snode_combinations
compute_invariant_domains_per_node
is_semi_directed_path
all_semi_directed_paths

Conversions between other package's causal graphs
=================================================
Expand Down
1 change: 1 addition & 0 deletions doc/whats_new/v0.2.rst
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ Changelog
- |Feature| Implement and test functions to convert a DAG to MAG, by `Aryan Roy`_ (:pr:`96`)
- |Feature| Implement and test functions to convert a PAG to MAG, by `Aryan Roy`_ (:pr:`93`)
- |API| Remove support for Python 3.8 by `Adam Li`_ (:pr:`99`)
- |Feature| Implement a suite of functions for finding and checking semi-directed paths on a mixed-edge graph, by `Adam Li`_ (:pr:``)
Code and Documentation Contributors
-----------------------------------
Expand Down
9 changes: 2 additions & 7 deletions examples/mixededge/plot_mixed_edge_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,12 +34,8 @@
# %%
# Construct a MixedEdgeGraph
# --------------------------
# Using the ``MixedEdgeGraph``, we can represent a causal graph
# with two different kinds of edges. To create the graph, we
# use networkx ``nx.DiGraph`` class to represent directed edges,
# and ``nx.Graph`` class to represent edges without directions (i.e.
# bidirected edges). The edge types are then specified, so the mixed edge
# graph object knows which graphs are associated with which types of edges.
# Here we demonstrate how to construct a mixed edge graph
# by composing networkx graphs.

directed_G = nx.DiGraph(
[
Expand All @@ -60,7 +56,6 @@
name="IV Graph",
)

# Compute the multipartite_layout using the "layer" node attribute
pos = nx.spring_layout(G)

# we can then visualize the mixed-edge graph
Expand Down
1 change: 1 addition & 0 deletions pywhy_graphs/algorithms/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,3 +2,4 @@
from .generic import * # noqa: F403
from .multidomain import * # noqa: F403
from .pag import * # noqa: F403
from .semi_directed_paths import * # noqa: F403
6 changes: 3 additions & 3 deletions pywhy_graphs/algorithms/pag.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,9 +27,9 @@


def _possibly_directed(G: PAG, i: Node, j: Node, reverse: bool = False):
"""Check that path is possibly directed.
"""Check that edge is possibly directed.
A possibly directed path is one of the form:
A possibly directed edge is one of the form:
- ``i -> j``
- ``i o-> j``
- ``i o-o j``
Expand Down Expand Up @@ -64,7 +64,7 @@ def _possibly_directed(G: PAG, i: Node, j: Node, reverse: bool = False):

# the direct check checks for i *-> j or i <-* j
# i <-> j is also checked
# everything else is valid
# everything else is valid; i.e. i -- j, or i o-o j
if direct_check or G.has_edge(i, j, G.bidirected_edge_name):
return False
return True
Expand Down
186 changes: 186 additions & 0 deletions pywhy_graphs/algorithms/semi_directed_paths.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,186 @@
import networkx as nx

from ..config import EdgeType
from ..typing import Node

__all__ = [
"is_semi_directed_path",
"all_semi_directed_paths",
]


def _empty_generator():
yield from ()


def is_semi_directed_path(G, nodes):
"""Returns True if and only if `nodes` form a semi-directed path in `G`.
A *semi-directed path* in a graph is a nonempty sequence of nodes in which
no node appears more than once in the sequence, each adjacent
pair of nodes in the sequence is adjacent in the graph and where each
pair of adjacent nodes does not contain a directed endpoint in the direction
towards the start of the sequence.
That is ``(a -> b o-> c <-> d -> e)`` is not a semi-directed path from ``a`` to ``e``
because ``d *-> c`` is a directed endpoint in the direction towards ``a``.
Parameters
----------
G : graph
A mixed-edge graph.
nodes : list
A list of one or more nodes in the graph `G`.
Returns
-------
bool
Whether the given list of nodes represents a simple path in `G`.
Notes
-----
This function is very similar to networkx's
:func:`networkx.algorithms.is_simple_path` function.
"""
# The empty list is not a valid path. Could also return
# NetworkXPointlessConcept here.
if len(nodes) == 0:
return False

# If the list is a single node, just check that the node is actually
# in the graph.
if len(nodes) == 1:
return nodes[0] in G

# check that all nodes in the list are in the graph, if at least one
# is not in the graph, then this is not a simple path
if not all(n in G for n in nodes):
return False

# If the list contains repeated nodes, then it's not a simple path
if len(set(nodes)) != len(nodes):
return False

# Test that each adjacent pair of nodes is adjacent and that there
# is no directed endpoint towards the beginning of the sequence.
for idx in range(len(nodes) - 1):
u, v = nodes[idx], nodes[idx + 1]
if G.has_edge(v, u, EdgeType.DIRECTED.value) or G.has_edge(v, u, EdgeType.BIDIRECTED.value):
return False
elif not G.has_edge(u, v):
return False
return True


def all_semi_directed_paths(G, source: Node, target: Node, cutoff: int = None):
"""Generate all semi-directed paths from source to target in G.
A semi-directed path is a path from ``source`` to ``target`` in that
no end-point is directed from ``target`` to ``source``. I.e.
``target *-> source`` does not exist.
Parameters
----------
G : Graph
The graph.
source : Node
The source node.
target : Node
The target node.
cutoff : integer, optional
Depth to stop the search. Only paths of length <= cutoff are returned.
Notes
-----
This algorithm is very similar to networkx's
:func:`networkx.algorithms.all_simple_paths` function.
This algorithm uses a modified depth-first search to generate the
paths [1]_. A single path can be found in $O(V+E)$ time but the
number of simple paths in a graph can be very large, e.g. $O(n!)$ in
the complete graph of order $n$.
This function does not check that a path exists between `source` and
`target`. For large graphs, this may result in very long runtimes.
Consider using `has_path` to check that a path exists between `source` and
`target` before calling this function on large graphs.
References
----------
.. [1] R. Sedgewick, "Algorithms in C, Part 5: Graph Algorithms",
Addison Wesley Professional, 3rd ed., 2001.
"""
if source not in G:
raise nx.NodeNotFound("source node %s not in graph" % source)
if target in G:
targets = {target}
else:
try:
targets = set(target) # type: ignore
except TypeError:
raise nx.NodeNotFound("target node %s not in graph" % target)
if source in targets:
return []
if cutoff is None:
cutoff = len(G) - 1
if cutoff < 1:
return []
if source in targets:
return _empty_generator()
if cutoff is None:
cutoff = len(G) - 1
if cutoff < 1:
return _empty_generator()

return _all_semi_directed_paths_graph(G, source, targets, cutoff)


def _all_semi_directed_paths_graph(
G, source, targets, cutoff, directed_edge_name="directed", bidirected_edge_name="bidirected"
):
"""See networkx's all_simple_paths function.
This performs a depth-first search for all semi-directed paths from source to target.
"""
# memoize each node that was already visited
visited = {source: True}

# iterate over neighbors of source
stack = [iter(G.neighbors(source))]

# XXX: figure out how to update prev_node for efficient DFS
prev_node = source

while stack:
# get the iterator through children for the current node
children = stack[-1]
child = next(children, None)

# The first condition guarantees that there is not a directed endpoint
# along the path from source to target that points towards source.
if child is None or (
G.has_edge(child, prev_node, directed_edge_name)
or G.has_edge(child, prev_node, bidirected_edge_name)
):
# If we've found a directed edge from child to prev_node
# once all children are visited, pop the stack
# and remove the child from the visited set
stack.pop()
visited.popitem()
elif len(visited) < cutoff:
if child in visited:
continue
if child in targets:
# we've found a path to a target
yield list(visited) + [child]
visited[child] = True
if targets - set(visited.keys()): # expand stack until find all targets

stack.append(iter(G.neighbors(child)))
else:
visited.popitem() # maybe other ways to child
else: # len(visited) == cutoff:
for target in (targets & (set(children) | {child})) - set(visited.keys()):
yield list(visited) + [target]
stack.pop()
visited.popitem()
104 changes: 104 additions & 0 deletions pywhy_graphs/algorithms/tests/test_semi_directed_paths.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
import networkx as nx
import pytest

import pywhy_graphs.networkx as pywhy_nx
from pywhy_graphs.algorithms import all_semi_directed_paths, is_semi_directed_path


# Fixture to create a sample mixed-edge graph for testing
@pytest.fixture
def sample_mixed_edge_graph():
directed_G = nx.DiGraph([("X", "Y"), ("Z", "X")])
bidirected_G = nx.Graph([("X", "Y")])
directed_G.add_nodes_from(bidirected_G.nodes)
bidirected_G.add_nodes_from(directed_G.nodes)
G = pywhy_nx.MixedEdgeGraph(
graphs=[directed_G, bidirected_G], edge_types=["directed", "bidirected"], name="IV Graph"
)
return G


def test_empty_path_not_semi_directed(sample_mixed_edge_graph):
G = sample_mixed_edge_graph
assert not is_semi_directed_path(G, [])


def test_single_node_path(sample_mixed_edge_graph):
G = sample_mixed_edge_graph
assert is_semi_directed_path(G, ["X"])


def test_nonexistent_node_path(sample_mixed_edge_graph):
G = sample_mixed_edge_graph
assert not is_semi_directed_path(G, ["A", "B"])


def test_repeated_nodes_path(sample_mixed_edge_graph):
G = sample_mixed_edge_graph
assert not is_semi_directed_path(G, ["X", "Y", "X"])


def test_valid_semi_directed_path(sample_mixed_edge_graph):
G = sample_mixed_edge_graph
G.add_edge("A", "Z", edge_type="directed")
G.add_edge_type(nx.DiGraph(), "circle")
G.add_edge("Z", "A", edge_type="circle")
assert is_semi_directed_path(G, ["Z", "X"])
assert is_semi_directed_path(G, ["A", "Z", "X"])


def test_invalid_semi_directed_path(sample_mixed_edge_graph):
G = sample_mixed_edge_graph
assert not is_semi_directed_path(G, ["Y", "X"])

# there is a bidirected edge between X and Y
assert not is_semi_directed_path(G, ["X", "Y"])
assert not is_semi_directed_path(G, ["Z", "X", "Y"])


def test_empty_paths(sample_mixed_edge_graph):
G = sample_mixed_edge_graph
source = "A"
target = "B"
with pytest.raises(nx.NodeNotFound, match="source node A not in graph"):
all_semi_directed_paths(G, source, target)

G.add_node(source)
G.add_node(target)
paths = all_semi_directed_paths(G, source, target)
assert list(paths) == []


def test_no_paths(sample_mixed_edge_graph):
G = sample_mixed_edge_graph
source = "Y"
target = "X"
cutoff = 3
paths = all_semi_directed_paths(G, source, target, cutoff)
assert list(paths) == []


def test_multiple_paths(sample_mixed_edge_graph):
G = sample_mixed_edge_graph
G.add_edge_type(nx.DiGraph(), "circle")
G.add_edge("A", "Z", edge_type="directed")
G.add_edge("A", "B", edge_type="circle")
G.add_edge("B", "A", edge_type="circle")
G.add_edge("B", "Z", edge_type="circle")

source = "A"
target = "X"
cutoff = 3
paths = all_semi_directed_paths(G, source, target, cutoff)
paths = list(paths)
assert len(paths) == 2
assert all(path in paths for path in [["A", "Z", "X"], ["A", "B", "Z", "X"]])


def test_long_cutoff(sample_mixed_edge_graph):
G = sample_mixed_edge_graph
source = "Z"
target = "X"
cutoff = 10 # Cutoff longer than the actual path length
paths = all_semi_directed_paths(G, source, target, cutoff)
assert list(paths) == [[source, target]]
10 changes: 9 additions & 1 deletion pywhy_graphs/networkx/algorithms/causal/mixed_edge_moral.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,15 @@ def mixed_edge_moral_graph(
undirected_edge_name="undirected",
bidirected_edge_name="bidirected",
):
"""Return the moral graph from an ancestral graph in :math:`O(|V|^2)`.
"""Return the moral graph from an ancestral graph.
A moral graph is a graph where all edges are undirected and an edge
between two nodes, ``u`` and ``v``, exists if there is a v-structure
``u -> w <- v``, where ``u`` and ``v`` are not adjacent. An ancestral
graph is a mixed edge graph with directed, bidirected, and undirected
edges.
The algorithm runs in :math:`O(|V|^2)`.
Parameters
----------
Expand Down

0 comments on commit 1afacc0

Please sign in to comment.