From ed4cb7f5a3160dfc901f58f6408881a0be8a9143 Mon Sep 17 00:00:00 2001 From: Adam Li Date: Wed, 6 Dec 2023 12:16:27 -0500 Subject: [PATCH] [ENH] Consistent extension sampling of CPDAG, PDAG and DAG (#102) * Add sampling a consistent extension of a CPDAG/PDAG * Also enable conversion from DAG to CPDAG --------- Signed-off-by: Adam Li --- .github/workflows/main.yml | 1 + README.md | 2 +- doc/api.rst | 20 ++ doc/installation.md | 8 +- doc/references.bib | 10 + doc/whats_new/v0.2.rst | 1 + pyproject.toml | 6 +- pywhy_graphs/algorithms/__init__.py | 1 + pywhy_graphs/algorithms/cpdag.py | 281 ++++++++++++++++++ pywhy_graphs/algorithms/generic.py | 32 ++ pywhy_graphs/algorithms/tests/test_cpdag.py | 241 +++++++++++++++ pywhy_graphs/algorithms/tests/test_generic.py | 28 ++ pywhy_graphs/networkx/classes/mixededge.py | 6 + 13 files changed, 630 insertions(+), 7 deletions(-) create mode 100644 pywhy_graphs/algorithms/cpdag.py create mode 100644 pywhy_graphs/algorithms/tests/test_cpdag.py diff --git a/.github/workflows/main.yml b/.github/workflows/main.yml index 1b78278d1..0bc17675c 100644 --- a/.github/workflows/main.yml +++ b/.github/workflows/main.yml @@ -158,6 +158,7 @@ jobs: if: "matrix.os == 'ubuntu'" shell: bash run: | + sudo apt-get update sudo apt-get install nvidia-cuda-toolkit nvidia-cuda-toolkit-gcc - name: Run pytest # headless via Xvfb on linux diff --git a/README.md b/README.md index c2cdd7a3f..1cce990db 100644 --- a/README.md +++ b/README.md @@ -76,7 +76,7 @@ To install the package from github, clone the repository and then `cd` into the Pywhy-Graphs is always looking for new contributors to help make the package better, whether it is algorithms, documentation, examples of graph usage, and more! Contributing to Pywhy-Graphs will be rewarding because you will contribute to a much needed package for causal inference. -See our [contributing guide](https://github.com/py-why/pywhy-graphs/CONTRIBUTING.md) for more details. +See our [contributing guide](https://github.com/py-why/pywhy-graphs/blob/main/CONTRIBUTING.md) for more details. # Citing diff --git a/doc/api.rst b/doc/api.rst index f8a15d3f7..9297ebb33 100644 --- a/doc/api.rst +++ b/doc/api.rst @@ -64,6 +64,26 @@ causal graph operations. is_semi_directed_path all_semi_directed_paths +:mod:`pywhy_graphs.algorithms`: Algorithms for dealing with CPDAGs +================================================================== +With Markov equivalence classes of DAGs in a Markovian SCM setting, we obtain +a potentially directed acyclic graph (PDAG), which may be completed (CPDAG). +We may want to generate a consistent DAG extension (i.e. Markov equivalent) of a CPDAG +then we may use some of the algorithms described here. Or perhaps one may want to +convert a DAG to its corresponding CPDAG. + +.. currentmodule:: pywhy_graphs.algorithms + +.. autosummary:: + :toctree: generated/ + + pdag_to_dag + dag_to_cpdag + pdag_to_cpdag + order_edges + label_edges + + Conversions between other package's causal graphs ================================================= Other packages, such as `causal-learn `_, diff --git a/doc/installation.md b/doc/installation.md index 5e208c6ab..10b312bdb 100644 --- a/doc/installation.md +++ b/doc/installation.md @@ -1,9 +1,10 @@ Installation ============ -**pywhy-graphs** supports Python >= 3.8. +**pywhy-graphs** closely follows the NetworkX dependencies and thus supports Python >= 3.9. -## Installing with ``pip`` +Installing with ``pip`` +----------------------- **pywhy-graphs** is available [on PyPI](https://pypi.org/project/pywhy-graphs/). Just run @@ -12,7 +13,8 @@ Installation # or if you use poetry which is recommended poetry add pywhy-graphs -## Installing from source +Installing from source +---------------------- To install **pywhy-graphs** from source, first clone [the repository](https://github.com/py-why/pywhy-graphs): diff --git a/doc/references.bib b/doc/references.bib index 2ba181a6f..e9a8833c1 100644 --- a/doc/references.bib +++ b/doc/references.bib @@ -17,6 +17,16 @@ @article{bareinboim_causal_2016 pages = {7345--7352} } +@article{chickering2002learning, + title = {Learning equivalence classes of Bayesian-network structures}, + author = {Chickering, David Maxwell}, + journal = {The Journal of Machine Learning Research}, + volume = {2}, + pages = {445--498}, + year = {2002}, + publisher = {JMLR} +} + @article{Colombo2012, author = {Diego Colombo and Marloes H. Maathuis and Markus Kalisch and Thomas S. Richardson}, title = {{Learning high-dimensional directed acyclic graphs with latent and selection variables}}, diff --git a/doc/whats_new/v0.2.rst b/doc/whats_new/v0.2.rst index 0fb5875fa..5a746cbb8 100644 --- a/doc/whats_new/v0.2.rst +++ b/doc/whats_new/v0.2.rst @@ -30,6 +30,7 @@ Changelog - |Feature| Implement and test functions to convert a PAG to MAG, by `Aryan Roy`_ (:pr:`93`) - |API| Remove support for Python 3.8 by `Adam Li`_ (:pr:`99`) - |Feature| Implement a suite of functions for finding and checking semi-directed paths on a mixed-edge graph, by `Adam Li`_ (:pr:`101`) +- |Feature| Implement functions for converting between a DAG and PDAG and CPDAG for generating consistent extensions of a CPDAG for example. These functions are :func:`pywhy_graphs.algorithms.pdag_to_cpdag`, :func:`pywhy_graphs.algorithms.pdag_to_dag` and :func:`pywhy_graphs.algorithms.dag_to_cpdag`, by `Adam Li`_ (:pr:`102`) Code and Documentation Contributors ----------------------------------- diff --git a/pyproject.toml b/pyproject.toml index 5acf30e6d..66524b946 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -7,7 +7,7 @@ exclude_dirs = ["tests"] [tool.black] line-length = 100 -target-version = ['py38'] +target-version = ['py39'] include = '\.pyi?$' extend-exclude = ''' ( @@ -102,10 +102,10 @@ readme = "README.md" classifiers = [ 'Development Status :: 4 - Beta', 'License :: OSI Approved :: MIT License', - 'Programming Language :: Python :: 3.8', 'Programming Language :: Python :: 3.9', 'Programming Language :: Python :: 3.10', - 'Programming Language :: Python :: 3.11' + 'Programming Language :: Python :: 3.11', + 'Programming Language :: Python :: 3.12' ] keywords = ['causality', 'graphs', 'causal-inference', 'graphical-model'] diff --git a/pywhy_graphs/algorithms/__init__.py b/pywhy_graphs/algorithms/__init__.py index b8feba4d4..ddc2edd53 100644 --- a/pywhy_graphs/algorithms/__init__.py +++ b/pywhy_graphs/algorithms/__init__.py @@ -1,3 +1,4 @@ +from .cpdag import * # noqa: F403 from .cyclic import * # noqa: F403 from .generic import * # noqa: F403 from .multidomain import * # noqa: F403 diff --git a/pywhy_graphs/algorithms/cpdag.py b/pywhy_graphs/algorithms/cpdag.py new file mode 100644 index 000000000..558dd710a --- /dev/null +++ b/pywhy_graphs/algorithms/cpdag.py @@ -0,0 +1,281 @@ +from enum import Enum + +import networkx as nx + +import pywhy_graphs as pg + +__all__ = ["pdag_to_dag", "dag_to_cpdag", "pdag_to_cpdag", "order_edges", "label_edges"] + + +class EDGELABELS(Enum): + """Edge labels for a CPDAG.""" + + COMPELLED = "compelled" + REVERSIBLE = "reversible" + UNKNOWN = "unknown" + + +def is_clique(G, nodelist): + H = G.subgraph(nodelist) + n = len(nodelist) + return H.size() == n * (n - 1) / 2 + + +def order_edges(G: nx.DiGraph): + """Find total ordering of the edges of DAG G. + + A total ordering is a topological sorting of the nodes, and then + ordering all possible edges according to Algorithm 4 in + :footcite:`chickering2002learning`. The edges are sorted such that + the edges obey the topological sorting of the nodes, but also + is sorted such that the source node of the edge is ordered based + on the topological sort as well. + + Parameters + ---------- + G : DAG + A directed acyclic graph. + + Returns + ------- + list + A list of edges in the DAG. + + References + ---------- + .. footbibliography:: + """ + if not nx.is_directed_acyclic_graph(G): + raise ValueError("G must be a directed acyclic graph") + nx.set_edge_attributes(G, None, "order") + ordered_nodes = list(nx.topological_sort(G)) + + idx = 0 + + while any([G[u][v]["order"] is None for u, v in G.edges]): + # get all edges that are still not ordered + unordered_edges = [(u, v) for u, v in G.edges if G[u][v]["order"] is None] + + # get the lowest order unlabeled edge's destination node + y = sorted(unordered_edges, key=lambda x: ordered_nodes.index(x[1]))[-1][-1] + + # find the highest order node such that x -> y is not ordered + unlabeled_y_parent_edges = [u for u in G.predecessors(y) if G[u][y]["order"] is None] + x = sorted(unlabeled_y_parent_edges, key=lambda x: ordered_nodes.index(x))[0] + + # label the edge order + G[x][y]["order"] = idx + idx += 1 + + return G + + +def label_edges(G: nx.DiGraph): + """Label compelled and reversible edges of a DAG G. + + Label the edges of a DAG G as either compelled or reversible. Compelled + edges are edges that are compelled to be directed in a consistent + extension of G. Reversible edges are edges that are not required + to be directed in a consistent extension of G. For full details, + see Algorithm 5 in :footcite:`chickering2002learning`. + + Parameters + ---------- + G : DAG + The directed acyclic graph to label. + + Returns + ------- + DAG + The labelled DAG with edge attribute ``"label"`` as either + ``"compelled"`` or ``"reversible"``. + + References + ---------- + .. footbibliography:: + """ + if not nx.is_directed_acyclic_graph(G): + raise ValueError("G must be a directed acyclic graph") + if not all([G[u][v].get("order") is not None for u, v in G.edges]): + raise ValueError("G must have all edges ordered via the `order` attribute") + + nx.set_edge_attributes(G, EDGELABELS.UNKNOWN, "label") + + while any([edge[-1] == EDGELABELS.UNKNOWN for edge in G.edges.data("label")]): + # find the lowest order edge with an unknown label + unknown_edges = [ + (src, target) + for src, target, label in G.edges.data("label") + if label == EDGELABELS.UNKNOWN + ] + unknown_edges.sort(key=lambda edge: G.edges[edge]["order"]) + x, y = unknown_edges[-1] + + # now find every edge w -> x that is labeled as compelled + w_nodes = [w for w in G.predecessors(x) if G[w][x]["label"] == EDGELABELS.COMPELLED] + continue_while_loop = False + for node in w_nodes: + # For all compelled edges w -> x, if there is no edge w -> y, + # we can label the edge x -> y as compelled + if not G.has_edge(node, y): + for src, target in G.in_edges(y): + G[src][target]["label"] = EDGELABELS.COMPELLED + + # now, we start over at the beginning of the while loop + continue_while_loop = True + break + else: + # w -> y is compelled, since there is an edge w -> x that is compelled + # so w is a confounder + G[node][y]["label"] = EDGELABELS.COMPELLED + + if continue_while_loop: + continue + + # now, we check if there an edge z -> y such that: + # 1. z != x + # 2. z is not a parent of x + # If so, then label all unknown edges into y (including x -> y) + # as compelled + # otherwise, label all unknown edges with reversible label + z_exists = len([z for z in G.predecessors(y) if z != x and not G.has_edge(z, x)]) + for src, target in G.in_edges(y): + if G[src][target]["label"] == EDGELABELS.UNKNOWN: + if z_exists: + G[src][target]["label"] = EDGELABELS.COMPELLED + else: + G[src][target]["label"] = EDGELABELS.REVERSIBLE + return G + + +def pdag_to_dag(G): + """Compute consistent extension of given PDAG resulting in a DAG. + + Implements the algorithm described in Figure 11 of :footcite:`chickering2002learning`. + + Parameters + ---------- + G : CPDAG + A partially directed acyclic graph. + + Returns + ------- + DAG + A directed acyclic graph. + + References + ---------- + .. footbibliography:: + """ + if set(["directed", "undirected"]) != set(G.edge_types): + raise ValueError("Only directed and undirected edges are allowed in a CPDAG") + + G = G.copy() + dir_G: nx.DiGraph = G.get_graphs(edge_type="directed") + undir_G: nx.Graph = G.get_graphs(edge_type="undirected") + full_undir_G: nx.Graph = G.to_undirected() + + # initialize a DAG for the consistent extension + dag = nx.DiGraph(dir_G) + + nodes_memo = {node: None for node in G.nodes} + found = False + + while len(nodes_memo) > 0: + found = False + idx = 0 + + nodes = list(nodes_memo.keys()) + + # select a node, x, which: + # 1. has no outgoing edges + # 2. all undirected neighbors are adjacent to all its adjacent nodes + while not found and idx < len(nodes): + # check that there are no outgoing edges for said node + node_is_sink = dir_G.out_degree(nodes[idx]) == 0 + + if not node_is_sink: + idx += 1 + continue + + # since there are no outgoing edges, all directed adjacencies are parent nodes + # now check that all undirected neighbors are adjacent to all its adjacent nodes + undir_nbrs = list(undir_G.neighbors(nodes[idx])) + nearby_is_clique = False + if len(undir_nbrs) != 0: + parents = dir_G.predecessors(nodes[idx]) + # adj = full_undir_G.neighbors(nodes[idx]) + undir_nbrs_and_parents = set(undir_nbrs).union(set(parents)) + nearby_is_clique = is_clique(full_undir_G, undir_nbrs_and_parents) + + if len(undir_nbrs) == 0 or nearby_is_clique: + found = True + + # now, we orient all undirected edges between x and its neighbors + # such that ``nbr -> x`` + for nbr in undir_nbrs: + dag.add_edge(nbr, nodes[idx], edge_type="directed") + + # remove x from the "graph" and memoization + del nodes_memo[nodes[idx]] + dir_G.remove_node(nodes[idx]) + undir_G.remove_node(nodes[idx]) + full_undir_G.remove_node(nodes[idx]) + else: + idx += 1 + + # if no node satisfies condition 1 and 2, then the PDAG does not + # admit a consistent extension + if not found: + print(nodes_memo) + raise ValueError(f"No consistent extension found for PDAG: {G}, {G.edges()}") + return dag + + +def dag_to_cpdag(G): + """Convert a DAG to a CPDAG. + + Creates a CPDAG from a DAG. + + Parameters + ---------- + G : DAG + Directed acyclic graph. + """ + G = order_edges(G) + G = label_edges(G) + + # now construct CPDAG + cpdag = pg.CPDAG() + + # for all compelled edges, add a directed edge + compelled_edges = [ + (u, v) for u, v, label in G.edges.data("label") if label == EDGELABELS.COMPELLED + ] + cpdag.add_edges_from(compelled_edges, edge_type="directed") + + # for all reversible edges, add an undirected edge + reversible_edges = [ + (u, v) for u, v, label in G.edges.data("label") if label == EDGELABELS.REVERSIBLE + ] + cpdag.add_edges_from(reversible_edges, edge_type="undirected") + + return cpdag + + +def pdag_to_cpdag(G): + """Convert a PDAG to a CPDAG. + + Parameters + ---------- + G : PDAG + A partially directed acyclic graph that is not completed. + + Returns + ------- + CPDAG + A completed partially directed acyclic graph. + """ + dag = pdag_to_dag(G) + + return dag_to_cpdag(dag) diff --git a/pywhy_graphs/algorithms/generic.py b/pywhy_graphs/algorithms/generic.py index 2c336df06..f0011c655 100644 --- a/pywhy_graphs/algorithms/generic.py +++ b/pywhy_graphs/algorithms/generic.py @@ -1,3 +1,4 @@ +from itertools import combinations from typing import List, Optional, Set, Union import networkx as nx @@ -17,6 +18,7 @@ "valid_mag", "dag_to_mag", "is_maximal", + "all_vstructures", ] @@ -823,3 +825,33 @@ def is_maximal(G, L: Optional[Set] = None, S: Optional[Set] = None): else: continue return True + + +def all_vstructures(G: nx.DiGraph, as_edges: bool = False): + """Generate all v-structures in the graph. + + Parameters + ---------- + G : DiGraph + A directed graph. + as_edges : bool + Whether to return the v-structures as edges or as a set of tuples. + + Returns + ------- + vstructs : set + If ``as_edges`` is True, a set of v-structures in the graph encoded as the + (parent_1, child, parent_2) tuple with child being an unshielded collider. + Otherwise, a set of tuples of the form (parent, child), which are part of + v-structures in the graph. + """ + vstructs = set() + for node in G.nodes: + for p1, p2 in combinations(G.predecessors(node), 2): + if p1 not in G.predecessors(p2) and p2 not in G.predecessors(p1): + if as_edges: + vstructs.add((p1, node)) + vstructs.add((p2, node)) + else: + vstructs.add((p1, node, p2)) # type: ignore + return vstructs diff --git a/pywhy_graphs/algorithms/tests/test_cpdag.py b/pywhy_graphs/algorithms/tests/test_cpdag.py new file mode 100644 index 000000000..472628cab --- /dev/null +++ b/pywhy_graphs/algorithms/tests/test_cpdag.py @@ -0,0 +1,241 @@ +import networkx as nx +import numpy as np +import pytest + +import pywhy_graphs.networkx as pywhy_nx +from pywhy_graphs.algorithms import all_vstructures +from pywhy_graphs.algorithms.cpdag import ( + EDGELABELS, + dag_to_cpdag, + label_edges, + order_edges, + pdag_to_cpdag, + pdag_to_dag, +) +from pywhy_graphs.testing import assert_mixed_edge_graphs_isomorphic + +seed = 12345 +rng = np.random.default_rng(seed) + + +class TestOrderEdges: + def test_order_edges_errors(self): + G = nx.DiGraph() + + # 1 -> 2 -> 4 -> 5 + # 1 -> 3 -> 4 + # so topological sort is: (1, 2, 3, 4, 5) + G.add_edges_from([(1, 2), (1, 3), (2, 4), (3, 4), (4, 5)]) + # now test when there is a cycle + G.add_edge(5, 1) + with pytest.raises(ValueError, match="G must be a directed acyclic graph"): + order_edges(G) + + def test_order_edges(self): + # Example usage: + G = nx.DiGraph() + + # 1 -> 2 -> 4 -> 5 + # 1 -> 3 -> 4 + G.add_edges_from([(1, 2), (1, 3), (2, 4), (3, 4), (4, 5)]) + G = order_edges(G) + + expected_order = [ + (1, 2, {"order": 4}), + (1, 3, {"order": 3}), + (2, 4, {"order": 1}), + (3, 4, {"order": 2}), + (4, 5, {"order": 0}), + ] + assert set(G.edges.data(data="order")) == set( + [(src, target, order["order"]) for src, target, order in expected_order] + ) + + # Add a string as a node + # 5 -> 3 -> 1 -> 2 -> 'a'; 1 -> 'b' + G = nx.DiGraph() + G.add_edges_from([(5, 3), (3, 1), (1, 2), (2, "a"), (1, "b")]) + G = order_edges(G) + + expected_order = [ + (5, 3, {"order": 4}), + (3, 1, {"order": 3}), + (1, 2, {"order": 2}), + (1, "b", {"order": 1}), + (2, "a", {"order": 0}), + ] + assert set(G.edges.data(data="order")) == set( + [(src, target, order["order"]) for src, target, order in expected_order] + ) + + def test_order_edges_ex1(self): + G = nx.DiGraph() + + # 1 -> 3; 1 -> 4; 1 -> 5; + # 2 -> 3; 2 -> 4; 2 -> 5; + # 3 -> 4; 3 -> 5; + # 4 -> 5; + G.add_edges_from([(1, 3), (1, 4), (1, 5), (2, 3), (2, 4), (2, 5), (3, 4), (3, 5), (4, 5)]) + G = order_edges(G) + expected_order = [ + (1, 3, {"order": 7}), + (1, 4, {"order": 4}), + (1, 5, {"order": 0}), + (2, 3, {"order": 8}), + (2, 4, {"order": 5}), + (2, 5, {"order": 1}), + (3, 4, {"order": 6}), + (3, 5, {"order": 2}), + (4, 5, {"order": 3}), + ] + assert set(G.edges.data(data="order")) == set( + [(src, target, order["order"]) for src, target, order in expected_order] + ) + + +class TestLabelEdges: + def test_label_edges_raises_error_for_non_dag(self): + # Test that label_edges raises a ValueError for a non-DAG + G = nx.DiGraph([(1, 2), (2, 3), (3, 1)]) # A cyclic graph + with pytest.raises(ValueError, match="G must be a directed acyclic graph"): + label_edges(G) + + def test_label_edges_raises_error_for_unordered_edges(self): + # Test that label_edges raises a ValueError for unordered edges + G = nx.DiGraph([(1, 2), (2, 3)]) + with pytest.raises( + ValueError, match="G must have all edges ordered via the `order` attribute" + ): + label_edges(G) + + def test_label_edges_all_compelled(self): + # Create an example DAG for testing + G = nx.DiGraph() + + # 1 -> 3; 3 -> 4; 3 -> 5 + # 2 -> 3; + # 4 -> 5 + G.add_edges_from([(1, 3), (2, 3), (3, 4), (3, 5), (4, 5)]) + nx.set_edge_attributes(G, None, "order") + G = order_edges(G) + labeled_graph = label_edges(G) + + expected_labels = { + (1, 3): EDGELABELS.COMPELLED, + (2, 3): EDGELABELS.COMPELLED, + (3, 4): EDGELABELS.COMPELLED, + (3, 5): EDGELABELS.COMPELLED, + (4, 5): EDGELABELS.REVERSIBLE, + } + for edge, expected_label in expected_labels.items(): + assert labeled_graph[edge[0]][edge[1]]["label"] == expected_label, ( + f"Edge {edge} has label {labeled_graph[edge[0]][edge[1]]['label']}, " + f"but expected {expected_label}" + ) + + +class TestPDAGtoDAG: + def test_pdag_to_dag_errors(self): + G = nx.DiGraph() + G.add_edge("A", "Z") + G.add_edges_from([("A", "B"), ("B", "A"), ("B", "Z"), ("X", "Y"), ("Z", "X")]) + + # add non-CPDAG supported edges + G = pywhy_nx.MixedEdgeGraph(graphs=[G], edge_types=["directed"]) + G.add_edge_type(nx.DiGraph(), "circle") + G.add_edge("Z", "A", edge_type="circle") + G.add_edge("A", "B", edge_type="circle") + G.add_edge("B", "A", edge_type="circle") + G.add_edge("B", "Z", edge_type="circle") + with pytest.raises( + ValueError, match="Only directed and undirected edges are allowed in a CPDAG" + ): + pdag_to_dag(G) + + def test_pdag_to_dag_inconsistent(self): + # 1 -- 3; 1 -> 4; + # 2 -> 3; + # 4 -> 3 + # Note: this PDAG is inconsistent because it would create a v-structure, or a cycle + # by orienting the undirected edge 1 -- 3 + pdag = pywhy_nx.MixedEdgeGraph( + graphs=[nx.DiGraph(), nx.Graph()], edge_types=["directed", "undirected"] + ) + pdag.add_edge(1, 3, edge_type="undirected") + pdag.add_edges_from([(1, 4), (2, 3), (4, 3)], edge_type="directed") + with pytest.raises(ValueError, match="No consistent extension found"): + pdag_to_dag(pdag) + + def test_pdag_to_dag_already_dag(self): + # 1 -> 2; 1 -> 3 + # 2 -> 3 + # 4 -> 3 + pdag = pywhy_nx.MixedEdgeGraph( + graphs=[nx.DiGraph(), nx.Graph()], edge_types=["directed", "undirected"] + ) + pdag.add_edges_from([(1, 2), (1, 3), (2, 3), (4, 3)], edge_type="directed") + G = pdag_to_dag(pdag) + assert nx.is_isomorphic(G, pdag.get_graphs("directed")) + + def test_pdag_to_dag_0(self): + # 1 -- 3; + # 2 -> 3; 2 -> 4 + pdag = pywhy_nx.MixedEdgeGraph( + graphs=[nx.DiGraph(), nx.Graph()], edge_types=["directed", "undirected"] + ) + + pdag.add_edge(1, 3, edge_type="undirected") + pdag.add_edges_from([(2, 3), (2, 4)], edge_type="directed") + + G = pdag_to_dag(pdag) + + # add a directed edge from 3 to 1 + pdag.remove_edge(1, 3, edge_type="undirected") + pdag.add_edge(3, 1, edge_type="directed") + + assert nx.is_isomorphic(G, pdag.get_graphs("directed")) + + def test_pdag_to_dag_1(self): + # 1 -- 3; + # 2 -> 1; 2 -> 4 + pdag = pywhy_nx.MixedEdgeGraph( + graphs=[nx.DiGraph(), nx.Graph()], edge_types=["directed", "undirected"] + ) + + pdag.add_edge(1, 3, edge_type="undirected") + pdag.add_edges_from([(2, 1), (2, 4)], edge_type="directed") + + G = pdag_to_dag(pdag) + pdag.remove_edge(1, 3, edge_type="undirected") + pdag.add_edge(1, 3, edge_type="directed") + + assert nx.is_isomorphic(G, pdag.get_graphs("directed")) + + def test_pdag_to_cpdag(self): + # construct a random DAG + n = 10 + p = 0.4 + random_graph = nx.fast_gnp_random_graph(n, p, directed=True, seed=seed) + dag = nx.DiGraph([(u, v) for (u, v) in random_graph.edges() if u < v]) + + pdag = pywhy_nx.MixedEdgeGraph( + graphs=[dag.copy(), nx.Graph()], edge_types=["directed", "undirected"] + ) + + # now we construct the set of undirected edges that to not belong + # to any unshielded collider (i.e. v-structure) + vstructs = all_vstructures(dag, as_edges=True) + + # we apply a random orientation for a subset of the undirected edges + for edge in dag.edges: + if edge not in vstructs: + if rng.binomial(1, 0.3): + pdag.remove_edge(*edge) + pdag.add_edge(*edge, edge_type="undirected") + + # now, we can convert the DAG to CPDAG and also convert the PDAG to a CPDAG + # they should be equivalent + cpdag = dag_to_cpdag(dag) + cpdag_from_pdag = pdag_to_cpdag(pdag) + + assert_mixed_edge_graphs_isomorphic(cpdag, cpdag_from_pdag) diff --git a/pywhy_graphs/algorithms/tests/test_generic.py b/pywhy_graphs/algorithms/tests/test_generic.py index 8b7e6c375..09218a334 100644 --- a/pywhy_graphs/algorithms/tests/test_generic.py +++ b/pywhy_graphs/algorithms/tests/test_generic.py @@ -3,6 +3,7 @@ import pywhy_graphs from pywhy_graphs import ADMG +from pywhy_graphs.algorithms import all_vstructures def test_convert_to_latent_confounder_errors(): @@ -468,3 +469,30 @@ def test_is_maximal(): S = {} L = {"Y"} assert not pywhy_graphs.is_maximal(admg, L, S) + + +def test_all_vstructures(): + # Create a directed graph + G = nx.DiGraph() + G.add_edges_from([(1, 2), (3, 2), (4, 2)]) + + # Generate the v-structures + v_structs_edges = all_vstructures(G, as_edges=True) + v_structs_tuples = all_vstructures(G, as_edges=False) + + # Assert that the returned values are as expected + assert len(v_structs_edges) == 3 + assert len(v_structs_tuples) == 3 + assert (1, 2) in v_structs_edges or (2, 1) in v_structs_edges + assert (3, 2) in v_structs_edges or (2, 3) in v_structs_edges + assert (1, 2, 3) in v_structs_tuples or (3, 2, 1) in v_structs_tuples + assert (4, 2, 3) in v_structs_tuples or (3, 2, 4) in v_structs_tuples + + G.remove_node(2) + # Generate the v-structures + v_structs_edges = all_vstructures(G, as_edges=True) + v_structs_tuples = all_vstructures(G, as_edges=False) + + # Assert that the returned values are as expected + assert len(v_structs_edges) == 0 + assert len(v_structs_tuples) == 0 diff --git a/pywhy_graphs/networkx/classes/mixededge.py b/pywhy_graphs/networkx/classes/mixededge.py index f90723973..a9fc41cc4 100644 --- a/pywhy_graphs/networkx/classes/mixededge.py +++ b/pywhy_graphs/networkx/classes/mixededge.py @@ -112,6 +112,12 @@ def __init__(self, graphs=None, edge_types=None, **attr): # load graph attributes (must be after convert) self.graph.update(attr) + # XXX: experimental. Fix this in doc string once finalized. + # make dynamic property names for the edges, (i.e. circle_edges, + # directed_edges, undirected_edges) + for edge_type_name in self.edge_types: + setattr(self, f"{edge_type_name}_edges", self.get_graphs(edge_type_name).edges) + def __str__(self): """Returns a short summary of the graph.