Skip to content

Commit

Permalink
upgrade black and add torch min version
Browse files Browse the repository at this point in the history
Signed-off-by: Adam Li <[email protected]>
  • Loading branch information
adam2392 committed Oct 30, 2023
1 parent 3619dd5 commit e03591f
Show file tree
Hide file tree
Showing 26 changed files with 63 additions and 66 deletions.
4 changes: 3 additions & 1 deletion .github/workflows/main.yml
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,9 @@ jobs:
with:
poetry-version: ${{ matrix.poetry-version }}
- name: Install Poetry Dynamic Versioning Plugin
run: pip install poetry-dynamic-versioning
run: |
pip install --upgrade pip
pip install poetry-dynamic-versioning
- name: Install packages via poetry
run: |
poetry install --with test
Expand Down
2 changes: 1 addition & 1 deletion doc/reference/functional/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ model. Currently, we only support linear models, but we plan to support non-line
and we also do not support latent confounders yet.

To add a latent confounder, one can add a confounder explicitly, generate the data
and then drop the confounder varialble in the final dataset. In the roadmap of this submodule,
and then drop the confounder variable in the final dataset. In the roadmap of this submodule,
the plan is to represent any bidirected edge as a uniformly randomly distributed variable
that has an additive noise effect on both variables simultaneously.

Expand Down
2 changes: 1 addition & 1 deletion poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ causal-learn = { version = "^0.1.2.8" }
ananke-causal = { version = "^0.3.3" }
pre-commit = "^3.0.4"
pandas = { version = "^1.1" } # needed for simulation
torch = { version = "^2.0" }
torch = { version = "^2.0.0" }

[tool.poetry.group.style]
optional = true
Expand Down
16 changes: 8 additions & 8 deletions pywhy_graphs/algorithms/generic.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import List, Set, Union
from typing import List, Optional, Set, Union

import networkx as nx

Expand All @@ -20,7 +20,9 @@
]


def is_node_common_cause(G: nx.DiGraph, node: Node, exclude_nodes: List[Node] = None) -> bool:
def is_node_common_cause(
G: nx.DiGraph, node: Node, exclude_nodes: Optional[List[Node]] = None
) -> bool:
"""Check if a node is a common cause within the graph.
Parameters
Expand Down Expand Up @@ -519,7 +521,7 @@ def _shortest_valid_path(
return (path_exists, path)


def inducing_path(G, node_x: Node, node_y: Node, L: Set = None, S: Set = None):
def inducing_path(G, node_x: Node, node_y: Node, L: Optional[Set] = None, S: Optional[Set] = None):
"""Checks if an inducing path exists between two nodes.
An inducing path is defined in :footcite:`Zhang2008`.
Expand Down Expand Up @@ -599,7 +601,6 @@ def inducing_path(G, node_x: Node, node_y: Node, L: Set = None, S: Set = None):

path_exists = False
for elem in x_neighbors:

