diff --git a/asdf/search.py b/asdf/search.py index d8507c5d8..531183d45 100644 --- a/asdf/search.py +++ b/asdf/search.py @@ -6,14 +6,21 @@ import re import typing +from . import _itertree from ._display import DEFAULT_MAX_COLS, DEFAULT_MAX_ROWS, DEFAULT_SHOW_VALUES, format_faint, format_italic, render_tree from ._node_info import NodeSchemaInfo, collect_schema_info -from .treeutil import _get_children, _is_container +from .treeutil import _is_container from .util import NotSet __all__ = ["AsdfSearchResult"] +def _get_children(obj): + if hasattr(obj, "__asdf_traverse__"): + obj = obj.__asdf_traverse__() + return _itertree._default_get_children(obj) + + class AsdfSearchResult: """ Result of a call to AsdfFile.search. @@ -222,11 +229,9 @@ def replace(self, value): """ results = [] - def _callback(identifiers, parent, node, children): - if all(f(node, identifiers[-1]) for f in self._filters): - results.append((identifiers[-1], parent)) - - _walk_tree_breadth_first(self._identifiers, self._node, _callback) + for node, edge in _itertree.breadth_first(self._node, get_children=_get_children): + if all(f(node, edge.key) for f in self._filters): + results.append((edge.key, edge.parent.node)) for identifier, parent in results: parent[identifier] = value @@ -284,11 +289,10 @@ def nodes(self): """ results = [] - def _callback(identifiers, parent, node, children): - if all(f(node, identifiers[-1]) for f in self._filters): + for node, edge in _itertree.breadth_first(self._node, get_children=_get_children): + if all(f(node, edge.key) for f in self._filters): results.append(node) - _walk_tree_breadth_first(self._identifiers, self._node, _callback) return results @property @@ -303,11 +307,10 @@ def paths(self): """ results = [] - def _callback(identifiers, parent, node, children): - if all(f(node, identifiers[-1]) for f in self._filters): - results.append(_build_path(identifiers)) + for node, edge in _itertree.breadth_first(self._node, get_children=_get_children): + if all(f(node, edge.key) for f in self._filters): + results.append(_build_path(self._identifiers + list(_itertree.edge_to_keys(edge)))) - _walk_tree_breadth_first(self._identifiers, self._node, _callback) return results def __repr__(self): @@ -369,31 +372,6 @@ def __getitem__(self, key): ) -def _walk_tree_breadth_first(root_identifiers, root_node, callback): - """ - Walk the tree in breadth-first order (useful for prioritizing - lower-depth nodes). - """ - current_nodes = [(root_identifiers, None, root_node)] - seen = set() - while True: - next_nodes = [] - - for identifiers, parent, node in current_nodes: - if (isinstance(node, (dict, list, tuple)) or NodeSchemaInfo.traversable(node)) and id(node) in seen: - continue - tnode = node.__asdf_traverse__() if NodeSchemaInfo.traversable(node) else node - children = _get_children(tnode) - callback(identifiers, parent, node, [c for _, c in children]) - next_nodes.extend([([*identifiers, i], node, c) for i, c in children]) - seen.add(id(node)) - - if len(next_nodes) == 0: - break - - current_nodes = next_nodes - - def _build_path(identifiers): """ Generate the Python code needed to extract the identified node.