diff --git a/doc/api.rst b/doc/api.rst index f8a15d3f7..e8c7efc16 100644 --- a/doc/api.rst +++ b/doc/api.rst @@ -64,6 +64,26 @@ causal graph operations. is_semi_directed_path all_semi_directed_paths +:mod:`pywhy_graphs.algorithms.cpdag`: Algorithms for dealing with CPDAGs +======================================================================== +With Markov equivalence classes of DAGs in a Markovian SCM setting, we obtain +a potentiallly directed acyclic graph (PDAG), which may be completed (CPDAG). +We may want to generate a consistent DAG extension (i.e. Markov equivalent) of a CPDAG +then we may use some of the algorithms described here. Or perhaps one may want to +convert a DAG to its corresponding CPDAG. + +.. currentmodule:: pywhy_graphs.algorithms.cpdag + +.. autosummary:: + :toctree: generated/ + + pdag_to_dag + dag_to_cpdag + pdag_to_cpdag + order_edges + label_edges + + Conversions between other package's causal graphs ================================================= Other packages, such as `causal-learn `_, diff --git a/doc/whats_new/v0.2.rst b/doc/whats_new/v0.2.rst index ab7101ec6..5d912ddaa 100644 --- a/doc/whats_new/v0.2.rst +++ b/doc/whats_new/v0.2.rst @@ -30,7 +30,7 @@ Changelog - |Feature| Implement and test functions to convert a PAG to MAG, by `Aryan Roy`_ (:pr:`93`) - |API| Remove support for Python 3.8 by `Adam Li`_ (:pr:`99`) - |Feature| Implement a suite of functions for finding and checking semi-directed paths on a mixed-edge graph, by `Adam Li`_ (:pr:`101`) -- |Feature| Implement :func:`pywhy_graphs.algorithms.cpdag_to_dag`, :func:`pywhy_graphs.algorithms.pdag_to_dag` and :func:`pywhy_graphs.algorithms.dag_to_cpdag`, by `Adam Li`_ (:pr:`102`) +- |Feature| Implement functions for converting between a DAG and PDAG and CPDAG for generating consistent extensions of a CPDAG for example. These functions are :func:`pywhy_graphs.algorithms.cpdag_to_dag`, :func:`pywhy_graphs.algorithms.pdag_to_dag` and :func:`pywhy_graphs.algorithms.dag_to_cpdag`, by `Adam Li`_ (:pr:`102`) Code and Documentation Contributors ----------------------------------- diff --git a/pywhy_graphs/algorithms/cpdag.py b/pywhy_graphs/algorithms/cpdag.py index 921e6a381..d33a0a2bd 100644 --- a/pywhy_graphs/algorithms/cpdag.py +++ b/pywhy_graphs/algorithms/cpdag.py @@ -1,5 +1,19 @@ +from enum import Enum + import networkx as nx +import pywhy_graphs as pg + +__all__ = ["pdag_to_dag", "dag_to_cpdag", "pdag_to_cpdag", "order_edges", "label_edges"] + + +class EDGELABELS(Enum): + """Edge labels for a CPDAG.""" + + COMPELLED = "compelled" + REVERSIBLE = "reversible" + UNKNOWN = "unknown" + def is_clique(G, nodelist): H = G.subgraph(nodelist) @@ -7,23 +21,128 @@ def is_clique(G, nodelist): return H.size() == n * (n - 1) / 2 -def order_edges(G): - pass +def order_edges(G: nx.DiGraph): + """Find total ordering of the edges of DAG G. + A total ordering is a topological sorting of the nodes, and then + ordering all possible edges according to Algorithm 4 in + :footcite:`chickering2002learning`. -def label_edges(G): - pass + Parameters + ---------- + G : DAG + A directed acyclic graph. + Returns + ------- + list + A list of edges in the DAG. -def cpdag_to_pdag(G): - """Convert a CPDAG to + References + ---------- + .. footbibliography:: + """ + if not nx.is_directed_acyclic_graph(G): + raise ValueError("G must be a directed acyclic graph") + nx.set_edge_attributes(G, None, "order") + ordered_nodes = list(nx.topological_sort(G)) + + idx = 0 + + while any([G[u][v]["order"] is None for u, v in G.edges]): + # get all edges that are still not ordered + unordered_edges = [(u, v) for u, v in G.edges if G[u][v]["order"] is None] + + # get the lowest order unlabeled edge's destination node + y = sorted(unordered_edges, key=lambda x: ordered_nodes.index(x[1]))[-1][-1] + + # find the highest order node such that x -> y is not ordered + unlabeled_y_parent_edges = [u for u in G.predecessors(y) if G[u][y]["order"] is None] + x = sorted(unlabeled_y_parent_edges, key=lambda x: ordered_nodes.index(x))[0] + + # label the edge order + G[x][y]["order"] = idx + idx += 1 + + return G + + +def label_edges(G: nx.DiGraph): + """Label compelled and reversible edges of a DAG G. + + Label the edges of a DAG G as either compelled or reversible. Compelled + edges are edges that are compelled to be directed in a consistent + extension of G. Reversible edges are edges that are not required + to be directed in a consistent extension of G. For full details, + see Algorithm 5 in :footcite:`chickering2002learning`. Parameters ---------- - G : _type_ - _description_ + G : DAG + The directed acyclic graph to label. + + Returns + ------- + DAG + The labelled DAG with edge attribute ``"label"`` as either + ``"compelled"`` or ``"reversible"``. + + References + ---------- + .. footbibliography:: """ - pass + if not nx.is_directed_acyclic_graph(G): + raise ValueError("G must be a directed acyclic graph") + if not all([G[u][v].get("order") is not None for u, v in G.edges]): + raise ValueError("G must have all edges ordered via the `order` attribute") + + nx.set_edge_attributes(G, EDGELABELS.UNKNOWN, "label") + + while any([edge[-1] == EDGELABELS.UNKNOWN for edge in G.edges.data("label")]): + # find the lowest order edge with an unknown label + unknown_edges = [ + (src, target) + for src, target, label in G.edges.data("label") + if label == EDGELABELS.UNKNOWN + ] + unknown_edges.sort(key=lambda edge: G.edges[edge]["order"]) + x, y = unknown_edges[-1] + + # now find every edge w -> x that is labeled as compelled + w_nodes = [w for w in G.predecessors(x) if G[w][x]["label"] == EDGELABELS.COMPELLED] + continue_while_loop = False + for node in w_nodes: + # For all compelled edges w -> x, if there is no edge w -> y, + # we can label the edge x -> y as compelled + if not G.has_edge(node, y): + for src, target in G.in_edges(y): + G[src][target]["label"] = EDGELABELS.COMPELLED + + # now, we start over at the beginning of the while loop + continue_while_loop = True + break + else: + # w -> y is compelled, since there is an edge w -> x that is compelled + # so w is a confounder + G[node][y]["label"] = EDGELABELS.COMPELLED + + if continue_while_loop: + continue + + # now, we check if there an edge z -> y such that: + # 1. z != x + # 2. z is not a parent of x + # If so, then label all unknown edges into y (including x -> y) + # as compelled + # otherwise, label all unknown edges with reversible label + z_exists = len([z for z in G.predecessors(y) if z != x and not G.has_edge(z, x)]) + for src, target in G.in_edges(y): + if G[src][target]["label"] == EDGELABELS.UNKNOWN: + if z_exists: + G[src][target]["label"] = EDGELABELS.COMPELLED + else: + G[src][target]["label"] = EDGELABELS.REVERSIBLE + return G def pdag_to_dag(G): @@ -48,16 +167,23 @@ def pdag_to_dag(G): if set(["directed", "undirected"]) != set(G.edge_types): raise ValueError("Only directed and undirected edges are allowed in a CPDAG") + G = G.copy() dir_G: nx.DiGraph = G.get_graphs(edge_type="directed") undir_G: nx.Graph = G.get_graphs(edge_type="undirected") full_undir_G: nx.Graph = G.to_undirected() - nodes = set(dir_G.nodes) + + # initialize a DAG for the consistent extension + dag = nx.DiGraph(dir_G) + + nodes_memo = {node: None for node in G.nodes} found = False - while nodes: + while len(nodes_memo) > 0: found = False idx = 0 + nodes = list(nodes_memo.keys()) + # select a node, x, which: # 1. has no outgoing edges # 2. all undirected neighbors are adjacent to all its adjacent nodes @@ -71,33 +197,81 @@ def pdag_to_dag(G): # since there are no outgoing edges, all directed adjacencies are parent nodes # now check that all undirected neighbors are adjacent to all its adjacent nodes - undir_nbrs = undir_G.neighbors(nodes[idx]) - parents = dir_G.predecessors(nodes[idx]) - undir_nbrs_and_parents = set(undir_nbrs).union(set(parents)) - nearby_is_clique = is_clique(full_undir_G, undir_nbrs_and_parents) - idx += 1 + undir_nbrs = list(undir_G.neighbors(nodes[idx])) + nearby_is_clique = False + if len(undir_nbrs) != 0: + parents = dir_G.predecessors(nodes[idx]) + # adj = full_undir_G.neighbors(nodes[idx]) + undir_nbrs_and_parents = set(undir_nbrs).union(set(parents)) + nearby_is_clique = is_clique(full_undir_G, undir_nbrs_and_parents) - if nearby_is_clique: + if len(undir_nbrs) == 0 or nearby_is_clique: found = True # now, we orient all undirected edges between x and its neighbors # such that ``nbr -> x`` for nbr in undir_nbrs: - dir_G.add_edge(nbr, nodes[idx], edge_type="directed") + dag.add_edge(nbr, nodes[idx], edge_type="directed") - # remove x from the "graph" - nodes.remove(nodes[idx]) - if not found: - raise ValueError("No consistent extension found") - return dir_G + # remove x from the "graph" and memoization + del nodes_memo[nodes[idx]] + dir_G.remove_node(nodes[idx]) + undir_G.remove_node(nodes[idx]) + full_undir_G.remove_node(nodes[idx]) + else: + idx += 1 + + # if no node satisfies condition 1 and 2, then the PDAG does not + # admit a consistent extension + if not found: + raise ValueError("No consistent extension found") + return dag def dag_to_cpdag(G): """Convert a DAG to a CPDAG. + Creates a CPDAG from a DAG. + Parameters ---------- - G : _type_ - _description_ + G : DAG + Directed acyclic graph. """ - pass + G = order_edges(G) + G = label_edges(G) + + # now construct CPDAG + cpdag = pg.CPDAG() + + # for all compelled edges, add a directed edge + compelled_edges = [ + (u, v) for u, v, label in G.edges.data("label") if label == EDGELABELS.COMPELLED + ] + cpdag.add_edges_from(compelled_edges, edge_type="directed") + + # for all reversible edges, add an undirected edge + reversible_edges = [ + (u, v) for u, v, label in G.edges.data("label") if label == EDGELABELS.REVERSIBLE + ] + cpdag.add_edges_from(reversible_edges, edge_type="undirected") + + return cpdag + + +def pdag_to_cpdag(G): + """Convert a PDAG to a CPDAG. + + Parameters + ---------- + G : PDAG + A partially directed acyclic graph that is not completed. + + Returns + ------- + CPDAG + A completed partially directed acyclic graph. + """ + dag = pdag_to_dag(G) + + return dag_to_cpdag(dag) diff --git a/pywhy_graphs/algorithms/tests/test_cpdag.py b/pywhy_graphs/algorithms/tests/test_cpdag.py index 48cbf0413..006ab9e2e 100644 --- a/pywhy_graphs/algorithms/tests/test_cpdag.py +++ b/pywhy_graphs/algorithms/tests/test_cpdag.py @@ -2,7 +2,151 @@ import pytest import pywhy_graphs.networkx as pywhy_nx -from pywhy_graphs.algorithms.cpdag import cpdag_to_pdag, pdag_to_dag +from pywhy_graphs.algorithms.cpdag import EDGELABELS, label_edges, order_edges, pdag_to_dag + + +class TestOrderEdges: + def test_order_edges_errors(self): + G = nx.DiGraph() + + # 1 -> 2 -> 4 -> 5 + # 1 -> 3 -> 4 + # so topological sort is: (1, 2, 3, 4, 5) + G.add_edges_from([(1, 2), (1, 3), (2, 4), (3, 4), (4, 5)]) + # now test when there is a cycle + G.add_edge(5, 1) + with pytest.raises(ValueError, match="G must be a directed acyclic graph"): + order_edges(G) + + def test_order_edges(self): + # Example usage: + G = nx.DiGraph() + + # 1 -> 2 -> 4 -> 5 + # 1 -> 3 -> 4 + G.add_edges_from([(1, 2), (1, 3), (2, 4), (3, 4), (4, 5)]) + G = order_edges(G) + + print("Edge labels:", G.edges.data()) + expected_order = [ + (1, 2, {"order": 4}), + (1, 3, {"order": 3}), + (2, 4, {"order": 1}), + (3, 4, {"order": 2}), + (4, 5, {"order": 0}), + ] + assert set(G.edges.data(data="order")) == set( + [(src, target, order["order"]) for src, target, order in expected_order] + ) + + # Add a string as a node + # 5 -> 3 -> 1 -> 2 -> 'a'; 1 -> 'b' + G = nx.DiGraph() + G.add_edges_from([(5, 3), (3, 1), (1, 2), (2, "a"), (1, "b")]) + G = order_edges(G) + print("Edge labels:", G.edges.data()) + expected_order = [ + (5, 3, {"order": 4}), + (3, 1, {"order": 3}), + (1, 2, {"order": 2}), + (1, "b", {"order": 1}), + (2, "a", {"order": 0}), + ] + assert set(G.edges.data(data="order")) == set( + [(src, target, order["order"]) for src, target, order in expected_order] + ) + + def test_order_edges_ex1(self): + G = nx.DiGraph() + + # 1 -> 3; 1 -> 4; 1 -> 5; + # 2 -> 3; 2 -> 4; 2 -> 5; + # 3 -> 4; 3 -> 5; + # 4 -> 5; + G.add_edges_from([(1, 3), (1, 4), (1, 5), (2, 3), (2, 4), (2, 5), (3, 4), (3, 5), (4, 5)]) + G = order_edges(G) + expected_order = [ + (1, 3, {"order": 7}), + (1, 4, {"order": 4}), + (1, 5, {"order": 0}), + (2, 3, {"order": 8}), + (2, 4, {"order": 5}), + (2, 5, {"order": 1}), + (3, 4, {"order": 6}), + (3, 5, {"order": 2}), + (4, 5, {"order": 3}), + ] + assert set(G.edges.data(data="order")) == set( + [(src, target, order["order"]) for src, target, order in expected_order] + ) + + +class TestLabelEdges: + def test_label_edges_raises_error_for_non_dag(self): + # Test that label_edges raises a ValueError for a non-DAG + G = nx.DiGraph([(1, 2), (2, 3), (3, 1)]) # A cyclic graph + with pytest.raises(ValueError, match="G must be a directed acyclic graph"): + label_edges(G) + + def test_label_edges_raises_error_for_unordered_edges(self): + # Test that label_edges raises a ValueError for unordered edges + G = nx.DiGraph([(1, 2), (2, 3)]) + with pytest.raises( + ValueError, match="G must have all edges ordered via the `order` attribute" + ): + label_edges(G) + + @pytest.mark.skip() + def test_label_edges_output(self): + # Create an example DAG for testing + G = nx.DiGraph() + + # 1 -> 2 -> 4 -> 5 + # 1 -> 3 -> 4 + G.add_edges_from([(1, 2), (1, 3), (2, 4), (3, 4), (4, 5)]) + nx.set_edge_attributes(G, None, "order") + G[1][2]["order"] = 0 + G[1][3]["order"] = 1 + G[2][4]["order"] = 2 + G[3][4]["order"] = 3 + G[4][5]["order"] = 4 + + # Test the output of label_edges function for a specific example DAG + labeled_graph = label_edges(G) + expected_labels = { + (1, 2): EDGELABELS.REVERSIBLE, + (1, 3): EDGELABELS.REVERSIBLE, + (2, 4): EDGELABELS.REVERSIBLE, + (3, 4): EDGELABELS.REVERSIBLE, + (4, 5): EDGELABELS.REVERSIBLE, + } + for edge, expected_label in expected_labels.items(): + assert labeled_graph.edges[edge]["label"] == expected_label + + def test_label_edges_all_compelled(self): + # Create an example DAG for testing + G = nx.DiGraph() + + # 1 -> 3; 3 -> 4; 3 -> 5 + # 2 -> 3; + # 4 -> 5 + G.add_edges_from([(1, 3), (2, 3), (3, 4), (3, 5), (4, 5)]) + nx.set_edge_attributes(G, None, "order") + G = order_edges(G) + labeled_graph = label_edges(G) + + expected_labels = { + (1, 3): EDGELABELS.COMPELLED, + (2, 3): EDGELABELS.COMPELLED, + (3, 4): EDGELABELS.COMPELLED, + (3, 5): EDGELABELS.COMPELLED, + (4, 5): EDGELABELS.REVERSIBLE, + } + for edge, expected_label in expected_labels.items(): + assert labeled_graph[edge[0]][edge[1]]["label"] == expected_label, ( + f"Edge {edge} has label {labeled_graph[edge[0]][edge[1]]['label']}, " + f"but expected {expected_label}" + ) class TestPDAGtoDAG: @@ -11,29 +155,74 @@ def test_pdag_to_dag_errors(self): G.add_edges_from([("X", "Y"), ("Z", "X")]) G.add_edge("A", "Z") G.add_edges_from([("A", "B"), ("B", "A"), ("B", "Z")]) - G = pywhy_nx.MixedEdgeGraph(graphs=[G], edge_types=["directed"], name="IV Graph") + + # add non-CPDAG supported edges + G = pywhy_nx.MixedEdgeGraph(graphs=[G], edge_types=["directed"]) G.add_edge_type(nx.DiGraph(), "circle") G.add_edge("Z", "A", edge_type="circle") G.add_edge("A", "B", edge_type="circle") G.add_edge("B", "A", edge_type="circle") G.add_edge("B", "Z", edge_type="circle") - G = cpdag_to_pdag(G) with pytest.raises( ValueError, match="Only directed and undirected edges are allowed in a CPDAG" ): pdag_to_dag(G) + def test_pdag_to_dag_inconsistent(self): + # 1 -- 3; 1 -> 4; + # 2 -> 3; + # 4 -> 3 + # Note: this PDAG is inconsistent because it would create a v-structure, or a cycle + # by orienting the undirected edge 1 -- 3 + pdag = pywhy_nx.MixedEdgeGraph( + graphs=[nx.DiGraph(), nx.Graph()], edge_types=["directed", "undirected"] + ) + pdag.add_edge(1, 3, edge_type="undirected") + pdag.add_edges_from([(1, 4), (2, 3), (4, 3)], edge_type="directed") + with pytest.raises(ValueError, match="No consistent extension found"): + pdag_to_dag(pdag) + + def test_pdag_to_dag_already_dag(self): + # 1 -> 2; 1 -> 3 + # 2 -> 3 + # 4 -> 3 + pdag = pywhy_nx.MixedEdgeGraph( + graphs=[nx.DiGraph(), nx.Graph()], edge_types=["directed", "undirected"] + ) + pdag.add_edges_from([(1, 2), (1, 3), (2, 3), (4, 3)], edge_type="directed") + G = pdag_to_dag(pdag) + assert nx.is_isomorphic(G, pdag.get_graphs("directed")) + + def test_pdag_to_dag_0(self): + # 1 -- 3; + # 2 -> 3; 2 -> 4 + pdag = pywhy_nx.MixedEdgeGraph( + graphs=[nx.DiGraph(), nx.Graph()], edge_types=["directed", "undirected"] + ) + + pdag.add_edge(1, 3, edge_type="undirected") + pdag.add_edges_from([(2, 3), (2, 4)], edge_type="directed") + + G = pdag_to_dag(pdag) + + # add a directed edge from 3 to 1 + pdag.remove_edge(1, 3, edge_type="undirected") + pdag.add_edge(3, 1, edge_type="directed") + + assert nx.is_isomorphic(G, pdag.get_graphs("directed")) + def test_pdag_to_dag_1(self): - G = nx.DiGraph() - G.add_edges_from([("X", "Y"), ("Z", "X")]) - G.add_edge("A", "Z") - G.add_edges_from([("A", "B"), ("B", "A"), ("B", "Z")]) - G = pywhy_nx.MixedEdgeGraph(graphs=[G], edge_types=["directed"], name="IV Graph") - G.add_edge_type(nx.DiGraph(), "circle") - G.add_edge("Z", "A", edge_type="circle") - G.add_edge("A", "B", edge_type="circle") - G.add_edge("B", "A", edge_type="circle") - G.add_edge("B", "Z", edge_type="circle") - G = cpdag_to_pdag(G) - G = pdag_to_dag(G) - assert G.edges == {("X", "Y"), ("Z", "X"), ("A", "Z"), ("A", "B"), ("B", "Z")} + # 1 -- 3; + # 2 -> 1; 2 -> 4 + pdag = pywhy_nx.MixedEdgeGraph( + graphs=[nx.DiGraph(), nx.Graph()], edge_types=["directed", "undirected"] + ) + + pdag.add_edge(1, 3, edge_type="undirected") + pdag.add_edges_from([(2, 1), (2, 4)], edge_type="directed") + + G = pdag_to_dag(pdag) + pdag.remove_edge(1, 3, edge_type="undirected") + pdag.add_edge(1, 3, edge_type="directed") + + assert nx.is_isomorphic(G, pdag.get_graphs("directed"))