diff --git a/README.md b/README.md index 1820715..29cf3e8 100644 --- a/README.md +++ b/README.md @@ -127,14 +127,14 @@ map_reduce_sort_recursive.to_graph(partition_counts=4).to_dot().write_png("map_r ![Map-Reduce Sort Recursive](docs/_static/map_reduce_sort_recursive.png) -Use the `to_dask` method to convert the generated graph to a Dask task graph. +Use the `to_dict` method to convert the generated graph to a dict graph. ```python import numpy as np from distributed import Client with Client() as client: - client.get(map_reduce_sort.to_graph(partition_count=4).to_dask(array=np.random.rand(20)))[0] + client.get(map_reduce_sort.to_graph(partition_count=4).to_dict(array=np.random.rand(20)))[0] # [0.06253707 0.06795382 0.11492823 0.14512393 0.20183152 0.41109117 # 0.42613798 0.45156214 0.4714821 0.54000373 0.54902451 0.62671881 diff --git a/pargraph/about.py b/pargraph/about.py index fa3ddd8..3e2f46a 100644 --- a/pargraph/about.py +++ b/pargraph/about.py @@ -1 +1 @@ -__version__ = "0.8.4" +__version__ = "0.9.0" diff --git a/pargraph/engine/engine.py b/pargraph/engine/engine.py index 3bd1d4a..402fa39 100644 --- a/pargraph/engine/engine.py +++ b/pargraph/engine/engine.py @@ -39,11 +39,11 @@ def set_parallel_backend(self, backend: Backend) -> None: """ self.backend = backend - def get(self, dsk: Dict, keys: Any, **kwargs) -> Any: + def get(self, graph: Dict, keys: Any, **kwargs) -> Any: """ - Compute task graph + Compute dict graph - :param dsk: dask-compatible task graph + :param graph: dict graph :param keys: keys to compute (e.g. ``"x"``, ``["x", "y", "z"]``, etc) :param kwargs: keyword arguments to forward to the parallel backend :return: results in the same structure as keys @@ -51,11 +51,11 @@ def get(self, dsk: Dict, keys: Any, **kwargs) -> Any: keyset = set(self._flatten_iter([keys])) # cull graph to remove any unnecessary dependencies - graphlib_graph = self._cull_graph(self._convert_dsk_to_graph(dsk), keyset) + graphlib_graph = self._cull_graph(self._get_graph_dependencies(graph), keyset) ref_count_graph = self._create_ref_count_graph(graphlib_graph) - graph = TopologicalSorter(graphlib_graph) - graph.prepare() + topological_sorter = TopologicalSorter(graphlib_graph) + topological_sorter.prepare() results: Dict[Hashable, Any] = {} future_to_key: Dict[Future[Any], Hashable] = {} @@ -95,14 +95,14 @@ def wait_for_completed_futures(): future_to_key.pop(done_future, None) done_keys.append(key) - graph.done(*done_keys) + topological_sorter.done(*done_keys) for done_key in done_keys: dereference_key(done_key) # while there are still unscheduled tasks - while graph.is_active(): + while topological_sorter.is_active(): # get in vertices - in_keys = graph.get_ready() + in_keys = topological_sorter.get_ready() # if there are no in-vertices, wait for a future to resolve # IMPORTANT: we make the assumption that the graph is acyclic @@ -111,7 +111,7 @@ def wait_for_completed_futures(): continue for in_key in in_keys: - computation = dsk[in_key] + computation = graph[in_key] if self._is_submittable_function_computation(computation): future = self._submit_function_computation(computation, results, **kwargs) @@ -119,7 +119,7 @@ def wait_for_completed_futures(): else: result = self._evaluate_computation(computation, results) results[in_key] = result - graph.done(in_key) + topological_sorter.done(in_key) dereference_key(in_key) # resolve all pending futures @@ -183,8 +183,8 @@ def _evaluate_computation(cls, computation: Any, results: Dict) -> Optional[Any] return computation @staticmethod - def _convert_dsk_to_graph(dsk: Dict) -> Dict: - keys = set(dsk.keys()) + def _get_graph_dependencies(graph: Dict) -> Dict: + keys = set(graph.keys()) def flatten(value: Any) -> Set[Any]: # handle tasks as tuples @@ -209,7 +209,7 @@ def flatten(value: Any) -> Set[Any]: return set() - return {key: flatten(value) for key, value in dsk.items()} + return {key: flatten(value) for key, value in graph.items()} @staticmethod def _create_ref_count_graph(graph: Dict) -> Dict: diff --git a/pargraph/graph/objects.py b/pargraph/graph/objects.py index b8b0681..49dacd0 100644 --- a/pargraph/graph/objects.py +++ b/pargraph/graph/objects.py @@ -97,10 +97,10 @@ def __post_init__(self): assert isinstance(self.value, str), f"Value must be a string; got type '{type(self.value)}'" @staticmethod - def from_dict(data: Dict) -> "Const": + def from_json(data: Dict) -> "Const": return Const(**data) - def to_dict(self) -> Dict[str, Any]: + def to_json(self) -> Dict[str, Any]: return {"type": self.type, "value": self.value} @staticmethod @@ -234,7 +234,7 @@ def __post_init__(self): ), f"Arg '{arg}' must ConstKey, InputKey, or NodeOutputKey; got type '{type(arg)}'" @staticmethod - def from_dict(data: Dict) -> "FunctionCall": + def from_json(data: Dict) -> "FunctionCall": data = data.copy() function = data.pop("function") return FunctionCall( @@ -247,7 +247,7 @@ def from_dict(data: Dict) -> "FunctionCall": **data, ) - def to_dict(self) -> Dict[str, Any]: + def to_json(self) -> Dict[str, Any]: return { "function": ( base64.b64encode(cloudpickle.dumps(self.function)).decode("ascii") @@ -279,16 +279,16 @@ def __post_init__(self): ), f"Arg '{arg}' must ConstKey, InputKey, or NodeOutputKey; got type '{type(arg)}'" @staticmethod - def from_dict(data: Dict) -> "GraphCall": + def from_json(data: Dict) -> "GraphCall": data = data.copy() return GraphCall( - graph=Graph.from_dict(data.pop("graph")), + graph=Graph.from_json(data.pop("graph")), args={arg: _get_key_from_str(key_str) for arg, key_str in data.pop("args").items()}, **data, ) - def to_dict(self) -> Dict[str, Any]: - dct: dict = {"graph": self.graph.to_dict(), "args": {arg: key.to_str() for arg, key in self.args.items()}} + def to_json(self) -> Dict[str, Any]: + dct: dict = {"graph": self.graph.to_json(), "args": {arg: key.to_str() for arg, key in self.args.items()}} if self.graph_name is not None: dct["graph_name"] = self.graph_name return dct @@ -342,22 +342,22 @@ def __post_init__(self): ), f"Output '{output}' must be type '{ConstKey}', '{InputKey}', or '{NodeOutputKey}'" @staticmethod - def from_dict(data: Dict) -> "Graph": + def from_json(data: Dict) -> "Graph": """ - Create graph from graph dict by inferring the graph dict type + Create graph from json serializable dictionary by inferring the graph type :param data: graph dict :return: graph """ if "edges" in data: - return Graph.from_dict_with_edge_list(data) + return Graph.from_json_with_edge_list(data) - return Graph.from_dict_with_node_arguments(data) + return Graph.from_json_with_node_arguments(data) @staticmethod - def from_dict_with_edge_list(data: Dict) -> "Graph": + def from_json_with_edge_list(data: Dict) -> "Graph": """ - Create graph from graph dict with edge list + Create graph from json serializable dictionary with edge list :param data: graph dict with edge list :return: graph @@ -411,53 +411,53 @@ def from_dict_with_edge_list(data: Dict) -> "Graph": outputs[key] = new_output - return Graph.from_dict_with_node_arguments(data) + return Graph.from_json_with_node_arguments(data) @staticmethod - def from_dict_with_node_arguments(data: Dict) -> "Graph": + def from_json_with_node_arguments(data: Dict) -> "Graph": """ - Create graph from graph dict with node arguments + Create graph from json serializable dictionary with node arguments :param data: graph dict with node arguments :return: graph """ - def _graph_node_from_dict(data: Union[Dict, str]) -> Union[FunctionCall, "GraphCall"]: + def _graph_node_from_json(data: Union[Dict, str]) -> Union[FunctionCall, "GraphCall"]: if isinstance(data, dict) and "function" in data: - return FunctionCall.from_dict(data) + return FunctionCall.from_json(data) elif isinstance(data, dict) and "graph" in data: - return GraphCall.from_dict(data) + return GraphCall.from_json(data) raise ValueError(f"invalid graph node dict '{data}'") data = data.copy() return Graph( - consts={ConstKey(key=key): Const.from_dict(value) for key, value in data.pop("consts").items()}, + consts={ConstKey(key=key): Const.from_json(value) for key, value in data.pop("consts").items()}, inputs={ InputKey(key=key): cast(ConstKey, _get_key_from_str(value)) if value is not None else None for key, value in data.pop("inputs").items() }, - nodes={NodeKey(key=key): _graph_node_from_dict(value) for key, value in data.pop("nodes").items()}, + nodes={NodeKey(key=key): _graph_node_from_json(value) for key, value in data.pop("nodes").items()}, outputs={OutputKey(key=key): _get_key_from_str(value) for key, value in data.pop("outputs").items()}, **data, ) - def to_dict(self) -> Dict[str, Any]: + def to_json(self) -> Dict[str, Any]: """ - Convert graph representation to serializable dictionary + Convert graph representation to json serializable dictionary - :return: graph dictionary + :return: json serializable dictionary """ graph_dict: GraphDict = {"consts": {}, "inputs": {}, "nodes": {}, "edges": [], "outputs": {}} for const_node_key, const_node in self.consts.items(): - graph_dict["consts"][const_node_key.key] = const_node.to_dict() + graph_dict["consts"][const_node_key.key] = const_node.to_json() for input_node_key, input_node in self.inputs.items(): graph_dict["inputs"][input_node_key.key] = input_node.to_str() if input_node is not None else None for func_node_key, func_node in self.nodes.items(): - func_node_dict = func_node.to_dict() + func_node_dict = func_node.to_json() func_node_dict.pop("args") graph_dict["nodes"][func_node_key.key] = func_node_dict @@ -483,16 +483,50 @@ def to_dict(self) -> Dict[str, Any]: return cast(dict, graph_dict) + def to_dict(self, *args, **kwargs) -> Tuple[Dict[str, Any], List[str]]: + """ + Convert graph to dict graph + + Dict graph representation: + + .. code-block:: json + + { + "a": 1, + "b": 2, + "sum": (add, "a", "b") + } + + Values can be: + + - Tasks: represented as tuples with the format ``(fn, *args)`` + - Constants: all other values + + :param args: positional arguments + :param kwargs: keyword arguments + :return: dict graph and output keys + """ + inputs: dict = {**dict(zip((key.key for key in self.inputs.keys()), args)), **kwargs} + return self._convert_graph_to_dict(inputs=inputs) + def to_dask(self, *args, **kwargs) -> Tuple[Dict[str, Any], List[str]]: """ Convert graph to dask graph + .. warning:: + + This method is deprecated and will be removed in a future release. + Please use :func:`to_dict` instead. + :param args: positional arguments :param kwargs: keyword arguments :return: dask graph and output keys """ - inputs: dict = {**dict(zip(self.inputs.keys(), args)), **kwargs} - return self._convert_graph_to_dask_graph(inputs=inputs) + warnings.warn( + "This method is deprecated and will be removed in a future release. Please use 'to_dict' instead.", + DeprecationWarning, + ) + return self.to_dict(*args, **kwargs) def to_dot( self, @@ -779,29 +813,29 @@ def _create_dot_edge(src: str, dst: str) -> pydot.Edge: return edge - def _convert_graph_to_dask_graph( + def _convert_graph_to_dict( self, inputs: Optional[Dict[str, Any]] = None, input_mapping: Optional[Dict[InputKey, str]] = None, output_mapping: Optional[Dict[OutputKey, str]] = None, ) -> Tuple[Dict[str, Any], List[str]]: """ - Convert our own graph format to a dask graph. + Convert our own graph format to a dict graph. :param inputs: inputs dictionary :param input_mapping: input mapping for subgraphs :param output_mapping: output mapping for subgraphs - :return: tuple containing dask graph and targets + :return: tuple containing dict graph and targets """ assert inputs is None or input_mapping is None, "cannot specify both inputs and input_mapping" - dask_graph: dict = {} + dict_graph: dict = {} key_to_uuid: dict = {} # create constants for const_key, const in self.consts.items(): graph_key = f"const_{self._get_const_label(const)}_{uuid.uuid4().hex}" - dask_graph[graph_key] = const.to_value() + dict_graph[graph_key] = const.to_value() key_to_uuid[const_key] = graph_key # create inputs @@ -809,7 +843,7 @@ def _convert_graph_to_dask_graph( for input_key in self.inputs.keys(): graph_key = f"input_{input_key.key}_{uuid.uuid4().hex}" # if input key is not in inputs, use the default value - dask_graph[graph_key] = ( + dict_graph[graph_key] = ( inputs[input_key.key] if input_key.key in inputs else self.consts[self.inputs[input_key]].to_value() ) key_to_uuid[input_key] = graph_key @@ -845,7 +879,7 @@ def _convert_graph_to_dask_graph( else: key_to_uuid[input_key] = key_to_uuid[const_path] - # build dask graph + # build dict graph for node_key, node in self.nodes.items(): if isinstance(node, FunctionCall): assert callable(node.function) @@ -862,7 +896,7 @@ def _convert_graph_to_dask_graph( # handle default arguments if param_name not in node.args: graph_key = f"const_{self._get_const_label(input_annotation.default)}_{uuid.uuid4().hex}" - dask_graph[graph_key] = input_annotation.default + dict_graph[graph_key] = input_annotation.default args.append(graph_key) continue @@ -884,10 +918,10 @@ def _convert_graph_to_dask_graph( break constant_key = f"const_{self._get_const_label(output_position)}_{uuid.uuid4().hex}" - dask_graph[constant_key] = output_position - dask_graph[graph_key] = (_unpack_tuple, node_uuid, constant_key) + dict_graph[constant_key] = output_position + dict_graph[graph_key] = (_unpack_tuple, node_uuid, constant_key) - dask_graph[node_uuid] = (node.function,) + tuple(args) + dict_graph[node_uuid] = (node.function,) + tuple(args) elif isinstance(node, GraphCall): new_input_mapping = { @@ -897,12 +931,12 @@ def _convert_graph_to_dask_graph( output_key: key_to_uuid[NodeOutputKey(key=node_key.key, output=output_key.key)] for output_key in node.graph.outputs } - dask_subgraph, _ = node.graph._convert_graph_to_dask_graph( + dict_subgraph, _ = node.graph._convert_graph_to_dict( input_mapping=new_input_mapping, output_mapping=new_output_mapping ) - dask_graph.update(dask_subgraph) + dict_graph.update(dict_subgraph) - return dask_graph, [key_to_uuid[output_path] for output_path in self.outputs.values()] + return dict_graph, [key_to_uuid[output_path] for output_path in self.outputs.values()] def _scramble_keys( self, old_to_new: Optional[bidict[Union[ConstKey, NodeKey], Union[ConstKey, NodeKey]]] = None diff --git a/tests/test_graph_generation.py b/tests/test_graph_generation.py index 85c57d6..d1c8614 100644 --- a/tests/test_graph_generation.py +++ b/tests/test_graph_generation.py @@ -30,7 +30,31 @@ def sample_graph(w: int, x: int, y: int, z: int) -> int: return add(add(w, x), add(y, z)) self.assertEqual( - self.engine.get(*sample_graph.to_graph().to_dask(w=1, x=2, y=3, z=4))[0], sample_graph(w=1, x=2, y=3, z=4) + self.engine.get(*sample_graph.to_graph().to_dict(w=1, x=2, y=3, z=4))[0], sample_graph(w=1, x=2, y=3, z=4) + ) + + def test_task_graph_positional_arguments(self): + @delayed + def add(x: int, y: int) -> int: + return x + y + + @graph + def sample_graph(w: int, x: int, y: int, z: int) -> int: + return add(add(w, x), add(y, z)) + + self.assertEqual(self.engine.get(*sample_graph.to_graph().to_dict(1, 2, 3, 4))[0], sample_graph(1, 2, 3, 4)) + + def test_task_graph_positional_and_keyword_arguments(self): + @delayed + def add(x: int, y: int) -> int: + return x + y + + @graph + def sample_graph(w: int, x: int, y: int, z: int) -> int: + return add(add(w, x), add(y, z)) + + self.assertEqual( + self.engine.get(*sample_graph.to_graph().to_dict(1, 2, y=3, z=4))[0], sample_graph(1, 2, y=3, z=4) ) def test_subgraph(self): @@ -47,7 +71,7 @@ def sample_graph(w: int, x: int, y: int, z: int) -> int: return sample_subgraph(sample_subgraph(w, x), sample_subgraph(y, z)) self.assertEqual( - self.engine.get(*sample_graph.to_graph().to_dask(w=1, x=2, y=3, z=4))[0], sample_graph(w=1, x=2, y=3, z=4) + self.engine.get(*sample_graph.to_graph().to_dict(w=1, x=2, y=3, z=4))[0], sample_graph(w=1, x=2, y=3, z=4) ) def test_basic_partial(self): @@ -60,7 +84,7 @@ def sample_graph(w: int, x: int, y: int, z: int) -> int: return add(add(w, x), add(y, z)) self.assertEqual( - self.engine.get(*sample_graph.to_graph(w=1, x=2).to_dask(y=3, z=4))[0], sample_graph(w=1, x=2, y=3, z=4) + self.engine.get(*sample_graph.to_graph(w=1, x=2).to_dict(y=3, z=4))[0], sample_graph(w=1, x=2, y=3, z=4) ) def test_variadic_arguments(self): @@ -72,7 +96,7 @@ def add(*args: int) -> int: def sample_graph(w: int, x: int, y: int, z: int) -> int: return add(w, x, y, z) - self.assertEqual(self.engine.get(*sample_graph.to_graph().to_dask(w=1, x=2, y=3, z=4))[0], add(1, 2, 3, 4)) + self.assertEqual(self.engine.get(*sample_graph.to_graph().to_dict(w=1, x=2, y=3, z=4))[0], add(1, 2, 3, 4)) def test_operator_override(self): @graph @@ -80,7 +104,7 @@ def sample_graph(w: int, x: int, y: int, z: int) -> int: return (w + x) + (y + z) self.assertEqual( - self.engine.get(*sample_graph.to_graph().to_dask(w=1, x=2, y=3, z=4))[0], sample_graph(w=1, x=2, y=3, z=4) + self.engine.get(*sample_graph.to_graph().to_dict(w=1, x=2, y=3, z=4))[0], sample_graph(w=1, x=2, y=3, z=4) ) def test_operator_override_complex(self): @@ -89,7 +113,7 @@ def fibonacci(n: int) -> int: phi = (1 + 5**0.5) / 2 return round(((phi**n) + ((1 - phi) ** n)) / 5**0.5) - self.assertEqual(self.engine.get(*fibonacci.to_graph().to_dask(n=6))[0], fibonacci(n=6)) + self.assertEqual(self.engine.get(*fibonacci.to_graph().to_dict(n=6))[0], fibonacci(n=6)) def test_getitem(self): @delayed @@ -100,7 +124,7 @@ def return_tuple(x: int, y: int) -> Any: def sample_graph(x: int, y: int) -> int: return return_tuple(x, y)[0] - self.assertEqual(self.engine.get(*sample_graph.to_graph().to_dask(x=1, y=2))[0], sample_graph(x=1, y=2)) + self.assertEqual(self.engine.get(*sample_graph.to_graph().to_dict(x=1, y=2))[0], sample_graph(x=1, y=2)) @unittest.skipIf(not PANDAS_INSTALLED, "pandas must be installed") def test_call(self): @@ -111,7 +135,7 @@ def sample_graph(s: pd.Series) -> int: return s.sum() self.assertEqual( - self.engine.get(*sample_graph.to_graph().to_dask(s=pd.Series([1, 2, 3])))[0], + self.engine.get(*sample_graph.to_graph().to_dict(s=pd.Series([1, 2, 3])))[0], sample_graph(s=pd.Series([1, 2, 3])), ) @@ -124,7 +148,7 @@ def sample_graph(s: pd.Series) -> pd.Series: return s[s > s.mean()] pd.testing.assert_series_equal( - self.engine.get(*sample_graph.to_graph().to_dask(s=pd.Series([1, 2, 3])))[0], + self.engine.get(*sample_graph.to_graph().to_dict(s=pd.Series([1, 2, 3])))[0], sample_graph(s=pd.Series([1, 2, 3])), ) @@ -138,7 +162,7 @@ def sample_graph(df1: pd.DataFrame, df2: pd.DataFrame) -> pd.DataFrame: pd.testing.assert_frame_equal( self.engine.get( - *sample_graph.to_graph().to_dask( + *sample_graph.to_graph().to_dict( df1=pd.DataFrame({"a": ["foo", "bar"], "b": [1, 2]}), df2=pd.DataFrame({"a": ["foo", "baz"], "c": [3, 4]}), ) @@ -163,7 +187,7 @@ def sample_graph(w: int, x: int, y: int, z: int) -> int: ) self.assertEqual( - self.engine.get(*sample_graph.to_graph().explode_subgraphs().to_dask(w=1, x=2, y=3, z=4))[0], + self.engine.get(*sample_graph.to_graph().explode_subgraphs().to_dict(w=1, x=2, y=3, z=4))[0], sample_graph(w=1, x=2, y=3, z=4), ) @@ -177,12 +201,12 @@ def sample_graph(w: int, x: int, y: int, z: int) -> int: return add(add(w, x), add(y, z)) self.assertNotEqual( - json.dumps(sample_graph.to_graph().to_dict()), json.dumps(sample_graph.to_graph().to_dict()) + json.dumps(sample_graph.to_graph().to_json()), json.dumps(sample_graph.to_graph().to_json()) ) self.assertEqual( - json.dumps(sample_graph.to_graph().stabilize().to_dict()), - json.dumps(sample_graph.to_graph().stabilize().to_dict()), + json.dumps(sample_graph.to_graph().stabilize().to_json()), + json.dumps(sample_graph.to_graph().stabilize().to_json()), ) def test_valid_delayed_signature(self): @@ -271,25 +295,25 @@ def test_graph_default_argument(self): def sample_graph(x: int, y: int = 1) -> int: return x + y - self.assertEqual(self.engine.get(*sample_graph.to_graph().to_dask(x=2, y=3))[0], sample_graph(x=2, y=3)) + self.assertEqual(self.engine.get(*sample_graph.to_graph().to_dict(x=2, y=3))[0], sample_graph(x=2, y=3)) def test_graph_default_argument_missing(self): @graph def sample_graph(x: int, y: int = 1) -> int: return x + y - self.assertEqual(self.engine.get(*sample_graph.to_graph().to_dask(x=2))[0], sample_graph(x=2)) + self.assertEqual(self.engine.get(*sample_graph.to_graph().to_dict(x=2))[0], sample_graph(x=2)) def test_function_default_argument(self): @graph def add(x: int, y: int = 1) -> int: return x + y - self.assertEqual(self.engine.get(*add.to_graph().to_dask(x=2, y=3))[0], add(x=2, y=3)) + self.assertEqual(self.engine.get(*add.to_graph().to_dict(x=2, y=3))[0], add(x=2, y=3)) def test_function_default_argument_missing(self): @graph def add(x: int, y: int = 1) -> int: return x + y - self.assertEqual(self.engine.get(*add.to_graph().to_dask(x=2))[0], add(x=2)) + self.assertEqual(self.engine.get(*add.to_graph().to_dict(x=2))[0], add(x=2))