Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ENH] Function for converting PAG to MAG #86

Closed
wants to merge 1 commit into from
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 47 additions & 1 deletion pywhy_graphs/algorithms/pag.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import networkx as nx
import numpy as np

from pywhy_graphs import PAG, StationaryTimeSeriesPAG
from pywhy_graphs import CPDAG, PAG, StationaryTimeSeriesPAG
from pywhy_graphs.algorithms.generic import single_source_shortest_mixed_path
from pywhy_graphs.typing import Node, TsNode

Expand Down Expand Up @@ -908,3 +908,49 @@ def _check_ts_node(node):
)
if node[1] > 0:
raise ValueError(f"All lag points should be 0, or less. You passed in {node}.")


def pag_to_mag(graph):
"""Convert an PAG to an MAG.

Parameters
----------
G : Graph
The PAG.

Returns
-------
mag : Graph
The MAG constructed from the PAG.
"""
copy_graph = graph.copy()

cedges = set(copy_graph.circle_edges)
dedges = set(copy_graph.directed_edges)

temp_cpdag = CPDAG()

to_remove = []
to_reorient = []
to_add = []

for u, v in cedges:
if (v, u) in dedges: # remove the circle end from a 'o-->' edge to make a '-->' edge
to_remove.append((u, v))
elif (v, u) not in cedges: # reorient a '--o' edge to '-->'
to_reorient.append((u, v))
elif (v, u) in cedges and (
v,
u,
) not in to_add: # add all 'o--o' edges to the cpdag
to_add.append((u, v))
for u, v in to_remove:
copy_graph.remove_edge(u, v, graph.circle_edge_name)
for u, v in to_reorient:
copy_graph.orient_uncertain_edge(u, v)
for u, v in to_add:
temp_cpdag.add_edge(v, u, temp_cpdag.undirected_edge_name)

# flag = True

return copy_graph