diff --git a/src/cyclebane/graph.py b/src/cyclebane/graph.py index a2fbfb8..f242c86 100644 --- a/src/cyclebane/graph.py +++ b/src/cyclebane/graph.py @@ -195,8 +195,13 @@ def map( *, value_attr: str = 'value', ) -> Graph: - """For every value, create a new graph with all successors renamed, merge all - resulting graphs.""" + """ + Map the graph over the given values by associating source nodes with values. + + All successors of the mapped source nodes are replaced with new nodes, one for + each index value. The value is set as an attribute on the new source nodes + (but not their successors). + """ root_nodes = tuple(node_values.keys()) indices = self._get_indices(node_values) named = tuple(name for name, _ in indices if name is not None) @@ -231,7 +236,28 @@ def reduce( name: str, attrs: None | dict[str, Any] = None, ) -> Graph: - """Add edges from all nodes (key, index) to new node.""" + """ + Reduce over the given index or axis previously created with :py:meth:`map`. + ` + + If neither index nor axis is given, all axes are reduced. + + Parameters + ---------- + key: + The name of the source node to reduce. This is the original name prior to + mapping. Note that there is ambiguity if the same was used as 'name' in + a previous reduce operation over a subset of indices/axes. + index: + The name of the index to reduce over. Only one of index and axis can be + given. + axis: + The axis to reduce over. Only one of index and axis can be given. + name: + The name of the new node. + attrs: + Attributes to set on the new node. + """ attrs = attrs or {} if index is not None and axis is not None: raise ValueError('Only one of index and axis can be given') diff --git a/tests/graph_test.py b/tests/graph_test.py index 87741e2..6888474 100644 --- a/tests/graph_test.py +++ b/tests/graph_test.py @@ -10,6 +10,17 @@ import cyclebane as cb +def test_map_raises_if_mapping_nonexistent_node() -> None: + g = nx.DiGraph() + g.add_edge('a', 'b') + + graph = cb.Graph(g) + with pytest.raises(ValueError): + graph.map({'c': [1, 2]}) + with pytest.raises(ValueError): + graph.map({'a': [1, 2], 'c': [1, 2]}) + + def test_map_raises_if_mapping_non_source_node() -> None: g = nx.DiGraph() g.add_edge('a', 'b')