From e03591f15a299c2878bb722322b3ccbe27c8a6ea Mon Sep 17 00:00:00 2001 From: Adam Li Date: Mon, 30 Oct 2023 17:36:10 -0400 Subject: [PATCH] upgrade black and add torch min version Signed-off-by: Adam Li --- .github/workflows/main.yml | 4 +++- doc/reference/functional/index.rst | 2 +- poetry.lock | 2 +- pyproject.toml | 2 +- pywhy_graphs/algorithms/generic.py | 16 ++++++++-------- pywhy_graphs/algorithms/pag.py | 12 ++++++------ pywhy_graphs/algorithms/tests/test_cyclic.py | 2 +- pywhy_graphs/algorithms/tests/test_generic.py | 4 ---- pywhy_graphs/algorithms/tests/test_pag.py | 5 ++--- pywhy_graphs/array/api.py | 4 ++-- pywhy_graphs/classes/augmented.py | 2 +- pywhy_graphs/classes/timeseries/conversion.py | 8 +++++--- pywhy_graphs/classes/timeseries/mixededge.py | 4 +++- pywhy_graphs/export/pcalg.py | 2 +- pywhy_graphs/export/tests/test_ananke.py | 2 -- pywhy_graphs/functional/base.py | 8 ++++++-- pywhy_graphs/functional/discrete.py | 6 +++--- pywhy_graphs/functional/linear.py | 10 +++++----- pywhy_graphs/functional/multidomain.py | 4 ++-- pywhy_graphs/functional/utils.py | 3 ++- .../networkx/algorithms/causal/m_separation.py | 2 -- .../algorithms/causal/mixed_edge_moral.py | 1 - .../algorithms/causal/tests/test_convert.py | 1 - .../algorithms/causal/tests/test_m_separation.py | 3 --- pywhy_graphs/simulate.py | 8 ++++---- pywhy_graphs/viz/draw.py | 12 ++++++------ 26 files changed, 63 insertions(+), 66 deletions(-) diff --git a/.github/workflows/main.yml b/.github/workflows/main.yml index 1b9b5898a..bcf6d611c 100644 --- a/.github/workflows/main.yml +++ b/.github/workflows/main.yml @@ -136,7 +136,9 @@ jobs: with: poetry-version: ${{ matrix.poetry-version }} - name: Install Poetry Dynamic Versioning Plugin - run: pip install poetry-dynamic-versioning + run: | + pip install --upgrade pip + pip install poetry-dynamic-versioning - name: Install packages via poetry run: | poetry install --with test diff --git a/doc/reference/functional/index.rst b/doc/reference/functional/index.rst index 23b3f1f1a..eca175ae7 100644 --- a/doc/reference/functional/index.rst +++ b/doc/reference/functional/index.rst @@ -13,7 +13,7 @@ model. Currently, we only support linear models, but we plan to support non-line and we also do not support latent confounders yet. To add a latent confounder, one can add a confounder explicitly, generate the data -and then drop the confounder varialble in the final dataset. In the roadmap of this submodule, +and then drop the confounder variable in the final dataset. In the roadmap of this submodule, the plan is to represent any bidirected edge as a uniformly randomly distributed variable that has an additive noise effect on both variables simultaneously. diff --git a/poetry.lock b/poetry.lock index d00b7701c..93e6d217a 100644 --- a/poetry.lock +++ b/poetry.lock @@ -3751,4 +3751,4 @@ viz = ["pygraphviz"] [metadata] lock-version = "2.0" python-versions = ">=3.9,<3.12" -content-hash = "39c39952b3fe9418e79a9030dda27c7fc479acc06d70a03f22a332aa43d731f5" +content-hash = "627fa77ee29c72139600edcaf138535041aa77716ed5c2f26e38dee167c6bdb5" diff --git a/pyproject.toml b/pyproject.toml index 82ad92dc9..00cc58256 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -56,7 +56,7 @@ causal-learn = { version = "^0.1.2.8" } ananke-causal = { version = "^0.3.3" } pre-commit = "^3.0.4" pandas = { version = "^1.1" } # needed for simulation -torch = { version = "^2.0" } +torch = { version = "^2.0.0" } [tool.poetry.group.style] optional = true diff --git a/pywhy_graphs/algorithms/generic.py b/pywhy_graphs/algorithms/generic.py index cdee0be17..2c336df06 100644 --- a/pywhy_graphs/algorithms/generic.py +++ b/pywhy_graphs/algorithms/generic.py @@ -1,4 +1,4 @@ -from typing import List, Set, Union +from typing import List, Optional, Set, Union import networkx as nx @@ -20,7 +20,9 @@ ] -def is_node_common_cause(G: nx.DiGraph, node: Node, exclude_nodes: List[Node] = None) -> bool: +def is_node_common_cause( + G: nx.DiGraph, node: Node, exclude_nodes: Optional[List[Node]] = None +) -> bool: """Check if a node is a common cause within the graph. Parameters @@ -519,7 +521,7 @@ def _shortest_valid_path( return (path_exists, path) -def inducing_path(G, node_x: Node, node_y: Node, L: Set = None, S: Set = None): +def inducing_path(G, node_x: Node, node_y: Node, L: Optional[Set] = None, S: Optional[Set] = None): """Checks if an inducing path exists between two nodes. An inducing path is defined in :footcite:`Zhang2008`. @@ -599,7 +601,6 @@ def inducing_path(G, node_x: Node, node_y: Node, L: Set = None, S: Set = None): path_exists = False for elem in x_neighbors: - visited = {node_x} if elem not in visited: path_exists, temp_path = _shortest_valid_path( @@ -646,7 +647,7 @@ def has_adc(G): return adc_present -def valid_mag(G: ADMG, L: set = None, S: set = None): +def valid_mag(G: ADMG, L: Optional[set] = None, S: Optional[set] = None): """Checks if the provided graph is a valid maximal ancestral graph (MAG). A valid MAG as defined in :footcite:`Zhang2008` is a mixed edge graph that @@ -710,7 +711,7 @@ def valid_mag(G: ADMG, L: set = None, S: set = None): return True -def dag_to_mag(G, L: Set = None, S: Set = None): +def dag_to_mag(G, L: Optional[Set] = None, S: Optional[Set] = None): """Converts a DAG to a valid MAG. The algorithm is defined in :footcite:`Zhang2008` on page 1877. @@ -755,7 +756,6 @@ def dag_to_mag(G, L: Set = None, S: Set = None): mag = ADMG() for A, B in adj_nodes: - AuS = S.union(A) BuS = S.union(B) @@ -787,7 +787,7 @@ def dag_to_mag(G, L: Set = None, S: Set = None): return mag -def is_maximal(G, L: Set = None, S: Set = None): +def is_maximal(G, L: Optional[Set] = None, S: Optional[Set] = None): """Checks to see if the graph is maximal. Parameters: diff --git a/pywhy_graphs/algorithms/pag.py b/pywhy_graphs/algorithms/pag.py index d6c69e452..6f9e4f284 100644 --- a/pywhy_graphs/algorithms/pag.py +++ b/pywhy_graphs/algorithms/pag.py @@ -511,7 +511,7 @@ def uncovered_pd_path( def pds( - graph: PAG, node_x: Node, node_y: Node = None, max_path_length: Optional[int] = None + graph: PAG, node_x: Node, node_y: Optional[Node] = None, max_path_length: Optional[int] = None ) -> Set[Node]: """Find all PDS sets between node_x and node_y. @@ -712,7 +712,7 @@ def pds_path( for comp in biconn_comp: if (node_x, node_y) in comp or (node_y, node_x) in comp: # add all unique nodes in the biconnected component - for (x, y) in comp: + for x, y in comp: found_component.add(x) found_component.add(y) break @@ -1030,7 +1030,7 @@ def _meek_rule3(graph: CPDAG, i: str, j: str) -> bool: if graph.has_edge(i, j, graph.undirected_edge_name): # For all the pairs of nodes adjacent to i, # look for (k, l), such that j -> l and k -> l - for (k, l) in combinations(graph.neighbors(i), 2): + for k, l in combinations(graph.neighbors(i), 2): # Skip if k and l are adjacent. if l in graph.neighbors(k): continue @@ -1157,7 +1157,7 @@ def pag_to_mag(graph): while flag: undedges = temp_cpdag.undirected_edges if len(undedges) != 0: - for (u, v) in undedges: + for u, v in undedges: temp_cpdag.remove_edge(u, v, temp_cpdag.undirected_edge_name) temp_cpdag.add_edge(u, v, temp_cpdag.directed_edge_name) _apply_meek_rules(temp_cpdag) @@ -1169,10 +1169,10 @@ def pag_to_mag(graph): # construct the final MAG - for (u, v) in copy_graph.directed_edges: + for u, v in copy_graph.directed_edges: mag.add_edge(u, v, mag.directed_edge_name) - for (u, v) in temp_cpdag.directed_edges: + for u, v in temp_cpdag.directed_edges: mag.add_edge(u, v, mag.directed_edge_name) return mag diff --git a/pywhy_graphs/algorithms/tests/test_cyclic.py b/pywhy_graphs/algorithms/tests/test_cyclic.py index ec5eaa436..f47d2969f 100644 --- a/pywhy_graphs/algorithms/tests/test_cyclic.py +++ b/pywhy_graphs/algorithms/tests/test_cyclic.py @@ -83,7 +83,7 @@ def test_sigma_separated(): cyclic_G = pywhy_nx.MixedEdgeGraph(graphs=[cyclic_G], edge_types=["directed"]) cyclic_G.add_edge_type(nx.Graph(), edge_type="bidirected") - for (u, v) in combinations(cyclic_G.nodes, 2): + for u, v in combinations(cyclic_G.nodes, 2): other_nodes = set(cyclic_G.nodes) other_nodes.remove(u) other_nodes.remove(v) diff --git a/pywhy_graphs/algorithms/tests/test_generic.py b/pywhy_graphs/algorithms/tests/test_generic.py index a9876b57b..8b7e6c375 100644 --- a/pywhy_graphs/algorithms/tests/test_generic.py +++ b/pywhy_graphs/algorithms/tests/test_generic.py @@ -43,7 +43,6 @@ def test_convert_to_latent_confounder(graph_func): def test_inducing_path(): - admg = ADMG() admg.add_edge("X", "Y", admg.directed_edge_name) @@ -93,7 +92,6 @@ def test_inducing_path(): def test_inducing_path_wihtout_LandS(): - admg = ADMG() admg.add_edge("X", "Y", admg.directed_edge_name) @@ -113,7 +111,6 @@ def test_inducing_path_wihtout_LandS(): def test_inducing_path_one_direction(): - admg = ADMG() admg.add_edge("A", "B", admg.directed_edge_name) @@ -375,7 +372,6 @@ def test_valid_mag(): def test_dag_to_mag(): - # A -> E -> S # H -> E , H -> R admg = ADMG() diff --git a/pywhy_graphs/algorithms/tests/test_pag.py b/pywhy_graphs/algorithms/tests/test_pag.py index 1f0deb537..117210ea8 100644 --- a/pywhy_graphs/algorithms/tests/test_pag.py +++ b/pywhy_graphs/algorithms/tests/test_pag.py @@ -179,7 +179,7 @@ def test_discriminating_path(): ) for u in pag.nodes: - for (a, c) in permutations(pag.neighbors(u), 2): + for a, c in permutations(pag.neighbors(u), 2): found_discriminating_path, disc_path, _ = discriminating_path( pag, u, a, c, max_path_length=100 ) @@ -193,7 +193,7 @@ def test_discriminating_path(): pag.remove_edge("x2", "x5", pag.directed_edge_name) pag.add_edge("x5", "x2", pag.bidirected_edge_name) for u in pag.nodes: - for (a, c) in permutations(pag.neighbors(u), 2): + for a, c in permutations(pag.neighbors(u), 2): found_discriminating_path, disc_path, _ = discriminating_path( pag, u, a, c, max_path_length=100 ) @@ -650,7 +650,6 @@ def test_pdst(pdst_graph): def test_pag_to_mag(): - # C o- A o-> D <-o B # B o-o A o-o C o-> D diff --git a/pywhy_graphs/array/api.py b/pywhy_graphs/array/api.py index 37f9548b0..e209ead77 100644 --- a/pywhy_graphs/array/api.py +++ b/pywhy_graphs/array/api.py @@ -1,4 +1,4 @@ -from typing import Dict, List, Set +from typing import Dict, List, Optional, Set import numpy as np from numpy.typing import NDArray @@ -80,7 +80,7 @@ def get_summary_graph(arr: NDArray, arr_enum: str = "clearn"): def array_to_lagged_links( - arr: NDArray, arr_idx: List[Node] = None, include_weights: bool = True + arr: NDArray, arr_idx: Optional[List[Node]] = None, include_weights: bool = True ) -> Dict[Node, List[Set]]: """Convert a time-series 3D array to a dictionary of lagged links. diff --git a/pywhy_graphs/classes/augmented.py b/pywhy_graphs/classes/augmented.py index 3bf0c543f..c91f8da9d 100644 --- a/pywhy_graphs/classes/augmented.py +++ b/pywhy_graphs/classes/augmented.py @@ -154,7 +154,7 @@ def s_nodes(self) -> List[Node]: """Return set of S-nodes.""" return list(self.graph["S-nodes"].keys()) - def add_s_node(self, domain_ids: Tuple, node_changes: Set[Node] = None): + def add_s_node(self, domain_ids: Tuple, node_changes: Optional[Set[Node]] = None): if isinstance(node_changes, str) or not isinstance(node_changes, Iterable): raise RuntimeError("The intervention set nodes must be an iterable set of node(s).") diff --git a/pywhy_graphs/classes/timeseries/conversion.py b/pywhy_graphs/classes/timeseries/conversion.py index 94ec50472..b84ea4c99 100644 --- a/pywhy_graphs/classes/timeseries/conversion.py +++ b/pywhy_graphs/classes/timeseries/conversion.py @@ -1,4 +1,4 @@ -from typing import List +from typing import List, Optional import numpy as np @@ -7,7 +7,7 @@ from .graph import StationaryTimeSeriesGraph -def tsgraph_to_numpy(G, var_order: List[Node] = None): +def tsgraph_to_numpy(G, var_order: Optional[List[Node]] = None): """Convert stationary timeseries graph to numpy array. Parameters @@ -44,7 +44,9 @@ def tsgraph_to_numpy(G, var_order: List[Node] = None): return ts_graph_arr -def numpy_to_tsgraph(arr, var_order: List[Node] = None, create_using=StationaryTimeSeriesGraph): +def numpy_to_tsgraph( + arr, var_order: Optional[List[Node]] = None, create_using=StationaryTimeSeriesGraph +): """Convert 3D numpy array into a stationary time-series graph. Parameters diff --git a/pywhy_graphs/classes/timeseries/mixededge.py b/pywhy_graphs/classes/timeseries/mixededge.py index b873b901e..7f72922e5 100644 --- a/pywhy_graphs/classes/timeseries/mixededge.py +++ b/pywhy_graphs/classes/timeseries/mixededge.py @@ -1,3 +1,5 @@ +from typing import Optional + import numpy as np import pywhy_graphs.networkx as pywhy_nx @@ -178,7 +180,7 @@ class StationaryTimeSeriesMixedEdgeGraph(TimeSeriesMixedEdgeGraph): # supported graph types graph_types = (StationaryTimeSeriesGraph, StationaryTimeSeriesDiGraph) - def __init__(self, graphs=None, edge_types=None, max_lag: int = None, **attr): + def __init__(self, graphs=None, edge_types=None, max_lag: Optional[int] = None, **attr): super().__init__(graphs, edge_types, max_lag=max_lag, **attr) def set_stationarity(self, stationary: bool): diff --git a/pywhy_graphs/export/pcalg.py b/pywhy_graphs/export/pcalg.py index f5e81b198..9800af389 100644 --- a/pywhy_graphs/export/pcalg.py +++ b/pywhy_graphs/export/pcalg.py @@ -166,7 +166,7 @@ def graph_to_pcalg(causal_graph): # now map all values to their respective pcalg values seen_idx = dict() - for (idx, jdx) in np.argwhere(clearn_arr != 0): + for idx, jdx in np.argwhere(clearn_arr != 0): if (idx, jdx) in seen_idx or (jdx, idx) in seen_idx: continue diff --git a/pywhy_graphs/export/tests/test_ananke.py b/pywhy_graphs/export/tests/test_ananke.py index 66dc84318..ccffa5730 100644 --- a/pywhy_graphs/export/tests/test_ananke.py +++ b/pywhy_graphs/export/tests/test_ananke.py @@ -9,7 +9,6 @@ def dag(): - vertices = ["A", "B", "C", "D"] di_edges = [("A", "B"), ("B", "C"), ("C", "D")] graph = DAG(vertices=vertices, di_edges=di_edges) @@ -19,7 +18,6 @@ def dag(): def admg(): - vertices = ["A", "B", "C", "D"] di_edges = [("A", "B"), ("B", "C"), ("C", "D")] bi_edges = [("A", "C"), ("B", "D")] diff --git a/pywhy_graphs/functional/base.py b/pywhy_graphs/functional/base.py index 1a36cf52c..ece2bf79f 100644 --- a/pywhy_graphs/functional/base.py +++ b/pywhy_graphs/functional/base.py @@ -40,7 +40,7 @@ def add_parent_function(G: nx.DiGraph, node: Node, func: Callable) -> nx.DiGraph def add_noise_function( - G: nx.DiGraph, node: Node, distr_func: Callable, func: Callable = None + G: nx.DiGraph, node: Node, distr_func: Callable, func: Optional[Callable] = None ) -> nx.DiGraph: """Add function and distribution for a node's exogenous variable into the graph. @@ -120,7 +120,11 @@ def add_soft_intervention_function( def add_domain_shift_function( - G: AugmentedGraph, node: Node, s_node: Node, func: Callable = None, distr_func: Callable = None + G: AugmentedGraph, + node: Node, + s_node: Node, + func: Optional[Callable] = None, + distr_func: Optional[Callable] = None, ): """Add domain shift function for a node into the graph assuming invariant graph structure. diff --git a/pywhy_graphs/functional/discrete.py b/pywhy_graphs/functional/discrete.py index cc21a2466..1c8f30716 100644 --- a/pywhy_graphs/functional/discrete.py +++ b/pywhy_graphs/functional/discrete.py @@ -184,9 +184,9 @@ def parent_func(*args): def make_random_discrete_graph( G: nx.DiGraph, - cardinality_lims: Dict[Any, List[int]] = None, - weight_lims: Dict[Any, List[int]] = None, - noise_ratio_lims: List[float] = None, + cardinality_lims: Optional[Dict[Any, List[int]]] = None, + weight_lims: Optional[Dict[Any, List[int]]] = None, + noise_ratio_lims: Optional[List[float]] = None, overwrite: bool = False, random_state=None, ) -> nx.DiGraph: diff --git a/pywhy_graphs/functional/linear.py b/pywhy_graphs/functional/linear.py index b5eb8b9c4..71cee724e 100644 --- a/pywhy_graphs/functional/linear.py +++ b/pywhy_graphs/functional/linear.py @@ -1,4 +1,4 @@ -from typing import Callable, List, Set +from typing import Callable, List, Optional, Set import networkx as nx import numpy as np @@ -11,10 +11,10 @@ def make_graph_linear_gaussian( G: nx.DiGraph, - node_mean_lims: List[float] = None, - node_std_lims: List[float] = None, - edge_functions: List[Callable[[float], float]] = None, - edge_weight_lims: List[float] = None, + node_mean_lims: Optional[List[float]] = None, + node_std_lims: Optional[List[float]] = None, + edge_functions: Optional[List[Callable[[float], float]]] = None, + edge_weight_lims: Optional[List[float]] = None, random_state=None, ) -> nx.DiGraph: r"""Convert an existing DAG to a linear Gaussian graphical model. diff --git a/pywhy_graphs/functional/multidomain.py b/pywhy_graphs/functional/multidomain.py index cede33323..2f51927a9 100644 --- a/pywhy_graphs/functional/multidomain.py +++ b/pywhy_graphs/functional/multidomain.py @@ -23,7 +23,7 @@ def make_graph_multidomain( n_invariances_to_try: int = 1, node_mean_lims: Optional[List[float]] = None, node_std_lims: Optional[List[float]] = None, - edge_functions: List[Callable[[float], float]] = None, + edge_functions: Optional[List[Callable[[float], float]]] = None, edge_weight_lims: Optional[List[float]] = None, random_state=None, ) -> nx.DiGraph: @@ -263,7 +263,7 @@ def sample_multidomain_lin_functions( G: AugmentedGraph, node_mean_lims: Optional[List[float]] = None, node_std_lims: Optional[List[float]] = None, - edge_functions: List[Callable[[float], float]] = None, + edge_functions: Optional[List[Callable[[float], float]]] = None, edge_weight_lims: Optional[List[float]] = None, random_state=None, ): diff --git a/pywhy_graphs/functional/utils.py b/pywhy_graphs/functional/utils.py index 1bd424ff6..70fbb997f 100644 --- a/pywhy_graphs/functional/utils.py +++ b/pywhy_graphs/functional/utils.py @@ -1,4 +1,5 @@ import itertools +from typing import Optional import networkx as nx import numpy as np @@ -217,7 +218,7 @@ def _preprocess_parameter_inputs( edge_functions, edge_weight_lims, multi_domain: bool = False, - n_domains: int = None, + n_domains: Optional[int] = None, ): """Helper function to preprocess common parameter inputs for sampling functional graphs. diff --git a/pywhy_graphs/networkx/algorithms/causal/m_separation.py b/pywhy_graphs/networkx/algorithms/causal/m_separation.py index 84510e505..48308cdd1 100644 --- a/pywhy_graphs/networkx/algorithms/causal/m_separation.py +++ b/pywhy_graphs/networkx/algorithms/causal/m_separation.py @@ -110,7 +110,6 @@ def m_separated( G_bidirected = G.get_graphs(edge_type=bidirected_edge_name) while forward_deque or backward_deque: - if backward_deque: node = backward_deque.popleft() backward_visited.add(node) @@ -151,7 +150,6 @@ def m_separated( # Consider if *-> node <-* is opened due to conditioning on collider, # or descendant of collider if node in an_z: - if has_directed: # add <- edges to backward deque for x, _ in G_directed.in_edges(nbunch=node): diff --git a/pywhy_graphs/networkx/algorithms/causal/mixed_edge_moral.py b/pywhy_graphs/networkx/algorithms/causal/mixed_edge_moral.py index 6fee0634e..eeec16306 100644 --- a/pywhy_graphs/networkx/algorithms/causal/mixed_edge_moral.py +++ b/pywhy_graphs/networkx/algorithms/causal/mixed_edge_moral.py @@ -61,7 +61,6 @@ def mixed_edge_moral_graph( G_a = nx.compose(G_a, G_bidirected) for component in nx.connected_components(G_bidirected): - for u, v in itertools.combinations(component, 2): G_a.add_edge(u, v) all_parents = {parent for node in component for parent in G_directed.predecessors(node)} diff --git a/pywhy_graphs/networkx/algorithms/causal/tests/test_convert.py b/pywhy_graphs/networkx/algorithms/causal/tests/test_convert.py index b289372a4..7e3ff570e 100644 --- a/pywhy_graphs/networkx/algorithms/causal/tests/test_convert.py +++ b/pywhy_graphs/networkx/algorithms/causal/tests/test_convert.py @@ -4,7 +4,6 @@ def test_m_separation(): - # 0 -> 1 -> 2 -> 3 -> 4; 2 -> 4; 2 <-> 3 digraph = nx.path_graph(4, create_using=nx.DiGraph) digraph.add_edge(2, 4) diff --git a/pywhy_graphs/networkx/algorithms/causal/tests/test_m_separation.py b/pywhy_graphs/networkx/algorithms/causal/tests/test_m_separation.py index bc6aa56da..3aa19d26b 100644 --- a/pywhy_graphs/networkx/algorithms/causal/tests/test_m_separation.py +++ b/pywhy_graphs/networkx/algorithms/causal/tests/test_m_separation.py @@ -9,7 +9,6 @@ @pytest.fixture def fig5_vanderzander(): - nodes = ["V_1", "X", "V_2", "Y", "Z_1", "Z_2"] digraph = nx.DiGraph() @@ -35,7 +34,6 @@ def fig5_vanderzander(): @pytest.fixture def modified_fig5_vanderzander(): - nodes = ["V_1", "X", "V_2", "Y", "Z_1", "Z_2"] digraph = nx.DiGraph() @@ -239,7 +237,6 @@ def test_anterior(): def test_is_minimal_m_separator(fig5_vanderzander): - assert pywhy_nx.is_minimal_m_separator(fig5_vanderzander, "X", "Y", {"Z_1"}) assert pywhy_nx.is_minimal_m_separator(fig5_vanderzander, "X", "Y", {"Z_2"}) assert pywhy_nx.is_minimal_m_separator(fig5_vanderzander, "X", "Y", {"Z_2"}, r={"Z_1", "Z_2"}) diff --git a/pywhy_graphs/simulate.py b/pywhy_graphs/simulate.py index 1b1ebeb70..1ecc98837 100644 --- a/pywhy_graphs/simulate.py +++ b/pywhy_graphs/simulate.py @@ -11,7 +11,7 @@ def simulate_random_er_dag( - n_nodes: int, p: float = 0.5, seed: int = None, ensure_acyclic: bool = False + n_nodes: int, p: float = 0.5, seed: Optional[int] = None, ensure_acyclic: bool = False ): """Simulate a random Erdos-Renyi graph. @@ -109,7 +109,7 @@ def simulate_data_from_var( n_times: int = 1000, n_realizations: int = 1, var_names: Optional[List[Node]] = None, - random_state: int = None, + random_state: Optional[int] = None, ): """Simulate data from an already set VAR process. @@ -199,7 +199,7 @@ def simulate_linear_var_process( n_times: int = 1000, n_realizations: int = 1, weight_dist: Callable = scipy.stats.norm, - random_state: int = None, + random_state: Optional[int] = None, ): """Simulate a linear VAR process of a "stationary" causal graph. @@ -286,7 +286,7 @@ def simulate_linear_var_process( def simulate_var_process_from_summary_graph( - G: pywhy_nx.MixedEdgeGraph, max_lag=1, n_times=1000, random_state: int = None + G: pywhy_nx.MixedEdgeGraph, max_lag=1, n_times=1000, random_state: Optional[int] = None ): """Simulate a VAR(max_lag) process starting from a summary graph. diff --git a/pywhy_graphs/viz/draw.py b/pywhy_graphs/viz/draw.py index 7be5d561b..934fa39f4 100644 --- a/pywhy_graphs/viz/draw.py +++ b/pywhy_graphs/viz/draw.py @@ -5,10 +5,10 @@ def _draw_circle_edges( dot, - directed_edges: List[Tuple] = None, - circle_edges: List[Tuple] = None, - undirected_edges: List[Tuple] = None, - bidirected_edges: List[Tuple] = None, + directed_edges: Optional[List[Tuple]] = None, + circle_edges: Optional[List[Tuple]] = None, + undirected_edges: Optional[List[Tuple]] = None, + bidirected_edges: Optional[List[Tuple]] = None, **attrs, ): """Draw the PAG edges. @@ -52,7 +52,7 @@ def _draw_circle_edges( def _draw_un_edges( dot, - undirected_edges: List[Tuple] = None, + undirected_edges: Optional[List[Tuple]] = None, **attrs, ): """Draw undirected edges.""" @@ -65,7 +65,7 @@ def _draw_un_edges( def _draw_bi_edges( dot, - bidirected_edges: List[Tuple] = None, + bidirected_edges: Optional[List[Tuple]] = None, **attrs, ): """Draw bidirected edges."""