visited = {node_x}
if elem not in visited:
path_exists, temp_path = _shortest_valid_path(
Expand Down Expand Up @@ -646,7 +647,7 @@ def has_adc(G):
return adc_present


def valid_mag(G: ADMG, L: set = None, S: set = None):
def valid_mag(G: ADMG, L: Optional[set] = None, S: Optional[set] = None):
"""Checks if the provided graph is a valid maximal ancestral graph (MAG).
A valid MAG as defined in :footcite:`Zhang2008` is a mixed edge graph that
Expand Down Expand Up @@ -710,7 +711,7 @@ def valid_mag(G: ADMG, L: set = None, S: set = None):
return True


def dag_to_mag(G, L: Set = None, S: Set = None):
def dag_to_mag(G, L: Optional[Set] = None, S: Optional[Set] = None):
"""Converts a DAG to a valid MAG.
The algorithm is defined in :footcite:`Zhang2008` on page 1877.
Expand Down Expand Up @@ -755,7 +756,6 @@ def dag_to_mag(G, L: Set = None, S: Set = None):
mag = ADMG()

for A, B in adj_nodes:

AuS = S.union(A)
BuS = S.union(B)

Expand Down Expand Up @@ -787,7 +787,7 @@ def dag_to_mag(G, L: Set = None, S: Set = None):
return mag


def is_maximal(G, L: Set = None, S: Set = None):
def is_maximal(G, L: Optional[Set] = None, S: Optional[Set] = None):
"""Checks to see if the graph is maximal.
Parameters:
Expand Down
12 changes: 6 additions & 6 deletions pywhy_graphs/algorithms/pag.py
Original file line number Diff line number Diff line change
Expand Up @@ -511,7 +511,7 @@ def uncovered_pd_path(


def pds(
graph: PAG, node_x: Node, node_y: Node = None, max_path_length: Optional[int] = None
graph: PAG, node_x: Node, node_y: Optional[Node] = None, max_path_length: Optional[int] = None
) -> Set[Node]:
"""Find all PDS sets between node_x and node_y.
Expand Down Expand Up @@ -712,7 +712,7 @@ def pds_path(
for comp in biconn_comp:
if (node_x, node_y) in comp or (node_y, node_x) in comp:
# add all unique nodes in the biconnected component
for (x, y) in comp:
for x, y in comp:
found_component.add(x)
found_component.add(y)
break
Expand Down Expand Up @@ -1030,7 +1030,7 @@ def _meek_rule3(graph: CPDAG, i: str, j: str) -> bool:
if graph.has_edge(i, j, graph.undirected_edge_name):
# For all the pairs of nodes adjacent to i,
# look for (k, l), such that j -> l and k -> l
for (k, l) in combinations(graph.neighbors(i), 2):
for k, l in combinations(graph.neighbors(i), 2):
# Skip if k and l are adjacent.
if l in graph.neighbors(k):
continue
Expand Down Expand Up @@ -1157,7 +1157,7 @@ def pag_to_mag(graph):
while flag:
undedges = temp_cpdag.undirected_edges
if len(undedges) != 0:
for (u, v) in undedges:
for u, v in undedges:
temp_cpdag.remove_edge(u, v, temp_cpdag.undirected_edge_name)
temp_cpdag.add_edge(u, v, temp_cpdag.directed_edge_name)
_apply_meek_rules(temp_cpdag)
Expand All @@ -1169,10 +1169,10 @@ def pag_to_mag(graph):

# construct the final MAG

for (u, v) in copy_graph.directed_edges:
for u, v in copy_graph.directed_edges:
mag.add_edge(u, v, mag.directed_edge_name)

for (u, v) in temp_cpdag.directed_edges:
for u, v in temp_cpdag.directed_edges:
mag.add_edge(u, v, mag.directed_edge_name)

return mag
2 changes: 1 addition & 1 deletion pywhy_graphs/algorithms/tests/test_cyclic.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ def test_sigma_separated():
cyclic_G = pywhy_nx.MixedEdgeGraph(graphs=[cyclic_G], edge_types=["directed"])
cyclic_G.add_edge_type(nx.Graph(), edge_type="bidirected")

for (u, v) in combinations(cyclic_G.nodes, 2):
for u, v in combinations(cyclic_G.nodes, 2):
other_nodes = set(cyclic_G.nodes)
other_nodes.remove(u)
other_nodes.remove(v)
Expand Down
4 changes: 0 additions & 4 deletions pywhy_graphs/algorithms/tests/test_generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,6 @@ def test_convert_to_latent_confounder(graph_func):


def test_inducing_path():

admg = ADMG()

admg.add_edge("X", "Y", admg.directed_edge_name)
Expand Down Expand Up @@ -93,7 +92,6 @@ def test_inducing_path():


def test_inducing_path_wihtout_LandS():

admg = ADMG()

admg.add_edge("X", "Y", admg.directed_edge_name)
Expand All @@ -113,7 +111,6 @@ def test_inducing_path_wihtout_LandS():


def test_inducing_path_one_direction():

admg = ADMG()

admg.add_edge("A", "B", admg.directed_edge_name)
Expand Down Expand Up @@ -375,7 +372,6 @@ def test_valid_mag():


def test_dag_to_mag():

# A -> E -> S
# H -> E , H -> R
admg = ADMG()
Expand Down
5 changes: 2 additions & 3 deletions pywhy_graphs/algorithms/tests/test_pag.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,7 +179,7 @@ def test_discriminating_path():
)

for u in pag.nodes:
for (a, c) in permutations(pag.neighbors(u), 2):
for a, c in permutations(pag.neighbors(u), 2):
found_discriminating_path, disc_path, _ = discriminating_path(
pag, u, a, c, max_path_length=100
)
Expand All @@ -193,7 +193,7 @@ def test_discriminating_path():
pag.remove_edge("x2", "x5", pag.directed_edge_name)
pag.add_edge("x5", "x2", pag.bidirected_edge_name)
for u in pag.nodes:
for (a, c) in permutations(pag.neighbors(u), 2):
for a, c in permutations(pag.neighbors(u), 2):
found_discriminating_path, disc_path, _ = discriminating_path(
pag, u, a, c, max_path_length=100
)
Expand Down Expand Up @@ -650,7 +650,6 @@ def test_pdst(pdst_graph):


def test_pag_to_mag():

# C o- A o-> D <-o B
# B o-o A o-o C o-> D

Expand Down
4 changes: 2 additions & 2 deletions pywhy_graphs/array/api.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Dict, List, Set
from typing import Dict, List, Optional, Set

import numpy as np
from numpy.typing import NDArray
Expand Down Expand Up @@ -80,7 +80,7 @@ def get_summary_graph(arr: NDArray, arr_enum: str = "clearn"):


def array_to_lagged_links(
arr: NDArray, arr_idx: List[Node] = None, include_weights: bool = True
arr: NDArray, arr_idx: Optional[List[Node]] = None, include_weights: bool = True
) -> Dict[Node, List[Set]]:
"""Convert a time-series 3D array to a dictionary of lagged links.
Expand Down
2 changes: 1 addition & 1 deletion pywhy_graphs/classes/augmented.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,7 @@ def s_nodes(self) -> List[Node]:
"""Return set of S-nodes."""
return list(self.graph["S-nodes"].keys())

def add_s_node(self, domain_ids: Tuple, node_changes: Set[Node] = None):
def add_s_node(self, domain_ids: Tuple, node_changes: Optional[Set[Node]] = None):
if isinstance(node_changes, str) or not isinstance(node_changes, Iterable):
raise RuntimeError("The intervention set nodes must be an iterable set of node(s).")

Expand Down
8 changes: 5 additions & 3 deletions pywhy_graphs/classes/timeseries/conversion.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import List
from typing import List, Optional

import numpy as np

Expand All @@ -7,7 +7,7 @@
from .graph import StationaryTimeSeriesGraph


def tsgraph_to_numpy(G, var_order: List[Node] = None):
def tsgraph_to_numpy(G, var_order: Optional[List[Node]] = None):
"""Convert stationary timeseries graph to numpy array.
Parameters
Expand Down Expand Up @@ -44,7 +44,9 @@ def tsgraph_to_numpy(G, var_order: List[Node] = None):
return ts_graph_arr


def numpy_to_tsgraph(arr, var_order: List[Node] = None, create_using=StationaryTimeSeriesGraph):
def numpy_to_tsgraph(
arr, var_order: Optional[List[Node]] = None, create_using=StationaryTimeSeriesGraph
):
"""Convert 3D numpy array into a stationary time-series graph.
Parameters
Expand Down
4 changes: 3 additions & 1 deletion pywhy_graphs/classes/timeseries/mixededge.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from typing import Optional

import numpy as np

import pywhy_graphs.networkx as pywhy_nx
Expand Down Expand Up @@ -178,7 +180,7 @@ class StationaryTimeSeriesMixedEdgeGraph(TimeSeriesMixedEdgeGraph):
# supported graph types
graph_types = (StationaryTimeSeriesGraph, StationaryTimeSeriesDiGraph)

def __init__(self, graphs=None, edge_types=None, max_lag: int = None, **attr):
def __init__(self, graphs=None, edge_types=None, max_lag: Optional[int] = None, **attr):
super().__init__(graphs, edge_types, max_lag=max_lag, **attr)

def set_stationarity(self, stationary: bool):
Expand Down
2 changes: 1 addition & 1 deletion pywhy_graphs/export/pcalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,7 +166,7 @@ def graph_to_pcalg(causal_graph):

# now map all values to their respective pcalg values
seen_idx = dict()
for (idx, jdx) in np.argwhere(clearn_arr != 0):
for idx, jdx in np.argwhere(clearn_arr != 0):
if (idx, jdx) in seen_idx or (jdx, idx) in seen_idx:
continue

Expand Down
2 changes: 0 additions & 2 deletions pywhy_graphs/export/tests/test_ananke.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@


def dag():

vertices = ["A", "B", "C", "D"]
di_edges = [("A", "B"), ("B", "C"), ("C", "D")]
graph = DAG(vertices=vertices, di_edges=di_edges)
Expand All @@ -19,7 +18,6 @@ def dag():


def admg():

vertices = ["A", "B", "C", "D"]
di_edges = [("A", "B"), ("B", "C"), ("C", "D")]
bi_edges = [("A", "C"), ("B", "D")]
Expand Down
8 changes: 6 additions & 2 deletions pywhy_graphs/functional/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ def add_parent_function(G: nx.DiGraph, node: Node, func: Callable) -> nx.DiGraph


def add_noise_function(
G: nx.DiGraph, node: Node, distr_func: Callable, func: Callable = None
G: nx.DiGraph, node: Node, distr_func: Callable, func: Optional[Callable] = None
) -> nx.DiGraph:
"""Add function and distribution for a node's exogenous variable into the graph.
Expand Down Expand Up @@ -120,7 +120,11 @@ def add_soft_intervention_function(


def add_domain_shift_function(
G: AugmentedGraph, node: Node, s_node: Node, func: Callable = None, distr_func: Callable = None
G: AugmentedGraph,
node: Node,
s_node: Node,
func: Optional[Callable] = None,
distr_func: Optional[Callable] = None,
):
"""Add domain shift function for a node into the graph assuming invariant graph structure.
Expand Down
6 changes: 3 additions & 3 deletions pywhy_graphs/functional/discrete.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,9 +184,9 @@ def parent_func(*args):

def make_random_discrete_graph(
G: nx.DiGraph,
cardinality_lims: Dict[Any, List[int]] = None,
weight_lims: Dict[Any, List[int]] = None,
noise_ratio_lims: List[float] = None,
cardinality_lims: Optional[Dict[Any, List[int]]] = None,
weight_lims: Optional[Dict[Any, List[int]]] = None,
noise_ratio_lims: Optional[List[float]] = None,
overwrite: bool = False,
random_state=None,
) -> nx.DiGraph:
Expand Down
10 changes: 5 additions & 5 deletions pywhy_graphs/functional/linear.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Callable, List, Set
from typing import Callable, List, Optional, Set

import networkx as nx
import numpy as np
Expand All @@ -11,10 +11,10 @@

def make_graph_linear_gaussian(
G: nx.DiGraph,
node_mean_lims: List[float] = None,
node_std_lims: List[float] = None,
edge_functions: List[Callable[[float], float]] = None,
edge_weight_lims: List[float] = None,
node_mean_lims: Optional[List[float]] = None,
node_std_lims: Optional[List[float]] = None,
edge_functions: Optional[List[Callable[[float], float]]] = None,
edge_weight_lims: Optional[List[float]] = None,
random_state=None,
) -> nx.DiGraph:
r"""Convert an existing DAG to a linear Gaussian graphical model.
Expand Down
4 changes: 2 additions & 2 deletions pywhy_graphs/functional/multidomain.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ def make_graph_multidomain(
n_invariances_to_try: int = 1,
node_mean_lims: Optional[List[float]] = None,
node_std_lims: Optional[List[float]] = None,
edge_functions: List[Callable[[float], float]] = None,
edge_functions: Optional[List[Callable[[float], float]]] = None,
edge_weight_lims: Optional[List[float]] = None,
random_state=None,
) -> nx.DiGraph:
Expand Down Expand Up @@ -263,7 +263,7 @@ def sample_multidomain_lin_functions(
G: AugmentedGraph,
node_mean_lims: Optional[List[float]] = None,
node_std_lims: Optional[List[float]] = None,
edge_functions: List[Callable[[float], float]] = None,
edge_functions: Optional[List[Callable[[float], float]]] = None,
edge_weight_lims: Optional[List[float]] = None,
random_state=None,
):
Expand Down
3 changes: 2 additions & 1 deletion pywhy_graphs/functional/utils.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import itertools
from typing import Optional

import networkx as nx
import numpy as np
Expand Down Expand Up @@ -217,7 +218,7 @@ def _preprocess_parameter_inputs(
edge_functions,
edge_weight_lims,
multi_domain: bool = False,
n_domains: int = None,
n_domains: Optional[int] = None,
):
"""Helper function to preprocess common parameter inputs for sampling functional graphs.
Expand Down
Loading

0 comments on commit e03591f

Please sign in to comment.