diff --git a/pywhy_graphs/algorithms/pag.py b/pywhy_graphs/algorithms/pag.py index 4f8f2a28b..5d8dcfbe1 100644 --- a/pywhy_graphs/algorithms/pag.py +++ b/pywhy_graphs/algorithms/pag.py @@ -6,7 +6,7 @@ import networkx as nx import numpy as np -from pywhy_graphs import PAG, StationaryTimeSeriesPAG +from pywhy_graphs import CPDAG, PAG, StationaryTimeSeriesPAG from pywhy_graphs.algorithms.generic import single_source_shortest_mixed_path from pywhy_graphs.typing import Node, TsNode @@ -908,3 +908,49 @@ def _check_ts_node(node): ) if node[1] > 0: raise ValueError(f"All lag points should be 0, or less. You passed in {node}.") + + +def pag_to_mag(graph): + """Convert an PAG to an MAG. + + Parameters + ---------- + G : Graph + The PAG. + + Returns + ------- + mag : Graph + The MAG constructed from the PAG. + """ + copy_graph = graph.copy() + + cedges = set(copy_graph.circle_edges) + dedges = set(copy_graph.directed_edges) + + temp_cpdag = CPDAG() + + to_remove = [] + to_reorient = [] + to_add = [] + + for u, v in cedges: + if (v, u) in dedges: # remove the circle end from a 'o-->' edge to make a '-->' edge + to_remove.append((u, v)) + elif (v, u) not in cedges: # reorient a '--o' edge to '-->' + to_reorient.append((u, v)) + elif (v, u) in cedges and ( + v, + u, + ) not in to_add: # add all 'o--o' edges to the cpdag + to_add.append((u, v)) + for u, v in to_remove: + copy_graph.remove_edge(u, v, graph.circle_edge_name) + for u, v in to_reorient: + copy_graph.orient_uncertain_edge(u, v) + for u, v in to_add: + temp_cpdag.add_edge(v, u, temp_cpdag.undirected_edge_name) + + # flag = True + + return copy_graph