diff --git a/pywhy_graphs/algorithms/cpdag.py b/pywhy_graphs/algorithms/cpdag.py index d33a0a2bd..d3bd96090 100644 --- a/pywhy_graphs/algorithms/cpdag.py +++ b/pywhy_graphs/algorithms/cpdag.py @@ -224,7 +224,8 @@ def pdag_to_dag(G): # if no node satisfies condition 1 and 2, then the PDAG does not # admit a consistent extension if not found: - raise ValueError("No consistent extension found") + print(nodes_memo) + raise ValueError(f"No consistent extension found for PDAG: {G}, {G.edges()}") return dag diff --git a/pywhy_graphs/algorithms/tests/test_cpdag.py b/pywhy_graphs/algorithms/tests/test_cpdag.py index d200b11ca..1251507ac 100644 --- a/pywhy_graphs/algorithms/tests/test_cpdag.py +++ b/pywhy_graphs/algorithms/tests/test_cpdag.py @@ -14,7 +14,7 @@ ) from pywhy_graphs.testing import assert_mixed_edge_graphs_isomorphic -seed = 1234 +seed = 12345 rng = np.random.default_rng(seed) @@ -216,8 +216,8 @@ def test_pdag_to_dag_1(self): def test_pdag_to_cpdag(self): # construct a random DAG n = 10 - p = 0.5 - random_graph = nx.fast_gnp_random_graph(n, p, directed=True) + p = 0.4 + random_graph = nx.fast_gnp_random_graph(n, p, directed=True, seed=seed) dag = nx.DiGraph([(u, v) for (u, v) in random_graph.edges() if u < v]) pdag = pywhy_nx.MixedEdgeGraph( @@ -231,7 +231,7 @@ def test_pdag_to_cpdag(self): # we apply a random orientation for a subset of the undirected edges for edge in dag.edges: if edge not in vstructs: - if rng.binomial(1, 1.0 / 3): + if rng.binomial(1, 0.3): pdag.remove_edge(*edge) pdag.add_edge(*edge, edge_type="undirected") diff --git a/pywhy_graphs/networkx/classes/mixededge.py b/pywhy_graphs/networkx/classes/mixededge.py index f90723973..a9fc41cc4 100644 --- a/pywhy_graphs/networkx/classes/mixededge.py +++ b/pywhy_graphs/networkx/classes/mixededge.py @@ -112,6 +112,12 @@ def __init__(self, graphs=None, edge_types=None, **attr): # load graph attributes (must be after convert) self.graph.update(attr) + # XXX: experimental. Fix this in doc string once finalized. + # make dynamic property names for the edges, (i.e. circle_edges, + # directed_edges, undirected_edges) + for edge_type_name in self.edge_types: + setattr(self, f"{edge_type_name}_edges", self.get_graphs(edge_type_name).edges) + def __str__(self): """Returns a short summary of the graph.