diff --git a/pywhy_graphs/algorithms/generic.py b/pywhy_graphs/algorithms/generic.py index 3737339a..d35cc523 100644 --- a/pywhy_graphs/algorithms/generic.py +++ b/pywhy_graphs/algorithms/generic.py @@ -873,7 +873,7 @@ def check_back_arrow(G: ADMG, X, Y: set): def get_X_neighbors(G, X: set): - out = [] + out = set() for elem in X: elem_neighbors = set(G.neighbors(elem)) @@ -883,53 +883,52 @@ def get_X_neighbors(G, X: set): if len(elem_neighbors) != 0: for nbh in elem_neighbors: - temp = dict() - temp[0] = elem - temp[1] = nbh - out.append(temp) + temp = (elem,) + temp = temp + (nbh,) + out.add(temp) return out def recursively_find_pd_paths(G, X, paths, Y): counter = 0 - new_paths = [] + new_paths = set() - for i in range(len(paths)): - cur_elem = paths[i][list(paths[i].keys())[-1]] + for elem in paths: + cur_elem = elem[-1] if cur_elem in Y: - new_paths.append(paths[i]) + new_paths.add(elem) continue nbr_temp = G.neighbors(cur_elem) nbr_possible = check_back_arrow(G, cur_elem, nbr_temp) if len(nbr_possible) == 0: - new_paths.append(paths[i].copy()) + new_paths = new_paths + (elem,) possible_end = nbr_possible.intersection(Y) if len(possible_end) != 0: - for elem in possible_end: - temp_path = paths[i].copy() - temp_path[len(temp_path)] = elem - new_paths.append(temp_path) + for nbr in possible_end: + temp_path = elem + temp_path = temp_path + (nbr,) + new_paths.add(temp_path) remaining_nodes = nbr_possible - possible_end remaining_nodes = ( remaining_nodes - - remaining_nodes.intersection(paths[i].values()) + - remaining_nodes.intersection(set(elem)) - remaining_nodes.intersection(X) ) - temp_arr = [] - for elem in remaining_nodes: - temp_paths = paths[i].copy() - temp_paths[len(temp_paths)] = elem - temp_arr.append(temp_paths) + temp_set = set() + for nbr in remaining_nodes: + temp_paths = elem + temp_paths = temp_paths + (nbr,) + temp_set.add(temp_paths) - new_paths.extend(recursively_find_pd_paths(G, X, temp_arr, Y)) + new_paths.update(recursively_find_pd_paths(G, X, temp_set, Y)) return new_paths @@ -950,5 +949,6 @@ def possibly_directed_path(G, X: Optional[Set] = None, Y: Optional[Set] = None): x_neighbors.append(temp) path_list = recursively_find_pd_paths(G, X, x_neighbors, Y) + print(path_list) return path_list diff --git a/pywhy_graphs/algorithms/tests/test_generic.py b/pywhy_graphs/algorithms/tests/test_generic.py index 3ae6c059..42c2c7fa 100644 --- a/pywhy_graphs/algorithms/tests/test_generic.py +++ b/pywhy_graphs/algorithms/tests/test_generic.py @@ -509,9 +509,9 @@ def test_possibly_directed(): Y = {"H"} X = {"Y"} - correct = [{0: "Y", 1: "X", 2: "Z", 3: "H"}] + correct = {('Y', 'X', 'Z', 'H')} out = pywhy_graphs.possibly_directed_path(admg, X, Y) - assert correct[0] == out[0] + assert correct == out admg = ADMG() admg.add_edge("A", "X", admg.directed_edge_name) @@ -522,10 +522,9 @@ def test_possibly_directed(): Y = {"H"} X = {"Y", "A"} - correct = [{0: "A", 1: "X", 2: "Z", 3: "H"}, {0: "Y", 1: "X", 2: "Z", 3: "H"}] + correct = {('Y', 'X', 'Z', 'H'), ('A', 'X', 'Z', 'H')} out = pywhy_graphs.possibly_directed_path(admg, X, Y) - assert correct[0] == out[0] - assert correct[1] == out[1] + assert correct == out admg = ADMG() admg.add_edge("X", "A", admg.directed_edge_name) @@ -536,9 +535,9 @@ def test_possibly_directed(): Y = {"H"} X = {"Y", "A"} - correct = [{0: "Y", 1: "X", 2: "Z", 3: "H"}] + correct = {('Y', 'X', 'Z', 'H')} out = pywhy_graphs.possibly_directed_path(admg, X, Y) - assert correct[0] == out[0] + assert correct == out admg = ADMG() admg.add_edge("X", "A", admg.directed_edge_name) @@ -550,9 +549,9 @@ def test_possibly_directed(): Y = {"H", "K"} X = {"Y", "A"} - correct = [{0: "Y", 1: "X", 2: "Z", 3: "H"}] + correct = {('Y', 'X', 'Z', 'H')} out = pywhy_graphs.possibly_directed_path(admg, X, Y) - assert correct[0] == out[0] + assert correct == out admg = ADMG() admg.add_edge("A", "X", admg.directed_edge_name) @@ -564,17 +563,9 @@ def test_possibly_directed(): Y = {"H", "K"} X = {"Y", "A"} - correct = [ - {0: "A", 1: "X", 2: "Z", 3: "H"}, - {0: "A", 1: "X", 2: "Z", 3: "K"}, - {0: "Y", 1: "X", 2: "Z", 3: "H"}, - {0: "Y", 1: "X", 2: "Z", 3: "K"}, - ] + correct = {('Y', 'X', 'Z', 'K'), ('A', 'X', 'Z', 'K'), ('Y', 'X', 'Z', 'H'), ('A', 'X', 'Z', 'H')} out = pywhy_graphs.possibly_directed_path(admg, X, Y) - assert correct[0] == out[0] - assert correct[1] == out[1] - assert correct[2] == out[2] - assert correct[3] == out[3] + assert correct == out admg = ADMG() admg.add_edge("A", "G", admg.directed_edge_name) @@ -587,10 +578,9 @@ def test_possibly_directed(): Y = {"H", "K"} X = {"Y", "A"} - correct = [{0: "A", 1: "G", 2: "C", 3: "H"}, {0: "Y", 1: "X", 2: "Z", 3: "K"}] + correct = {('Y', 'X', 'Z', 'K'), ('A', 'G', 'C', 'H')} out = pywhy_graphs.possibly_directed_path(admg, X, Y) - assert correct[0] == out[0] - assert correct[1] == out[1] + assert correct == out admg = ADMG() admg.add_edge("A", "G", admg.directed_edge_name) @@ -604,15 +594,9 @@ def test_possibly_directed(): Y = {"H", "K"} X = {"Y", "A"} - correct = [ - {0: "A", 1: "G", 2: "C", 3: "H"}, - {0: "Y", 1: "X", 2: "Z", 3: "K"}, - {0: "Y", 1: "X", 2: "Z", 3: "C", 4: "H"}, - ] + correct = {('Y', 'X', 'Z', 'K'), ('Y', 'X', 'Z', 'C', 'H'), ('A', 'G', 'C', 'H')} out = pywhy_graphs.possibly_directed_path(admg, X, Y) - assert correct[0] == out[0] - assert correct[1] == out[1] - assert correct[2] == out[2] + assert correct == out admg = ADMG() admg.add_edge("A", "G", admg.directed_edge_name) @@ -623,12 +607,9 @@ def test_possibly_directed(): Y = {"G", "H"} X = {"A", "K"} - correct = [{0: "K", 1: "G"}, {0: "K", 1: "H"}, {0: "A", 1: "G"}, {0: "A", 1: "H"}] + correct = {('K', 'H'), ('K', 'G'), ('A', 'G'), ('A', 'H')} out = pywhy_graphs.possibly_directed_path(admg, X, Y) - assert correct[0] == out[0] - assert correct[1] == out[1] - assert correct[2] == out[2] - assert correct[3] == out[3] + assert correct == out admg = ADMG() admg.add_edge("A", "G", admg.directed_edge_name) @@ -642,13 +623,12 @@ def test_possibly_directed(): Y = {"H", "K"} X = {"Y", "A"} - correct = [ - {0: "A", 1: "G", 2: "C", 3: "H"}, - {0: "Y", 1: "X", 2: "Z", 3: "K"}, - ] + correct = { + ("A","G","C","H"), + ("Y","X","Z","K"), + } out = pywhy_graphs.possibly_directed_path(admg, X, Y) - assert correct[0] == out[0] - assert correct[1] == out[1] + assert correct == out admg = ADMG() admg.add_edge("A", "G", admg.directed_edge_name) @@ -662,16 +642,9 @@ def test_possibly_directed(): Y = {"H", "K"} X = {"Y", "A"} - correct = [ - {0: "A", 1: "G", 2: "C", 3: "H"}, - {0: "Y", 1: "X", 2: "Z", 3: "K"}, - {0: "Y", 1: "X", 2: "Z", 3: "C", 4: "H"}, - {0: "A", 1: "G", 2: "C", 3: "Z", 4: "K"}, - ] + correct = {('Y', 'X', 'Z', 'K'), ('A', 'G', 'C', 'H')} out = pywhy_graphs.possibly_directed_path(admg, X, Y) - assert correct[0] == out[0] - assert correct[1] == out[1] - assert correct[2] == out[2] + assert correct == out admg = PAG() admg.add_edge("A", "G", admg.directed_edge_name) @@ -686,17 +659,6 @@ def test_possibly_directed(): Y = {"H", "K"} X = {"Y", "A"} - correct = [ - [ - {0: "Y", 1: "X", 2: "Z", 3: "K"}, - {0: "Y", 1: "X", 2: "Z", 3: "C", 4: "H"}, - {0: "A", 1: "G", 2: "C", 3: "H"}, - {0: "A", 1: "G", 2: "C", 3: "Z", 4: "K"}, - ] - ] - + correct = {('Y', 'X', 'Z', 'K'), ('Y', 'X', 'Z', 'C', 'H'), ('A', 'G', 'C', 'H'), ('A', 'G', 'C', 'Z', 'K')} out = pywhy_graphs.possibly_directed_path(admg, X, Y) - assert correct[0] == out[0] - assert correct[1] == out[1] - assert correct[2] == out[2] - assert correct[3] == out[3] + assert correct == out