From 6136ae0f058fa484913926dc2dd2efcf78b7a3ff Mon Sep 17 00:00:00 2001 From: Brett Date: Thu, 21 Dec 2023 11:44:22 -0500 Subject: [PATCH 01/16] fix bug where a dict with an id causes contained refs to fail to resolve --- CHANGES.rst | 3 +++ asdf/_tests/_regtests/test_1715.py | 28 ++++++++++++++++++++++++++++ asdf/reference.py | 4 ++-- 3 files changed, 33 insertions(+), 2 deletions(-) create mode 100644 asdf/_tests/_regtests/test_1715.py diff --git a/CHANGES.rst b/CHANGES.rst index 62a76753b..2222c346b 100644 --- a/CHANGES.rst +++ b/CHANGES.rst @@ -21,6 +21,9 @@ The ASDF Standard is at v1.6.0 is removed in an upcoming asdf release will be ``False`` and asdf will no longer by-default memory map arrays. [#1667] +- Fix bug where a dictionary containing a key ``id`` caused + any contained references to fail to resolve [#1716] + 3.0.1 (2023-10-30) ------------------ diff --git a/asdf/_tests/_regtests/test_1715.py b/asdf/_tests/_regtests/test_1715.py new file mode 100644 index 000000000..18c140eaa --- /dev/null +++ b/asdf/_tests/_regtests/test_1715.py @@ -0,0 +1,28 @@ +import pytest + +import asdf + + +def test_id_in_tree_breaks_ref(tmp_path): + """ + a dict containing id will break contained References + + https://github.com/asdf-format/asdf/issues/1715 + """ + external_fn = tmp_path / "external.asdf" + + external_tree = {"thing": 42} + + asdf.AsdfFile(external_tree).write_to(external_fn) + + main_fn = tmp_path / "main.asdf" + + af = asdf.AsdfFile({}) + af["id"] = "bogus" + af["myref"] = {"$ref": "external.asdf#/thing"} + af.write_to(main_fn) + + with pytest.warns(asdf.exceptions.AsdfDeprecationWarning, match="find_references"): + with asdf.open(main_fn) as af: + af.resolve_references() + assert af["myref"] == 42 diff --git a/asdf/reference.py b/asdf/reference.py index b408653e0..73ace1998 100644 --- a/asdf/reference.py +++ b/asdf/reference.py @@ -112,11 +112,11 @@ def find_references(tree, ctx, _warning_msg=False): `Reference` objects. """ - def do_find(tree, json_id): + def do_find(tree): if isinstance(tree, dict) and "$ref" in tree: if _warning_msg: warnings.warn(_warning_msg, AsdfDeprecationWarning) - return Reference(tree["$ref"], json_id, asdffile=ctx) + return Reference(tree["$ref"], asdffile=ctx) return tree return treeutil.walk_and_modify(tree, do_find, ignore_implicit_conversion=ctx._ignore_implicit_conversion) From 4130fe8070c737008f92f8953f1df382a41fcd1d Mon Sep 17 00:00:00 2001 From: Brett Date: Thu, 21 Dec 2023 16:26:59 -0500 Subject: [PATCH 02/16] deprecate asdf.treeutil.is_container and get_children --- asdf/_node_info.py | 4 ++-- asdf/_tests/test_deprecated.py | 10 ++++++++++ asdf/_tests/test_treeutil.py | 14 +++++++------- asdf/search.py | 6 +++--- asdf/treeutil.py | 22 +++++++++++++++++++--- 5 files changed, 41 insertions(+), 15 deletions(-) diff --git a/asdf/_node_info.py b/asdf/_node_info.py index e61b22241..6b1acc16f 100644 --- a/asdf/_node_info.py +++ b/asdf/_node_info.py @@ -2,7 +2,7 @@ from collections import namedtuple from .schema import load_schema -from .treeutil import get_children +from .treeutil import _get_children def _filter_tree(info, filters): @@ -290,7 +290,7 @@ def from_root_node(cls, key, root_identifier, root_node, schema=None, refresh_ex if parent is None: info.schema = schema - for child_identifier, child_node in get_children(t_node): + for child_identifier, child_node in _get_children(t_node): next_nodes.append((info, child_identifier, child_node)) if len(next_nodes) == 0: diff --git a/asdf/_tests/test_deprecated.py b/asdf/_tests/test_deprecated.py index 0df6b1f7e..fb9eecf40 100644 --- a/asdf/_tests/test_deprecated.py +++ b/asdf/_tests/test_deprecated.py @@ -67,3 +67,13 @@ def test_find_references_during_open_deprecation(tmp_path): with pytest.warns(AsdfDeprecationWarning, match="find_references during open"): with asdf.open(fn) as af: pass + + +def test_get_children_deprecation(): + with pytest.warns(AsdfDeprecationWarning, match="get_children is deprecated"): + asdf.treeutil.get_children({}) + + +def test_is_container_deprecation(): + with pytest.warns(AsdfDeprecationWarning, match="is_container is deprecated"): + asdf.treeutil.is_container({}) diff --git a/asdf/_tests/test_treeutil.py b/asdf/_tests/test_treeutil.py index 5e9139ee8..147a12da8 100644 --- a/asdf/_tests/test_treeutil.py +++ b/asdf/_tests/test_treeutil.py @@ -3,27 +3,27 @@ def test_get_children(): parent = ["foo", "bar"] - assert treeutil.get_children(parent) == [(0, "foo"), (1, "bar")] + assert treeutil._get_children(parent) == [(0, "foo"), (1, "bar")] parent = ("foo", "bar") - assert treeutil.get_children(parent) == [(0, "foo"), (1, "bar")] + assert treeutil._get_children(parent) == [(0, "foo"), (1, "bar")] parent = {"foo": "bar", "ding": "dong"} - assert sorted(treeutil.get_children(parent)) == sorted([("foo", "bar"), ("ding", "dong")]) + assert sorted(treeutil._get_children(parent)) == sorted([("foo", "bar"), ("ding", "dong")]) parent = "foo" - assert treeutil.get_children(parent) == [] + assert treeutil._get_children(parent) == [] parent = None - assert treeutil.get_children(parent) == [] + assert treeutil._get_children(parent) == [] def test_is_container(): for value in [[], {}, ()]: - assert treeutil.is_container(value) is True + assert treeutil._is_container(value) is True for value in ["foo", 12, 13.9827]: - assert treeutil.is_container(value) is False + assert treeutil._is_container(value) is False def test_walk_and_modify_shared_references(): diff --git a/asdf/search.py b/asdf/search.py index 08f0950a6..d8507c5d8 100644 --- a/asdf/search.py +++ b/asdf/search.py @@ -8,7 +8,7 @@ from ._display import DEFAULT_MAX_COLS, DEFAULT_MAX_ROWS, DEFAULT_SHOW_VALUES, format_faint, format_italic, render_tree from ._node_info import NodeSchemaInfo, collect_schema_info -from .treeutil import get_children, is_container +from .treeutil import _get_children, _is_container from .util import NotSet __all__ = ["AsdfSearchResult"] @@ -184,7 +184,7 @@ def _filter(node, identifier): return False if isinstance(value, typing.Pattern): - if is_container(node): + if _is_container(node): # The string representation of a container object tends to # include the child object values, but that's probably not # what searchers want. @@ -383,7 +383,7 @@ def _walk_tree_breadth_first(root_identifiers, root_node, callback): if (isinstance(node, (dict, list, tuple)) or NodeSchemaInfo.traversable(node)) and id(node) in seen: continue tnode = node.__asdf_traverse__() if NodeSchemaInfo.traversable(node) else node - children = get_children(tnode) + children = _get_children(tnode) callback(identifiers, parent, node, [c for _, c in children]) next_nodes.extend([([*identifiers, i], node, c) for i, c in children]) seen.add(id(node)) diff --git a/asdf/treeutil.py b/asdf/treeutil.py index 27b2989cb..17018a257 100644 --- a/asdf/treeutil.py +++ b/asdf/treeutil.py @@ -7,7 +7,7 @@ from contextlib import contextmanager from . import tagged -from .exceptions import AsdfWarning +from .exceptions import AsdfDeprecationWarning, AsdfWarning __all__ = ["walk", "iter_tree", "walk_and_modify", "get_children", "is_container", "PendingValue", "RemoveNode"] @@ -420,7 +420,7 @@ def _recurse(node, json_id=None): # call to walk_and_modify. -def get_children(node): +def _get_children(node): """ Retrieve the children (and their dict keys or list/tuple indices) of an ASDF tree node. @@ -446,7 +446,15 @@ def get_children(node): return [] -def is_container(node): +def get_children(node): + warnings.warn("asdf.treeutil.get_children is deprecated", AsdfDeprecationWarning) + return _get_children(node) + + +get_children.__doc__ = _get_children.__doc__ + + +def _is_container(node): """ Determine if an ASDF tree node is an instance of a "container" type (i.e., value may contain child nodes). @@ -462,3 +470,11 @@ def is_container(node): True if node is a container, False otherwise """ return isinstance(node, (dict, list, tuple)) + + +def is_container(node): + warnings.warn("asdf.treeutil.is_container is deprecated", AsdfDeprecationWarning) + return _is_container(node) + + +is_container.__doc__ = _is_container.__doc__ From 1a81242f02e34cd7859e0ad95998dee5910e6844 Mon Sep 17 00:00:00 2001 From: Brett Date: Thu, 21 Dec 2023 17:13:14 -0500 Subject: [PATCH 03/16] deprecate support for json_id in walk_and_modify callbacks --- asdf/_tests/test_deprecated.py | 8 ++++++++ asdf/schema.py | 2 +- asdf/treeutil.py | 15 ++++++++++----- docs/asdf/deprecations.rst | 7 +++++++ 4 files changed, 26 insertions(+), 6 deletions(-) diff --git a/asdf/_tests/test_deprecated.py b/asdf/_tests/test_deprecated.py index fb9eecf40..608bb4da4 100644 --- a/asdf/_tests/test_deprecated.py +++ b/asdf/_tests/test_deprecated.py @@ -77,3 +77,11 @@ def test_get_children_deprecation(): def test_is_container_deprecation(): with pytest.warns(AsdfDeprecationWarning, match="is_container is deprecated"): asdf.treeutil.is_container({}) + + +def test_json_id_deprecation(): + def callback(obj, json_id): + return obj + + with pytest.warns(AsdfDeprecationWarning, match="the json_id callback argument is deprecated"): + asdf.treeutil.walk_and_modify({"a": 1}, callback) diff --git a/asdf/schema.py b/asdf/schema.py index c09bad792..e5af67121 100644 --- a/asdf/schema.py +++ b/asdf/schema.py @@ -488,7 +488,7 @@ def resolve_refs(node, json_id): return node - schema = treeutil.walk_and_modify(schema, resolve_refs) + schema = treeutil.walk_and_modify(schema, resolve_refs, _track_id=True) return schema diff --git a/asdf/treeutil.py b/asdf/treeutil.py index 17018a257..975887985 100644 --- a/asdf/treeutil.py +++ b/asdf/treeutil.py @@ -220,7 +220,7 @@ def __repr__(self): RemoveNode = _RemoveNode() -def walk_and_modify(top, callback, ignore_implicit_conversion=False, postorder=True, _context=None): +def walk_and_modify(top, callback, ignore_implicit_conversion=False, postorder=True, _context=None, _track_id=False): """Modify a tree by walking it with a callback function. It also has the effect of doing a deep copy. @@ -234,14 +234,16 @@ def walk_and_modify(top, callback, ignore_implicit_conversion=False, postorder=T one or two arguments: - an instance from the tree - - a json id (optional) + - a json id (optional) DEPRECATED It may return a different instance in order to modify the tree. If the singleton instance `~asdf.treeutil._RemoveNode` is returned, the node will be removed from the tree. - The json id is the context under which any relative URLs - should be resolved. It may be `None` if no ids are in the file + The json id optional argument is deprecated. This function + will no longer track ids. The json id is the context under which + any relative URLs should be resolved. It may be `None` if no + ids are in the file The tree is traversed depth-first, with order specified by the ``postorder`` argument. @@ -266,7 +268,10 @@ def walk_and_modify(top, callback, ignore_implicit_conversion=False, postorder=T """ callback_arity = callback.__code__.co_argcount - if callback_arity < 1 or callback_arity > 2: + if callback_arity == 2: + if not _track_id: + warnings.warn("the json_id callback argument is deprecated", AsdfDeprecationWarning) + elif callback_arity != 1: msg = "Expected callback to accept one or two arguments" raise ValueError(msg) diff --git a/docs/asdf/deprecations.rst b/docs/asdf/deprecations.rst index 547ac9296..f975b2cec 100644 --- a/docs/asdf/deprecations.rst +++ b/docs/asdf/deprecations.rst @@ -20,6 +20,13 @@ Automatic calling of ``AsdfFile.find_references`` during calls to ``AsdfFile.__init__`` and ``asdf.open``. Call ``AsdfFile.find_references`` to find references. +``asdf.treeutil.get_children`` and ``asdf.treeutil.is_container`` are deprecated. +These were never intended to be public. + +The support for callbacks with a ``json_id`` to ``asdf.treeutil.walk_and_modify`` +is deprecated. Please use a different library or your own code to track ids during +tree modifications. + Version 3.0 =========== From 614bbfabe91ce954057d831141e68e63586116f7 Mon Sep 17 00:00:00 2001 From: Brett Date: Thu, 21 Dec 2023 17:20:35 -0500 Subject: [PATCH 04/16] add changelog --- CHANGES.rst | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/CHANGES.rst b/CHANGES.rst index 2222c346b..9fd08601a 100644 --- a/CHANGES.rst +++ b/CHANGES.rst @@ -24,6 +24,10 @@ The ASDF Standard is at v1.6.0 - Fix bug where a dictionary containing a key ``id`` caused any contained references to fail to resolve [#1716] +- Deprecate the following in ``asdf.treeutil``. + ``get_children``, ``is_container``, the ``json_id`` argument to + callbacks provided to ``walk_and_modify`` [#1719] + 3.0.1 (2023-10-30) ------------------ From fe48678e2e7139247cde9b86865988d90bb87601 Mon Sep 17 00:00:00 2001 From: Brett Date: Tue, 2 Jan 2024 12:02:12 -0500 Subject: [PATCH 05/16] walk_and_modify replacement --- asdf/_itertree.py | 441 ++++++++++++++++++++++++++++ asdf/_tests/test_array_blocks.py | 5 +- asdf/_tests/test_block_converter.py | 9 +- asdf/_tests/test_itertree.py | 376 ++++++++++++++++++++++++ asdf/_tests/test_schema.py | 2 + asdf/treeutil.py | 102 ++++--- asdf/yamlutil.py | 31 +- 7 files changed, 918 insertions(+), 48 deletions(-) create mode 100644 asdf/_itertree.py create mode 100644 asdf/_tests/test_itertree.py diff --git a/asdf/_itertree.py b/asdf/_itertree.py new file mode 100644 index 000000000..396bc06cf --- /dev/null +++ b/asdf/_itertree.py @@ -0,0 +1,441 @@ +""" +For modification, order is important +for a tree + + a + / \ +b c + / \ + d e + +when walking breadth-first down the tree, modify: +- a first +- b, c (any order) +- then d, e (any order) + +this means that a might get modified changing +where b and c come from + +when walking depth-first down the tree, modify: +- a first +- b or c +- if b, then c +- if c then d, e (any order) + +when walking leaf-first up the tree, modify: +- d, e (any order) +- c, b (any order) +- a +(note that this is the inverse of depth-first) + +""" +import collections + + +class _Edge: + __slots__ = ["parent", "key", "node"] + + def __init__(self, parent, key, node): + self.parent = parent + self.key = key # can be used to make path + self.node = node # can be used to get things like 'json_id', duplicate of obj in callback + + +class _RemoveNode: + """ + Class of the RemoveNode singleton instance. This instance is used + as a signal for `asdf.treeutil.walk_and_modify` to remove the + node received by the callback. + """ + + def __repr__(self): + return "RemoveNode" + + +RemoveNode = _RemoveNode() + + +def edge_to_keys(edge): + keys = [] + while edge.key is not None: + keys.append(edge.key) + edge = edge.parent + return tuple(keys[::-1]) + + +class _ShowValue: + __slots__ = ["obj", "obj_id"] + + def __init__(self, obj, obj_id): + self.obj = obj + self.obj_id = obj_id + + +def _default_get_children(obj): + if isinstance(obj, dict): + return obj.items() + elif isinstance(obj, (list, tuple)): + return enumerate(obj) + else: + return None + + +def breadth_first(d, get_children=None, skip_ids=None): + get_children = get_children or _default_get_children + skip_ids = skip_ids or set() + dq = collections.deque() + dq.append(_Edge(None, None, d)) + while dq: + edge = dq.popleft() + obj = edge.node + if id(obj) in skip_ids: + continue + yield obj, edge + children = get_children(obj) + if children: + skip_ids.add(id(obj)) + for key, value in children: + dq.append(_Edge(edge, key, value)) + + +def depth_first(d, get_children=None, skip_ids=None): + get_children = get_children or _default_get_children + skip_ids = skip_ids or set() + dq = collections.deque() + dq.append(_Edge(None, None, d)) + while dq: + edge = dq.pop() + obj = edge.node + obj_id = id(obj) + if obj_id in skip_ids: + continue + yield obj, edge + children = get_children(obj) + if children: + skip_ids.add(obj_id) + for key, value in children: + dq.append(_Edge(edge, key, value)) + + +def leaf_first(d, get_children=None, skip_ids=None): + get_children = get_children or _default_get_children + skip_ids = skip_ids or set() + dq = collections.deque() + dq.append(_Edge(None, None, d)) + while dq: + edge = dq.pop() + if isinstance(edge, _ShowValue): + edge = edge.obj + obj = edge.node + yield obj, edge + continue + obj = edge.node + obj_id = id(obj) + if obj_id in skip_ids: + continue + children = get_children(obj) + if children: + skip_ids.add(obj_id) + dq.append(_ShowValue(edge, obj_id)) + for key, value in children: + dq.append(_Edge(edge, key, value)) + continue + yield obj, edge + + +def _default_setitem(obj, key, value): + obj.__setitem__(key, value) + + +def _default_delitem(obj, key): + if key in obj: + obj.__delitem__(key) + + +def breadth_first_modify(d, callback, get_children=None, setitem=None, delitem=None, skip_ids=None): + get_children = get_children or _default_get_children + skip_ids = skip_ids or set() + cache = {} # TODO fix + setitem = setitem or _default_setitem + delitem = delitem or _default_delitem + dq = collections.deque() + dq.append(_Edge(None, None, d)) + while dq: + edge = dq.popleft() + obj = edge.node + obj_id = id(obj) + if obj_id in skip_ids: + continue + if obj_id not in cache: + cache[obj_id] = callback(obj, edge) + obj = cache[obj_id] + if edge.parent is not None: + if obj is RemoveNode: + delitem(edge.parent.node, edge.key) + continue + setitem(edge.parent.node, edge.key, obj) + children = get_children(obj) + if children: + skip_ids.add(obj_id) + for key, value in children: + dq.append(_Edge(edge, key, value)) + + +def depth_first_modify(d, callback, get_children=None, setitem=None, delitem=None, skip_ids=None): + get_children = get_children or _default_get_children + skip_ids = skip_ids or set() + cache = {} # TODO fix + setitem = setitem or _default_setitem + delitem = delitem or _default_delitem + dq = collections.deque() + dq.append(_Edge(None, None, d)) + while dq: + edge = dq.pop() + obj = edge.node + obj_id = id(obj) + if obj_id in skip_ids: + continue + if obj_id not in cache: + cache[obj_id] = callback(obj, edge) + obj = cache[obj_id] + if edge.parent is not None: + if obj is RemoveNode: + delitem(edge.parent.node, edge.key) + continue + setitem(edge.parent.node, edge.key, obj) + children = get_children(obj) + if children: + skip_ids.add(obj_id) + for key, value in children: + dq.append(_Edge(edge, key, value)) + + +def leaf_first_modify(d, callback, get_children=None, setitem=None, delitem=None, skip_ids=None): + get_children = get_children or _default_get_children + skip_ids = skip_ids or set() + cache = {} # TODO fix + setitem = setitem or _default_setitem + delitem = delitem or _default_delitem + dq = collections.deque() + dq.append(_Edge(None, None, d)) + while dq: + edge = dq.pop() + if isinstance(edge, _ShowValue): + edge = edge.obj + obj = edge.node + obj_id = id(obj) + if obj_id not in cache: + cache[obj_id] = callback(obj, edge) + obj = cache[obj_id] + if edge.parent is not None: + if obj is RemoveNode: + delitem(edge.parent.node, edge.key) + else: + setitem(edge.parent.node, edge.key, obj) + continue + obj = edge.node + obj_id = id(obj) + if obj_id in skip_ids: + continue + children = get_children(obj) + if children: + skip_ids.add(obj_id) + dq.append(_ShowValue(edge, obj_id)) + for key, value in children: + dq.append(_Edge(edge, key, value)) + continue + + if obj_id not in cache: + cache[obj_id] = callback(obj, edge) + obj = cache[obj_id] + if edge.parent is not None: + if obj is RemoveNode: + delitem(edge.parent.node, edge.key) + else: + setitem(edge.parent.node, edge.key, obj) + + +def _default_container_factory(obj): + if isinstance(obj, dict): + return dict() + elif isinstance(obj, (list, tuple)): + return [None] * len(obj) + raise NotImplementedError() + + +def breadth_first_modify_and_copy( + d, callback, get_children=None, setitem=None, delitem=None, container_factory=None, skip_ids=None +): + get_children = get_children or _default_get_children + skip_ids = skip_ids or set() + cache = {} # TODO fix + setitem = setitem or _default_setitem + delitem = delitem or _default_delitem + container_factory = container_factory or _default_container_factory + result = None + dq = collections.deque() + dq.append(_Edge(None, None, d)) + while dq: + edge = dq.popleft() + obj = edge.node + obj_id = id(obj) + if obj_id in skip_ids: + continue + if False and obj_id in cache: + obj = cache[obj_id] + else: + obj = callback(obj, edge) + if edge.parent is not None and obj is RemoveNode: + # TODO handle multiple list key deletion + delitem(edge.parent.node, edge.key) + continue + children = get_children(obj) + if children: + obj = container_factory(obj) + edge.node = obj + skip_ids.add(obj_id) + for key, value in children: + dq.append(_Edge(edge, key, value)) + # cache[obj_id] = obj + if result is None: + result = obj + if edge.parent is not None: + setitem(edge.parent.node, edge.key, obj) + return result + + +def depth_first_modify_and_copy( + d, callback, get_children=None, setitem=None, delitem=None, container_factory=None, skip_ids=None +): + get_children = get_children or _default_get_children + skip_ids = skip_ids or set() + cache = {} + setitem = setitem or _default_setitem + delitem = delitem or _default_delitem + container_factory = container_factory or _default_container_factory + result = None + dq = collections.deque() + dq.append(_Edge(None, None, d)) + while dq: + edge = dq.pop() + obj = edge.node + # print(obj) + obj_id = id(obj) + # if obj_id in skip_ids: + # #print(f"\tskip because of id {obj_id}") + # continue + if obj_id in cache: + # print("\tfrom cache") + new_obj = cache[obj_id][1] + else: + # print("\tstepping in") + new_obj = callback(obj, edge) + if edge.parent is not None and new_obj is RemoveNode: + # TODO handle multiple list key deletion + delitem(edge.parent.node, edge.key) + continue + children = get_children(new_obj) + if children: + container = container_factory(new_obj) + edge.node = container + # print(f"\tadding id {obj_id} to skips") + # skip_ids.add(obj_id) + for key, value in children: + dq.append(_Edge(edge, key, value)) + new_obj = container + cache[obj_id] = (obj, new_obj) + if result is None: + result = new_obj + if edge.parent is not None: + setitem(edge.parent.node, edge.key, new_obj) + # print(result) + return result + + +def leaf_first_modify_and_copy( + d, callback, get_children=None, setitem=None, delitem=None, container_factory=None, skip_ids=None +): + get_children = get_children or _default_get_children + skip_ids = skip_ids or set() + pending = {} + cache = {} + setitem = setitem or _default_setitem + delitem = delitem or _default_delitem + container_factory = container_factory or _default_container_factory + result = None + dq = collections.deque() + # print(f"Input obj id: {id(d['obj'])}") + # print(f"Input inverse id: {id(d['obj']['inverse'])}") + dq.append(_Edge(None, None, d)) + while dq: + edge = dq.pop() + # print(f"Processing {edge}") + # print(f"\tdq={dq}") + # print(f"\tcache={cache}") + if isinstance(edge, _ShowValue): + obj_id = edge.obj_id + edge = edge.obj + obj = edge.node + if obj is result: + result = None + if obj_id not in cache: + new_obj = callback(obj, edge) + cache[obj_id] = (obj, new_obj) + else: + new_obj = cache[obj_id][1] + if result is None: + result = new_obj + if edge.parent is not None: + if new_obj is RemoveNode: + # TODO handle multiple list key deletion + delitem(edge.parent.node, edge.key) + else: + setitem(edge.parent.node, edge.key, new_obj) + if obj_id in pending: + for edge in pending[obj_id]: + if new_obj is RemoveNode: + # TODO handle multiple list key deletion + delitem(edge.parent.node, edge.key) + else: + setitem(edge.parent.node, edge.key, new_obj) + del pending[obj_id] + continue + # print(f"\tNode id {id(edge.node)} at {edge.key} of {edge.parent}") + obj = edge.node + obj_id = id(obj) + # if obj_id in skip_ids: + # print("\tskipping") + # continue + if obj_id in cache: + # print("\tfrom cache") + new_obj = cache[obj_id][1] + else: + children = get_children(obj) + if children: + skip_ids.add(obj_id) + container = container_factory(obj) + pending[obj_id] = [] + if result is None: + result = container + edge.node = container + dq.append(_ShowValue(edge, obj_id)) + for key, value in children: + if id(value) in pending: + pending[id(value)].append(_Edge(edge, key, value)) + else: + dq.append(_Edge(edge, key, value)) + # if id(value) not in pending: + continue + # cache[obj_id] = callback(obj, edge) + # obj = cache[obj_id] + new_obj = callback(obj, edge) + cache[obj_id] = (obj, new_obj) + if result is None: + result = new_obj + if edge.parent is not None: + if new_obj is RemoveNode: + # TODO handle multiple list key deletion + delitem(edge.parent.node, edge.key) + else: + setitem(edge.parent.node, edge.key, new_obj) + return result diff --git a/asdf/_tests/test_array_blocks.py b/asdf/_tests/test_array_blocks.py index b6163ce08..1f79d7a2c 100644 --- a/asdf/_tests/test_array_blocks.py +++ b/asdf/_tests/test_array_blocks.py @@ -589,10 +589,11 @@ def test_block_index(): assert ff2._blocks.blocks[99].loaded # Force the loading of one array - ff2.tree["arrays"][50] * 2 + arr = ff2.tree["arrays"][50] + arr * 2 for i in range(2, 99): - if i == 50: + if i == arr._source: assert ff2._blocks.blocks[i].loaded else: assert not ff2._blocks.blocks[i].loaded diff --git a/asdf/_tests/test_block_converter.py b/asdf/_tests/test_block_converter.py index 462f7c363..8e5ce688c 100644 --- a/asdf/_tests/test_block_converter.py +++ b/asdf/_tests/test_block_converter.py @@ -1,4 +1,5 @@ import contextlib +from io import BytesIO import numpy as np from numpy.testing import assert_array_equal @@ -125,8 +126,12 @@ def test_block_data_callback_converter(tmp_path): # id(arr) would change every time a = BlockDataCallback(lambda: np.zeros(3, dtype="uint8")) - b = helpers.roundtrip_object(a) - assert_array_equal(a.data, b.data) + bs = BytesIO() + af = asdf.AsdfFile({"obj": a}) + af.write_to(bs) + bs.seek(0) + with asdf.open(bs, lazy_load=False, memmap=False) as af: + assert_array_equal(a.data, af["obj"].data) # make a tree without the BlockData instance to avoid # the initial validate which will trigger block allocation diff --git a/asdf/_tests/test_itertree.py b/asdf/_tests/test_itertree.py new file mode 100644 index 000000000..e5995150b --- /dev/null +++ b/asdf/_tests/test_itertree.py @@ -0,0 +1,376 @@ +""" +Reorganize these tests into a fixture that generates test trees +Each test tree returns: + - the tree + - the breadth-first order (need to account for multiple valid paths) + - the depth-first order (need to account for multiple valid paths) +The orderings above need to be reversible to allow postorder tests +to also check the path. + +For modification tests the callbacks can also check the order +and can modify the tree in a way that should impact later callbacks. +""" +import copy + +from asdf import _itertree + + +def test_breadth_first_traversal(): + tree = { + "a": { + "b": [1, 2, {"c": 3}], + "d": 4, + }, + "e": [5, 6, [7, 8, {"f": 9}]], + } + # It is ok for results to come in any order as long as + # all nodes closer to the root come before any more distant + # node. Track those here by ordering expected results by 'layer' + expected_results = [ + [ + tree, + ], + [tree["a"], tree["e"]], + [tree["a"]["b"], tree["a"]["d"], tree["e"][0], tree["e"][1], tree["e"][2]], + [tree["a"]["b"][0], tree["a"]["b"][1], tree["a"]["b"][2], tree["e"][2][0], tree["e"][2][1], tree["e"][2][2]], + [tree["a"]["b"][2]["c"], tree["e"][2][2]["f"]], + ] + + expected = [] + for node, edge in _itertree.breadth_first(tree): + if not len(expected): + expected = expected_results.pop(0) + assert node in expected + expected.remove(node) + assert not expected_results + + +def test_recursive_breadth_first_traversal(): + tree = { + "a": {}, + "b": {}, + } + tree["a"]["b"] = tree["b"] + tree["b"]["a"] = tree["a"] + + expected_results = [ + [ + tree, + ], + [tree["a"], tree["b"]], + ] + + expected = [] + for node, edge in _itertree.breadth_first(tree): + if not len(expected): + expected = expected_results.pop(0) + assert node in expected + expected.remove(node) + assert not expected_results + + +def test_leaf_first_traversal(): + tree = { + "a": { + "b": [1, 2, {"c": 3}], + "d": 4, + }, + "e": [5, 6, [7, 8, {"f": 9}]], + } + seen_keys = set() + reverse_paths = { + ("e", 2, 2, "f"): [("e", 2, 2), ("e", 2, 1), ("e", 2, 0)], + ("e", 2, 2): [("e", 0), ("e", 1), ("e", 2)], + ("e", 2, 1): [("e", 0), ("e", 1), ("e", 2)], + ("e", 2, 0): [("e", 0), ("e", 1), ("e", 2)], + ("e", 2): [("e",), ("a",)], + ("e", 1): [("e",), ("a",)], + ("e", 0): [("e",), ("a",)], + ("e",): [()], + ("a", "b", 2, "c"): [("a", "b", 0), ("a", "b", 1), ("a", "b", 2)], + ("a", "b", 2): [("a", "b"), ("a", "d")], + ("a", "b", 1): [("a", "b"), ("a", "d")], + ("a", "b", 0): [("a", "b"), ("a", "d")], + ("a", "b"): [("a",), ("e",)], + ("a", "d"): [("a",), ("e",)], + ("a",): [()], + (): [], + } + expected = { + ("e", 2, 2, "f"), + ("a", "b", 2, "c"), + ("a", "d"), + } + for node, edge in _itertree.leaf_first(tree): + keys = _itertree.edge_to_keys(edge) + assert keys in expected + obj = tree + for key in keys: + obj = obj[key] + assert obj is node + + # updated expected + seen_keys.add(keys) + expected.remove(keys) + for new_keys in reverse_paths[keys]: + if new_keys in seen_keys: + continue + expected.add(new_keys) + assert not expected + + +def test_recursive_leaf_first_traversal(): + tree = { + "a": {}, + "b": {}, + } + tree["a"]["b"] = tree["b"] + tree["b"]["a"] = tree["a"] + + seen_keys = set() + visit_ids = { + id(tree), + id(tree["a"]), + id(tree["b"]), + } + reverse_paths = { + ("a", "b"): [("a",), ("b",)], + ("b", "a"): [("a",), ("b",)], + ("a",): [()], + ("b",): [()], + (): [], + } + expected = { + ("a", "b"), + ("b", "a"), + } + for node, edge in _itertree.leaf_first(tree): + keys = _itertree.edge_to_keys(edge) + assert keys in expected + obj = tree + for key in keys: + obj = obj[key] + assert obj is node + visit_ids.remove(id(node)) + + # updated expected + seen_keys.add(keys) + expected.remove(keys) + for new_keys in reverse_paths[keys]: + if new_keys in seen_keys: + continue + expected.add(new_keys) + assert not visit_ids + + +def test_depth_first_traversal(): + tree = { + "a": { + "b": [1, 2, {"c": 3}], + "d": 4, + }, + "e": [5, 6, [7, 8, {"f": 9}]], + } + forward_paths = { + (): [("a",), ("e",)], + ("a",): [("a", "b"), ("a", "d")], + ("a", "b"): [("a", "b", 0), ("a", "b", 1), ("a", "b", 2)], + ("a", "b", 0): [], + ("a", "b", 1): [], + ("a", "b", 2): [("a", "b", 2, "c")], + ("a", "b", 2, "c"): [], + ("a", "d"): [], + ("e",): [("e", 0), ("e", 1), ("e", 2)], + ("e", 0): [], + ("e", 1): [], + ("e", 2): [("e", 2, 0), ("e", 2, 1), ("e", 2, 2)], + ("e", 2, 0): [], + ("e", 2, 1): [], + ("e", 2, 2): [("e", 2, 2, "f")], + ("e", 2, 2, "f"): [], + } + expected = {()} + seen_keys = set() + + for node, edge in _itertree.depth_first(tree): + keys = _itertree.edge_to_keys(edge) + assert keys in expected + obj = tree + for key in keys: + obj = obj[key] + assert obj is node + + # updated expected + seen_keys.add(keys) + expected.remove(keys) + for new_keys in forward_paths[keys]: + if new_keys in seen_keys: + continue + expected.add(new_keys) + assert not expected + + +def test_recursive_depth_first_traversal(): + tree = { + "a": {}, + "b": {}, + } + tree["a"]["b"] = tree["b"] + tree["b"]["a"] = tree["a"] + + seen_keys = set() + visit_ids = { + id(tree), + id(tree["a"]), + id(tree["b"]), + } + forward_paths = { + (): [("a",), ("b",)], + ("a",): [("a", "b")], + ("b",): [("b", "a")], + ("a", "b"): [], + ("b", "a"): [], + } + expected = { + (), + } + for node, edge in _itertree.depth_first(tree): + keys = _itertree.edge_to_keys(edge) + assert keys in expected + obj = tree + for key in keys: + obj = obj[key] + assert obj is node + visit_ids.remove(id(node)) + + # updated expected + seen_keys.add(keys) + expected.remove(keys) + for new_keys in forward_paths[keys]: + if new_keys in seen_keys: + continue + expected.add(new_keys) + assert not visit_ids + + +def test_breadth_first_modify(): + tree = { + "a": [1, 2, {"b": 3}], + "c": { + "d": [4, 5, 6], + }, + } + + def callback(obj, keys): + if isinstance(obj, list) and 1 in obj: + return [1, 2, 3] + if isinstance(obj, dict): + assert "b" not in obj + return obj + + _itertree.breadth_first_modify(tree, callback) + assert tree["a"] == [1, 2, 3] + + +def test_depth_first_modify(): + tree = { + "a": [1, 2, {"b": 3}], + "c": { + "d": [4, 5, 6], + }, + } + + def callback(obj, keys): + if isinstance(obj, dict) and "d" in obj: + obj["d"] = [42] + if isinstance(obj, list) and 42 in obj: + assert len(obj) == 1 + return obj + + _itertree.depth_first_modify(tree, callback) + assert tree["c"]["d"] == [42] + + +def test_leaf_first_modify(): + tree = { + "a": [1, 2, {"b": 3}], + "c": { + "d": [4, 5, 6], + }, + } + + def callback(obj, keys): + if isinstance(obj, list) and 1 in obj: + assert 42 in obj + if isinstance(obj, dict) and "b" in obj: + return 42 + return obj + + _itertree.leaf_first_modify(tree, callback) + assert tree["a"] == [1, 2, 42] + + +def test_breadth_first_modify_and_copy(): + tree = { + "a": [1, 2, {"b": 3}], + "c": { + "d": [4, 5, 6], + }, + } + + def callback(obj, keys): + if isinstance(obj, list) and 1 in obj: + assert 42 not in obj + if isinstance(obj, dict) and "b" in obj: + return 42 + return obj + + # copy the tree to make sure it's not modified + copied_tree = copy.deepcopy(tree) + result = _itertree.breadth_first_modify_and_copy(copied_tree, callback) + assert result["a"] == [1, 2, 42] + assert copied_tree == tree + + +def test_depth_first_modify_and_copy(): + tree = { + "a": [1, 2, {"b": 3}], + "c": { + "d": [4, 5, 6], + }, + } + + def callback(obj, keys): + if isinstance(obj, list) and 1 in obj: + assert 42 not in obj + if isinstance(obj, dict) and "b" in obj: + return 42 + return obj + + # copy the tree to make sure it's not modified + copied_tree = copy.deepcopy(tree) + result = _itertree.depth_first_modify_and_copy(copied_tree, callback) + assert result["a"] == [1, 2, 42] + assert copied_tree == tree + + +def test_leaf_first_modify_and_copy(): + tree = { + "a": [1, 2, {"b": 3}], + "c": { + "d": [4, 5, 6], + }, + } + + def callback(obj, keys): + if isinstance(obj, list) and 1 in obj: + assert 42 in obj + if isinstance(obj, dict) and "b" in obj: + return 42 + return obj + + # copy the tree to make sure it's not modified + copied_tree = copy.deepcopy(tree) + result = _itertree.leaf_first_modify_and_copy(copied_tree, callback) + assert result["a"] == [1, 2, 42] + assert copied_tree == tree diff --git a/asdf/_tests/test_schema.py b/asdf/_tests/test_schema.py index 82cea9238..996ed8fbe 100644 --- a/asdf/_tests/test_schema.py +++ b/asdf/_tests/test_schema.py @@ -1216,6 +1216,8 @@ def test_validator_visit_repeat_nodes(): node = asdf.tags.core.Software(name="Minesweeper") tree = yamlutil.custom_tree_to_tagged_tree({"node": node, "other_node": node, "nested": {"node": node}}, ctx) + assert tree["node"] is tree["other_node"] + assert tree["node"] is tree["nested"]["node"] visited_nodes = [] def _test_validator(validator, value, instance, schema): diff --git a/asdf/treeutil.py b/asdf/treeutil.py index 975887985..a951094af 100644 --- a/asdf/treeutil.py +++ b/asdf/treeutil.py @@ -6,7 +6,7 @@ import warnings from contextlib import contextmanager -from . import tagged +from . import _itertree, tagged from .exceptions import AsdfDeprecationWarning, AsdfWarning __all__ = ["walk", "iter_tree", "walk_and_modify", "get_children", "is_container", "PendingValue", "RemoveNode"] @@ -57,28 +57,8 @@ def iter_tree(top): tree : object The modified tree. """ - seen = set() - - def recurse(tree): - tree_id = id(tree) - - if tree_id in seen: - return - - if isinstance(tree, (list, tuple)): - seen.add(tree_id) - for val in tree: - yield from recurse(val) - seen.remove(tree_id) - elif isinstance(tree, dict): - seen.add(tree_id) - for val in tree.values(): - yield from recurse(val) - seen.remove(tree_id) - - yield tree - - return recurse(top) + for node, edge in _itertree.depth_first(top): + yield node class _TreeModificationContext: @@ -206,21 +186,77 @@ def __repr__(self): PendingValue = _PendingValue() -class _RemoveNode: - """ - Class of the RemoveNode singleton instance. This instance is used - as a signal for `asdf.treeutil.walk_and_modify` to remove the - node received by the callback. - """ +RemoveNode = _itertree.RemoveNode + + +def _get_json_id(top, edge): + keys = [] + while edge and edge.key is not None: + keys.append(edge.key) + edge = edge.parent + json_id = None + node = top + for key in keys[::-1]: + if hasattr(node, "get") and isinstance(node.get("id", None), str): + json_id = node["id"] + node = node[key] + return json_id + + +def _container_factory(obj): + if isinstance(obj, tagged.TaggedDict): + result = tagged.TaggedDict() + result._tag = obj._tag + elif isinstance(obj, tagged.TaggedList): + result = tagged.TaggedList([None] * len(obj)) + result._tag = obj._tag + elif isinstance(obj, dict): + result = obj.__class__() + elif isinstance(obj, list): + result = obj.__class__([None] * len(obj)) + elif isinstance(obj, tuple): + result = [None] * len(obj) + else: + raise NotImplementedError() + return result - def __repr__(self): - return "RemoveNode" +def walk_and_modify(top, callback, ignore_implicit_conversion=False, postorder=True, _context=None, _track_id=False): + if postorder: + modify = _itertree.leaf_first_modify_and_copy + else: + modify = _itertree.depth_first_modify_and_copy -RemoveNode = _RemoveNode() + if callback.__code__.co_argcount == 2 and not _track_id: + _track_id = True + warnings.warn("the json_id callback argument is deprecated", AsdfDeprecationWarning) + if _track_id: -def walk_and_modify(top, callback, ignore_implicit_conversion=False, postorder=True, _context=None, _track_id=False): + def wrapped_callback(obj, edge): + json_id = _get_json_id(top, edge) + return callback(obj, json_id) + + else: + + def wrapped_callback(obj, edge): + return callback(obj) + + if ignore_implicit_conversion: + container_factory = _container_factory + else: + + def container_factory(obj): + if isinstance(obj, tuple) and type(obj) != tuple: + warnings.warn(f"Failed to serialize instance of {type(obj)}, converting to list instead", AsdfWarning) + return _container_factory(obj) + + return modify(top, wrapped_callback, container_factory=container_factory) + + +def old_walk_and_modify( + top, callback, ignore_implicit_conversion=False, postorder=True, _context=None, _track_id=False +): """Modify a tree by walking it with a callback function. It also has the effect of doing a deep copy. diff --git a/asdf/yamlutil.py b/asdf/yamlutil.py index e4b1d5049..9021a051c 100644 --- a/asdf/yamlutil.py +++ b/asdf/yamlutil.py @@ -1,5 +1,5 @@ import warnings -from collections import OrderedDict +from collections import OrderedDict, deque from types import GeneratorType import numpy as np @@ -221,6 +221,8 @@ def custom_tree_to_tagged_tree(tree, ctx, _serialization_context=None): extension_manager = _serialization_context.extension_manager + generators = deque() + def _convert_obj(obj, converter): tag = converter.select_tag(obj, _serialization_context) # if select_tag returns None, converter.to_yaml_tree should return a new @@ -233,8 +235,7 @@ def _convert_obj(obj, converter): converter = extension_manager.get_converter_for_type(type(obj)) except KeyError: # no converter supports this type, return it as-is - yield obj - return + return obj if converter in converters_used: msg = "Conversion cycle detected" raise TypeError(msg) @@ -244,10 +245,8 @@ def _convert_obj(obj, converter): _serialization_context.assign_blocks() if isinstance(node, GeneratorType): - generator = node + generators.append(generator) node = next(generator) - else: - generator = None if isinstance(node, dict): tagged_node = tagged.TaggedDict(node, tag) @@ -262,9 +261,7 @@ def _convert_obj(obj, converter): _serialization_context._mark_extension_used(converter.extension) - yield tagged_node - if generator is not None: - yield from generator + return tagged_node cfg = config.get_config() convert_ndarray_subclasses = cfg.convert_unknown_ndarray_subclasses @@ -292,7 +289,7 @@ def _walker(obj): converters_cache[typ] = lambda obj: obj return obj - return treeutil.walk_and_modify( + new_tree = treeutil.walk_and_modify( tree, _walker, ignore_implicit_conversion=ctx._ignore_implicit_conversion, @@ -301,6 +298,10 @@ def _walker(obj): postorder=False, _context=ctx._tree_modification_context, ) + for generator in generators: + for _ in generator: + pass + return new_tree def tagged_tree_to_custom_tree(tree, ctx, force_raw_types=False, _serialization_context=None): @@ -312,6 +313,7 @@ def tagged_tree_to_custom_tree(tree, ctx, force_raw_types=False, _serialization_ _serialization_context = ctx._create_serialization_context(BlockAccess.READ) extension_manager = _serialization_context.extension_manager + generators = deque() def _walker(node): if force_raw_types: @@ -327,6 +329,9 @@ def _walker(node): _serialization_context.assign_object(obj) _serialization_context.assign_blocks() _serialization_context._mark_extension_used(converter.extension) + if isinstance(obj, GeneratorType): + generators.append(obj) + obj = next(obj) return obj if not ctx._ignore_unrecognized_tag: @@ -336,7 +341,7 @@ def _walker(node): ) return node - return treeutil.walk_and_modify( + new_tree = treeutil.walk_and_modify( tree, _walker, ignore_implicit_conversion=ctx._ignore_implicit_conversion, @@ -345,6 +350,10 @@ def _walker(node): postorder=True, _context=ctx._tree_modification_context, ) + for generator in generators: + for _ in generator: + pass + return new_tree def load_tree(stream): From 3970c987ada13321813047313ee86e3c397c2705 Mon Sep 17 00:00:00 2001 From: Brett Date: Tue, 2 Jan 2024 14:53:13 -0500 Subject: [PATCH 06/16] init dict keys to retain order --- asdf/_itertree.py | 3 ++- asdf/treeutil.py | 4 ++-- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/asdf/_itertree.py b/asdf/_itertree.py index 396bc06cf..fdbfecb3e 100644 --- a/asdf/_itertree.py +++ b/asdf/_itertree.py @@ -257,7 +257,8 @@ def leaf_first_modify(d, callback, get_children=None, setitem=None, delitem=None def _default_container_factory(obj): if isinstance(obj, dict): - return dict() + # init with keys to retain order + return dict({k: None for k in obj}) elif isinstance(obj, (list, tuple)): return [None] * len(obj) raise NotImplementedError() diff --git a/asdf/treeutil.py b/asdf/treeutil.py index a951094af..fbbbf9749 100644 --- a/asdf/treeutil.py +++ b/asdf/treeutil.py @@ -205,13 +205,13 @@ def _get_json_id(top, edge): def _container_factory(obj): if isinstance(obj, tagged.TaggedDict): - result = tagged.TaggedDict() + result = tagged.TaggedDict({k: None for k in obj}) result._tag = obj._tag elif isinstance(obj, tagged.TaggedList): result = tagged.TaggedList([None] * len(obj)) result._tag = obj._tag elif isinstance(obj, dict): - result = obj.__class__() + result = obj.__class__({k: None for k in obj}) elif isinstance(obj, list): result = obj.__class__([None] * len(obj)) elif isinstance(obj, tuple): From 699f11679dd58a7e8b5881554cc18d9182edb805 Mon Sep 17 00:00:00 2001 From: Brett Date: Tue, 2 Jan 2024 15:18:35 -0500 Subject: [PATCH 07/16] use top id for json_id --- asdf/treeutil.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/asdf/treeutil.py b/asdf/treeutil.py index fbbbf9749..424145def 100644 --- a/asdf/treeutil.py +++ b/asdf/treeutil.py @@ -194,8 +194,11 @@ def _get_json_id(top, edge): while edge and edge.key is not None: keys.append(edge.key) edge = edge.parent - json_id = None node = top + if hasattr(node, "get") and isinstance(node.get("id", None), str): + json_id = node["id"] + else: + json_id = None for key in keys[::-1]: if hasattr(node, "get") and isinstance(node.get("id", None), str): json_id = node["id"] From c1253796a342edeacaa658f613fcc10a597dfa72 Mon Sep 17 00:00:00 2001 From: Brett Date: Tue, 2 Jan 2024 16:33:16 -0500 Subject: [PATCH 08/16] remove old walk_and_modify --- asdf/_asdf.py | 7 +- asdf/treeutil.py | 337 +++++------------------------------------------ asdf/yamlutil.py | 2 - 3 files changed, 33 insertions(+), 313 deletions(-) diff --git a/asdf/_asdf.py b/asdf/_asdf.py index 69da27f9e..c44dfc31e 100644 --- a/asdf/_asdf.py +++ b/asdf/_asdf.py @@ -12,7 +12,7 @@ from . import _display as display from . import _node_info as node_info from . import _version as version -from . import constants, generic_io, reference, schema, treeutil, util, versioning, yamlutil +from . import constants, generic_io, reference, schema, util, versioning, yamlutil from ._block.manager import Manager as BlockManager from ._helpers import validate_version from .config import config_context, get_config @@ -169,11 +169,6 @@ def __init__( self._file_format_version = None - # Context of a call to treeutil.walk_and_modify, needed in the AsdfFile - # in case walk_and_modify is re-entered by extension code (via - # custom_tree_to_tagged_tree or tagged_tree_to_custom_tree). - self._tree_modification_context = treeutil._TreeModificationContext() - self._fd = None self._closed = False self._external_asdf_by_uri = {} diff --git a/asdf/treeutil.py b/asdf/treeutil.py index 424145def..92f7a5421 100644 --- a/asdf/treeutil.py +++ b/asdf/treeutil.py @@ -2,9 +2,7 @@ Utility functions for managing tree-like data structures. """ -import types import warnings -from contextlib import contextmanager from . import _itertree, tagged from .exceptions import AsdfDeprecationWarning, AsdfWarning @@ -61,117 +59,6 @@ def iter_tree(top): yield node -class _TreeModificationContext: - """ - Context of a call to walk_and_modify, which includes a map - of already modified nodes, a list of generators to drain - before exiting the call, and a set of node object ids that - are currently pending modification. - - Instances of this class are context managers that track - how many times they have been entered, and only drain - generators and reset themselves when exiting the outermost - context. They are also collections that map unmodified - nodes to the corresponding modified result. - """ - - def __init__(self): - self._map = {} - self._generators = [] - self._depth = 0 - self._pending = set() - - def add_generator(self, generator): - """ - Add a generator that should be drained before exiting - the outermost call to walk_and_modify. - """ - self._generators.append(generator) - - def is_pending(self, node): - """ - Return True if the node is already being modified. - This will not be the case unless the node contains a - reference to itself somewhere among its descendents. - """ - return id(node) in self._pending - - @contextmanager - def pending(self, node): - """ - Context manager that marks a node as pending for the - duration of the context. - """ - if id(node) in self._pending: - msg = ( - "Unhandled cycle in tree. This is possibly a bug " - "in extension code, which should be yielding " - "nodes that may contain reference cycles." - ) - raise RuntimeError(msg) - - self._pending.add(id(node)) - try: - yield self - finally: - self._pending.remove(id(node)) - - def __enter__(self): - self._depth += 1 - return self - - def __exit__(self, exc_type, exc_value, traceback): - self._depth -= 1 - - if self._depth == 0: - # If we're back to 0 depth, then we're exiting - # the outermost context, so it's time to drain - # the generators and reset this object for next - # time. - if exc_type is None: - self._drain_generators() - self._generators = [] - self._map = {} - self._pending = set() - - def _drain_generators(self): - """ - Drain each generator we've accumulated during this - call to walk_and_modify. - """ - # Generator code may add yet more generators - # to the list, so we need to loop until the - # list is empty. - while len(self._generators) > 0: - generators = self._generators - self._generators = [] - for generator in generators: - for _ in generator: - # Subsequent yields of the generator should - # always return the same value. What we're - # really doing here is executing the generator's - # remaining code, to further modify that first - # yielded object. - pass - - def __contains__(self, node): - return id(node) in self._map - - def __getitem__(self, node): - return self._map[id(node)][1] - - def __setitem__(self, node, result): - if id(node) in self._map: - # This indicates that an already defined - # modified node is being replaced, which is an - # error because it breaks references within the - # tree. - msg = "Node already has an associated result" - raise RuntimeError(msg) - - self._map[id(node)] = (node, result) - - class _PendingValue: """ Class of the PendingValue singleton instance. The presence of the instance @@ -224,42 +111,7 @@ def _container_factory(obj): return result -def walk_and_modify(top, callback, ignore_implicit_conversion=False, postorder=True, _context=None, _track_id=False): - if postorder: - modify = _itertree.leaf_first_modify_and_copy - else: - modify = _itertree.depth_first_modify_and_copy - - if callback.__code__.co_argcount == 2 and not _track_id: - _track_id = True - warnings.warn("the json_id callback argument is deprecated", AsdfDeprecationWarning) - - if _track_id: - - def wrapped_callback(obj, edge): - json_id = _get_json_id(top, edge) - return callback(obj, json_id) - - else: - - def wrapped_callback(obj, edge): - return callback(obj) - - if ignore_implicit_conversion: - container_factory = _container_factory - else: - - def container_factory(obj): - if isinstance(obj, tuple) and type(obj) != tuple: - warnings.warn(f"Failed to serialize instance of {type(obj)}, converting to list instead", AsdfWarning) - return _container_factory(obj) - - return modify(top, wrapped_callback, container_factory=container_factory) - - -def old_walk_and_modify( - top, callback, ignore_implicit_conversion=False, postorder=True, _context=None, _track_id=False -): +def walk_and_modify(top, callback, ignore_implicit_conversion=False, postorder=True, _track_id=False): """Modify a tree by walking it with a callback function. It also has the effect of doing a deep copy. @@ -306,162 +158,37 @@ def old_walk_and_modify( The modified tree. """ - callback_arity = callback.__code__.co_argcount - if callback_arity == 2: - if not _track_id: - warnings.warn("the json_id callback argument is deprecated", AsdfDeprecationWarning) - elif callback_arity != 1: - msg = "Expected callback to accept one or two arguments" - raise ValueError(msg) - - def _handle_generator(result): - # If the result is a generator, generate one value to - # extract the true result, then register the generator - # to be drained later. - if isinstance(result, types.GeneratorType): - generator = result - result = next(generator) - _context.add_generator(generator) - - return result - - def _handle_callback(node, json_id): - result = callback(node) if callback_arity == 1 else callback(node, json_id) - - return _handle_generator(result) - - def _handle_mapping(node, json_id): - result = node.__class__() - if isinstance(node, tagged.Tagged): - result._tag = node._tag - - pending_items = {} - for key, value in node.items(): - if _context.is_pending(value): - # The child node is pending modification, which means - # it must be its own ancestor. Assign the special - # PendingValue instance for now, and note that we'll - # need to fill in the real value later. - pending_items[key] = value - result[key] = PendingValue - - elif (val := _recurse(value, json_id)) is not RemoveNode: - result[key] = val - - yield result - - if len(pending_items) > 0: - # Now that we've yielded, the pending children should - # be available. - for key, value in pending_items.items(): - if (val := _recurse(value, json_id)) is not RemoveNode: - result[key] = val - else: - # The callback may have decided to delete - # this node after all. - del result[key] - - def _handle_mutable_sequence(node, json_id): - result = node.__class__() - if isinstance(node, tagged.Tagged): - result._tag = node._tag - - pending_items = {} - for i, value in enumerate(node): - if _context.is_pending(value): - # The child node is pending modification, which means - # it must be its own ancestor. Assign the special - # PendingValue instance for now, and note that we'll - # need to fill in the real value later. - pending_items[i] = value - result.append(PendingValue) - else: - result.append(_recurse(value, json_id)) - - yield result - - for i, value in pending_items.items(): - # Now that we've yielded, the pending children should - # be available. - result[i] = _recurse(value, json_id) - - def _handle_immutable_sequence(node, json_id): - # Immutable sequences containing themselves are impossible - # to construct (well, maybe possible in a C extension, but - # we're not going to worry about that), so we don't need - # to yield here. - contents = [_recurse(value, json_id) for value in node] - - try: - result = node.__class__(contents) - if isinstance(node, tagged.Tagged): - result._tag = node._tag - except TypeError: - # The derived class signature is different, so simply store the - # list representing the contents. Currently this is primarily - # intended to handle namedtuple and NamedTuple instances. - if not ignore_implicit_conversion: - warnings.warn(f"Failed to serialize instance of {type(node)}, converting to list instead", AsdfWarning) - result = contents - - return result - - def _handle_children(node, json_id): - if isinstance(node, dict): - result = _handle_mapping(node, json_id) - elif isinstance(node, tuple): - result = _handle_immutable_sequence(node, json_id) - elif isinstance(node, list): - result = _handle_mutable_sequence(node, json_id) - else: - result = node - - return _handle_generator(result) - - def _recurse(node, json_id=None): - if node in _context: - # The node's modified result has already been - # created, all we need to do is return it. This - # occurs when the tree contains multiple references - # to the same object id. - return _context[node] - - # Inform the context that we're going to start modifying - # this node. - with _context.pending(node): - # Take note of the "id" field, in case we're modifying - # a schema and need to know the namespace for resolving - # URIs. Ignore an id that is not a string, since it may - # be an object defining an id property and not an id - # itself (this is common in metaschemas). - if isinstance(node, dict) and "id" in node and isinstance(node["id"], str): - json_id = node["id"] - - if postorder: - # If this is a postorder modification, invoke the - # callback on this node's children first. - result = _handle_children(node, json_id) - result = _handle_callback(result, json_id) - else: - # Otherwise, invoke the callback on the node first, - # then its children. - result = _handle_callback(node, json_id) - result = _handle_children(result, json_id) - - # Store the result in the context, in case there are - # additional references to the same node elsewhere in - # the tree. - _context[node] = result - - return result - - if _context is None: - _context = _TreeModificationContext() - - with _context: - return _recurse(top) - # Generators will be drained here, if this is the outermost - # call to walk_and_modify. + + if postorder: + modify = _itertree.leaf_first_modify_and_copy + else: + modify = _itertree.depth_first_modify_and_copy + + if callback.__code__.co_argcount == 2 and not _track_id: + _track_id = True + warnings.warn("the json_id callback argument is deprecated", AsdfDeprecationWarning) + + if _track_id: + + def wrapped_callback(obj, edge): + json_id = _get_json_id(top, edge) + return callback(obj, json_id) + + else: + + def wrapped_callback(obj, edge): + return callback(obj) + + if ignore_implicit_conversion: + container_factory = _container_factory + else: + + def container_factory(obj): + if isinstance(obj, tuple) and type(obj) != tuple: + warnings.warn(f"Failed to serialize instance of {type(obj)}, converting to list instead", AsdfWarning) + return _container_factory(obj) + + return modify(top, wrapped_callback, container_factory=container_factory) def _get_children(node): diff --git a/asdf/yamlutil.py b/asdf/yamlutil.py index 9021a051c..2262de692 100644 --- a/asdf/yamlutil.py +++ b/asdf/yamlutil.py @@ -296,7 +296,6 @@ def _walker(obj): # Walk the tree in preorder, so that extensions can return # container nodes with unserialized children. postorder=False, - _context=ctx._tree_modification_context, ) for generator in generators: for _ in generator: @@ -348,7 +347,6 @@ def _walker(node): # Walk the tree in postorder, so that extensions receive # container nodes with children already deserialized. postorder=True, - _context=ctx._tree_modification_context, ) for generator in generators: for _ in generator: From 2197e342ae3fd41f1253e7f46f6c273c3cb1bf77 Mon Sep 17 00:00:00 2001 From: Brett Date: Tue, 2 Jan 2024 17:00:19 -0500 Subject: [PATCH 09/16] cleaning up unused code --- asdf/_itertree.py | 122 +++++++++++++++------------------------------- 1 file changed, 39 insertions(+), 83 deletions(-) diff --git a/asdf/_itertree.py b/asdf/_itertree.py index fdbfecb3e..06e5db514 100644 --- a/asdf/_itertree.py +++ b/asdf/_itertree.py @@ -80,46 +80,47 @@ def _default_get_children(obj): return None -def breadth_first(d, get_children=None, skip_ids=None): +def breadth_first(d, get_children=None): get_children = get_children or _default_get_children - skip_ids = skip_ids or set() + seen = set() dq = collections.deque() dq.append(_Edge(None, None, d)) while dq: edge = dq.popleft() obj = edge.node - if id(obj) in skip_ids: + obj_id = id(obj) + if obj_id in seen: continue yield obj, edge children = get_children(obj) if children: - skip_ids.add(id(obj)) + seen.add(obj_id) for key, value in children: dq.append(_Edge(edge, key, value)) -def depth_first(d, get_children=None, skip_ids=None): +def depth_first(d, get_children=None): get_children = get_children or _default_get_children - skip_ids = skip_ids or set() + seen = set() dq = collections.deque() dq.append(_Edge(None, None, d)) while dq: edge = dq.pop() obj = edge.node obj_id = id(obj) - if obj_id in skip_ids: + if obj_id in seen: continue yield obj, edge children = get_children(obj) if children: - skip_ids.add(obj_id) + seen.add(obj_id) for key, value in children: dq.append(_Edge(edge, key, value)) -def leaf_first(d, get_children=None, skip_ids=None): +def leaf_first(d, get_children=None): get_children = get_children or _default_get_children - skip_ids = skip_ids or set() + seen = set() dq = collections.deque() dq.append(_Edge(None, None, d)) while dq: @@ -131,11 +132,11 @@ def leaf_first(d, get_children=None, skip_ids=None): continue obj = edge.node obj_id = id(obj) - if obj_id in skip_ids: + if obj_id in seen: continue children = get_children(obj) if children: - skip_ids.add(obj_id) + seen.add(obj_id) dq.append(_ShowValue(edge, obj_id)) for key, value in children: dq.append(_Edge(edge, key, value)) @@ -152,10 +153,9 @@ def _default_delitem(obj, key): obj.__delitem__(key) -def breadth_first_modify(d, callback, get_children=None, setitem=None, delitem=None, skip_ids=None): +def breadth_first_modify(d, callback, get_children=None, setitem=None, delitem=None): get_children = get_children or _default_get_children - skip_ids = skip_ids or set() - cache = {} # TODO fix + cache = {} setitem = setitem or _default_setitem delitem = delitem or _default_delitem dq = collections.deque() @@ -164,11 +164,9 @@ def breadth_first_modify(d, callback, get_children=None, setitem=None, delitem=N edge = dq.popleft() obj = edge.node obj_id = id(obj) - if obj_id in skip_ids: - continue if obj_id not in cache: - cache[obj_id] = callback(obj, edge) - obj = cache[obj_id] + cache[obj_id] = (obj, callback(obj, edge)) + obj = cache[obj_id][1] if edge.parent is not None: if obj is RemoveNode: delitem(edge.parent.node, edge.key) @@ -176,15 +174,13 @@ def breadth_first_modify(d, callback, get_children=None, setitem=None, delitem=N setitem(edge.parent.node, edge.key, obj) children = get_children(obj) if children: - skip_ids.add(obj_id) for key, value in children: dq.append(_Edge(edge, key, value)) -def depth_first_modify(d, callback, get_children=None, setitem=None, delitem=None, skip_ids=None): +def depth_first_modify(d, callback, get_children=None, setitem=None, delitem=None): get_children = get_children or _default_get_children - skip_ids = skip_ids or set() - cache = {} # TODO fix + cache = {} setitem = setitem or _default_setitem delitem = delitem or _default_delitem dq = collections.deque() @@ -193,11 +189,9 @@ def depth_first_modify(d, callback, get_children=None, setitem=None, delitem=Non edge = dq.pop() obj = edge.node obj_id = id(obj) - if obj_id in skip_ids: - continue if obj_id not in cache: - cache[obj_id] = callback(obj, edge) - obj = cache[obj_id] + cache[obj_id] = (obj, callback(obj, edge)) + obj = cache[obj_id][1] if edge.parent is not None: if obj is RemoveNode: delitem(edge.parent.node, edge.key) @@ -205,15 +199,13 @@ def depth_first_modify(d, callback, get_children=None, setitem=None, delitem=Non setitem(edge.parent.node, edge.key, obj) children = get_children(obj) if children: - skip_ids.add(obj_id) for key, value in children: dq.append(_Edge(edge, key, value)) -def leaf_first_modify(d, callback, get_children=None, setitem=None, delitem=None, skip_ids=None): +def leaf_first_modify(d, callback, get_children=None, setitem=None, delitem=None): get_children = get_children or _default_get_children - skip_ids = skip_ids or set() - cache = {} # TODO fix + cache = {} setitem = setitem or _default_setitem delitem = delitem or _default_delitem dq = collections.deque() @@ -225,8 +217,8 @@ def leaf_first_modify(d, callback, get_children=None, setitem=None, delitem=None obj = edge.node obj_id = id(obj) if obj_id not in cache: - cache[obj_id] = callback(obj, edge) - obj = cache[obj_id] + cache[obj_id] = (obj, callback(obj, edge)) + obj = cache[obj_id][1] if edge.parent is not None: if obj is RemoveNode: delitem(edge.parent.node, edge.key) @@ -235,19 +227,16 @@ def leaf_first_modify(d, callback, get_children=None, setitem=None, delitem=None continue obj = edge.node obj_id = id(obj) - if obj_id in skip_ids: - continue children = get_children(obj) if children: - skip_ids.add(obj_id) dq.append(_ShowValue(edge, obj_id)) for key, value in children: dq.append(_Edge(edge, key, value)) continue if obj_id not in cache: - cache[obj_id] = callback(obj, edge) - obj = cache[obj_id] + cache[obj_id] = (obj, callback(obj, edge)) + obj = cache[obj_id][1] if edge.parent is not None: if obj is RemoveNode: delitem(edge.parent.node, edge.key) @@ -264,12 +253,9 @@ def _default_container_factory(obj): raise NotImplementedError() -def breadth_first_modify_and_copy( - d, callback, get_children=None, setitem=None, delitem=None, container_factory=None, skip_ids=None -): +def breadth_first_modify_and_copy(d, callback, get_children=None, setitem=None, delitem=None, container_factory=None): get_children = get_children or _default_get_children - skip_ids = skip_ids or set() - cache = {} # TODO fix + cache = {} setitem = setitem or _default_setitem delitem = delitem or _default_delitem container_factory = container_factory or _default_container_factory @@ -280,24 +266,23 @@ def breadth_first_modify_and_copy( edge = dq.popleft() obj = edge.node obj_id = id(obj) - if obj_id in skip_ids: - continue if False and obj_id in cache: obj = cache[obj_id] else: - obj = callback(obj, edge) - if edge.parent is not None and obj is RemoveNode: + cobj = callback(obj, edge) + if edge.parent is not None and cobj is RemoveNode: # TODO handle multiple list key deletion delitem(edge.parent.node, edge.key) continue - children = get_children(obj) + children = get_children(cobj) if children: - obj = container_factory(obj) - edge.node = obj - skip_ids.add(obj_id) + container = container_factory(cobj) + edge.node = container for key, value in children: dq.append(_Edge(edge, key, value)) - # cache[obj_id] = obj + cobj = container + cache[obj_id] = (obj, cobj) + obj = cobj if result is None: result = obj if edge.parent is not None: @@ -305,11 +290,8 @@ def breadth_first_modify_and_copy( return result -def depth_first_modify_and_copy( - d, callback, get_children=None, setitem=None, delitem=None, container_factory=None, skip_ids=None -): +def depth_first_modify_and_copy(d, callback, get_children=None, setitem=None, delitem=None, container_factory=None): get_children = get_children or _default_get_children - skip_ids = skip_ids or set() cache = {} setitem = setitem or _default_setitem delitem = delitem or _default_delitem @@ -320,16 +302,10 @@ def depth_first_modify_and_copy( while dq: edge = dq.pop() obj = edge.node - # print(obj) obj_id = id(obj) - # if obj_id in skip_ids: - # #print(f"\tskip because of id {obj_id}") - # continue if obj_id in cache: - # print("\tfrom cache") new_obj = cache[obj_id][1] else: - # print("\tstepping in") new_obj = callback(obj, edge) if edge.parent is not None and new_obj is RemoveNode: # TODO handle multiple list key deletion @@ -339,8 +315,6 @@ def depth_first_modify_and_copy( if children: container = container_factory(new_obj) edge.node = container - # print(f"\tadding id {obj_id} to skips") - # skip_ids.add(obj_id) for key, value in children: dq.append(_Edge(edge, key, value)) new_obj = container @@ -349,15 +323,11 @@ def depth_first_modify_and_copy( result = new_obj if edge.parent is not None: setitem(edge.parent.node, edge.key, new_obj) - # print(result) return result -def leaf_first_modify_and_copy( - d, callback, get_children=None, setitem=None, delitem=None, container_factory=None, skip_ids=None -): +def leaf_first_modify_and_copy(d, callback, get_children=None, setitem=None, delitem=None, container_factory=None): get_children = get_children or _default_get_children - skip_ids = skip_ids or set() pending = {} cache = {} setitem = setitem or _default_setitem @@ -365,14 +335,9 @@ def leaf_first_modify_and_copy( container_factory = container_factory or _default_container_factory result = None dq = collections.deque() - # print(f"Input obj id: {id(d['obj'])}") - # print(f"Input inverse id: {id(d['obj']['inverse'])}") dq.append(_Edge(None, None, d)) while dq: edge = dq.pop() - # print(f"Processing {edge}") - # print(f"\tdq={dq}") - # print(f"\tcache={cache}") if isinstance(edge, _ShowValue): obj_id = edge.obj_id edge = edge.obj @@ -401,19 +366,13 @@ def leaf_first_modify_and_copy( setitem(edge.parent.node, edge.key, new_obj) del pending[obj_id] continue - # print(f"\tNode id {id(edge.node)} at {edge.key} of {edge.parent}") obj = edge.node obj_id = id(obj) - # if obj_id in skip_ids: - # print("\tskipping") - # continue if obj_id in cache: - # print("\tfrom cache") new_obj = cache[obj_id][1] else: children = get_children(obj) if children: - skip_ids.add(obj_id) container = container_factory(obj) pending[obj_id] = [] if result is None: @@ -425,10 +384,7 @@ def leaf_first_modify_and_copy( pending[id(value)].append(_Edge(edge, key, value)) else: dq.append(_Edge(edge, key, value)) - # if id(value) not in pending: continue - # cache[obj_id] = callback(obj, edge) - # obj = cache[obj_id] new_obj = callback(obj, edge) cache[obj_id] = (obj, new_obj) if result is None: From 83e31d63d218a05eff3ab1c9f729e1e1a6eb38d1 Mon Sep 17 00:00:00 2001 From: Brett Date: Wed, 3 Jan 2024 09:23:57 -0500 Subject: [PATCH 10/16] key removal for lists and dicts --- asdf/_itertree.py | 52 +++++++++++++++++-------- asdf/_tests/test_itertree.py | 75 ++++++++++++++++++++++++++++++++++++ 2 files changed, 111 insertions(+), 16 deletions(-) diff --git a/asdf/_itertree.py b/asdf/_itertree.py index 06e5db514..4acbb1859 100644 --- a/asdf/_itertree.py +++ b/asdf/_itertree.py @@ -149,13 +149,27 @@ def _default_setitem(obj, key, value): def _default_delitem(obj, key): - if key in obj: - obj.__delitem__(key) + obj.__delitem__(key) + + +def _delete_items(edges, delitem): + # index all deletions by the parent node id + by_parent_id = {} + for edge in edges: + parent_id = id(edge.parent.node) + if parent_id not in by_parent_id: + by_parent_id[parent_id] = [] + by_parent_id[parent_id].append(edge) + for parent_id in by_parent_id: + # delete with highest/last key first + for edge in sorted(by_parent_id[parent_id], key=lambda edge: edge.key, reverse=True): + delitem(edge.parent.node, edge.key) def breadth_first_modify(d, callback, get_children=None, setitem=None, delitem=None): get_children = get_children or _default_get_children cache = {} + to_delete = collections.deque() setitem = setitem or _default_setitem delitem = delitem or _default_delitem dq = collections.deque() @@ -169,18 +183,20 @@ def breadth_first_modify(d, callback, get_children=None, setitem=None, delitem=N obj = cache[obj_id][1] if edge.parent is not None: if obj is RemoveNode: - delitem(edge.parent.node, edge.key) + to_delete.append(edge) continue setitem(edge.parent.node, edge.key, obj) children = get_children(obj) if children: for key, value in children: dq.append(_Edge(edge, key, value)) + _delete_items(to_delete, delitem) def depth_first_modify(d, callback, get_children=None, setitem=None, delitem=None): get_children = get_children or _default_get_children cache = {} + to_delete = collections.deque() setitem = setitem or _default_setitem delitem = delitem or _default_delitem dq = collections.deque() @@ -194,18 +210,20 @@ def depth_first_modify(d, callback, get_children=None, setitem=None, delitem=Non obj = cache[obj_id][1] if edge.parent is not None: if obj is RemoveNode: - delitem(edge.parent.node, edge.key) + to_delete.append(edge) continue setitem(edge.parent.node, edge.key, obj) children = get_children(obj) if children: for key, value in children: dq.append(_Edge(edge, key, value)) + _delete_items(to_delete, delitem) def leaf_first_modify(d, callback, get_children=None, setitem=None, delitem=None): get_children = get_children or _default_get_children cache = {} + to_delete = collections.deque() setitem = setitem or _default_setitem delitem = delitem or _default_delitem dq = collections.deque() @@ -221,7 +239,7 @@ def leaf_first_modify(d, callback, get_children=None, setitem=None, delitem=None obj = cache[obj_id][1] if edge.parent is not None: if obj is RemoveNode: - delitem(edge.parent.node, edge.key) + to_delete.append(edge) else: setitem(edge.parent.node, edge.key, obj) continue @@ -239,9 +257,10 @@ def leaf_first_modify(d, callback, get_children=None, setitem=None, delitem=None obj = cache[obj_id][1] if edge.parent is not None: if obj is RemoveNode: - delitem(edge.parent.node, edge.key) + to_delete.append(edge) else: setitem(edge.parent.node, edge.key, obj) + _delete_items(to_delete, delitem) def _default_container_factory(obj): @@ -256,6 +275,7 @@ def _default_container_factory(obj): def breadth_first_modify_and_copy(d, callback, get_children=None, setitem=None, delitem=None, container_factory=None): get_children = get_children or _default_get_children cache = {} + to_delete = collections.deque() setitem = setitem or _default_setitem delitem = delitem or _default_delitem container_factory = container_factory or _default_container_factory @@ -271,8 +291,7 @@ def breadth_first_modify_and_copy(d, callback, get_children=None, setitem=None, else: cobj = callback(obj, edge) if edge.parent is not None and cobj is RemoveNode: - # TODO handle multiple list key deletion - delitem(edge.parent.node, edge.key) + to_delete.append(edge) continue children = get_children(cobj) if children: @@ -287,12 +306,14 @@ def breadth_first_modify_and_copy(d, callback, get_children=None, setitem=None, result = obj if edge.parent is not None: setitem(edge.parent.node, edge.key, obj) + _delete_items(to_delete, delitem) return result def depth_first_modify_and_copy(d, callback, get_children=None, setitem=None, delitem=None, container_factory=None): get_children = get_children or _default_get_children cache = {} + to_delete = collections.deque() setitem = setitem or _default_setitem delitem = delitem or _default_delitem container_factory = container_factory or _default_container_factory @@ -308,8 +329,7 @@ def depth_first_modify_and_copy(d, callback, get_children=None, setitem=None, de else: new_obj = callback(obj, edge) if edge.parent is not None and new_obj is RemoveNode: - # TODO handle multiple list key deletion - delitem(edge.parent.node, edge.key) + to_delete.append(edge) continue children = get_children(new_obj) if children: @@ -323,6 +343,7 @@ def depth_first_modify_and_copy(d, callback, get_children=None, setitem=None, de result = new_obj if edge.parent is not None: setitem(edge.parent.node, edge.key, new_obj) + _delete_items(to_delete, delitem) return result @@ -330,6 +351,7 @@ def leaf_first_modify_and_copy(d, callback, get_children=None, setitem=None, del get_children = get_children or _default_get_children pending = {} cache = {} + to_delete = collections.deque() setitem = setitem or _default_setitem delitem = delitem or _default_delitem container_factory = container_factory or _default_container_factory @@ -353,15 +375,13 @@ def leaf_first_modify_and_copy(d, callback, get_children=None, setitem=None, del result = new_obj if edge.parent is not None: if new_obj is RemoveNode: - # TODO handle multiple list key deletion - delitem(edge.parent.node, edge.key) + to_delete.append(edge) else: setitem(edge.parent.node, edge.key, new_obj) if obj_id in pending: for edge in pending[obj_id]: if new_obj is RemoveNode: - # TODO handle multiple list key deletion - delitem(edge.parent.node, edge.key) + to_delete.append(edge) else: setitem(edge.parent.node, edge.key, new_obj) del pending[obj_id] @@ -391,8 +411,8 @@ def leaf_first_modify_and_copy(d, callback, get_children=None, setitem=None, del result = new_obj if edge.parent is not None: if new_obj is RemoveNode: - # TODO handle multiple list key deletion - delitem(edge.parent.node, edge.key) + to_delete.append(edge) else: setitem(edge.parent.node, edge.key, new_obj) + _delete_items(to_delete, delitem) return result diff --git a/asdf/_tests/test_itertree.py b/asdf/_tests/test_itertree.py index e5995150b..cfd1a8d48 100644 --- a/asdf/_tests/test_itertree.py +++ b/asdf/_tests/test_itertree.py @@ -12,6 +12,8 @@ """ import copy +import pytest + from asdf import _itertree @@ -374,3 +376,76 @@ def callback(obj, keys): result = _itertree.leaf_first_modify_and_copy(copied_tree, callback) assert result["a"] == [1, 2, 42] assert copied_tree == tree + + +@pytest.mark.parametrize( + "traversal", + [ + _itertree.breadth_first_modify, + _itertree.depth_first_modify, + _itertree.leaf_first_modify, + _itertree.breadth_first_modify_and_copy, + _itertree.depth_first_modify_and_copy, + _itertree.leaf_first_modify_and_copy, + ], +) +def test_node_removal(traversal): + tree = { + "a": [1, 2, 3], + "b": 4, + } + + def callback(obj, edge): + if obj in (1, 3, 4): + return _itertree.RemoveNode + return obj + + result = traversal(tree, callback) + if result is not None: # this is a copy + original = tree + else: + original = None + result = tree + assert result["a"] == [ + 2, + ] + assert "b" not in result + if original is not None: + assert original["a"] == [1, 2, 3] + assert original["b"] == 4 + assert set(original.keys()) == {"a", "b"} + + +@pytest.mark.parametrize( + "traversal", + [ + _itertree.breadth_first_modify, + _itertree.depth_first_modify, + _itertree.leaf_first_modify, + _itertree.breadth_first_modify_and_copy, + _itertree.depth_first_modify_and_copy, + _itertree.leaf_first_modify_and_copy, + ], +) +def test_key_order(traversal): + """ + All traversal and modification functions should preserve + the order of keys in a dictionary + """ + tree = {} + tree["a"] = [1, 2] + tree["b"] = [3, 4] + tree["c"] = {} + tree["c"]["d"] = [5, 6] + tree["c"]["e"] = [7, 8] + + result = traversal(tree, lambda obj, edge: obj) + if result is None: + result = tree + + assert list(result.keys()) == ["a", "b", "c"] + assert result["a"] == [1, 2] + assert result["b"] == [3, 4] + assert list(result["c"].keys()) == ["d", "e"] + assert result["c"]["d"] == [5, 6] + assert result["c"]["e"] == [7, 8] From 4992d530921d645cea3ed6d62b77e26c7634ef73 Mon Sep 17 00:00:00 2001 From: Brett Date: Wed, 3 Jan 2024 12:44:01 -0500 Subject: [PATCH 11/16] fix recursive obj cache --- asdf/_itertree.py | 74 ++++++++++++----------- asdf/_tests/test_itertree.py | 111 ++++++++++++++++++++++++++++++++--- 2 files changed, 144 insertions(+), 41 deletions(-) diff --git a/asdf/_itertree.py b/asdf/_itertree.py index 4acbb1859..e21597295 100644 --- a/asdf/_itertree.py +++ b/asdf/_itertree.py @@ -178,18 +178,20 @@ def breadth_first_modify(d, callback, get_children=None, setitem=None, delitem=N edge = dq.popleft() obj = edge.node obj_id = id(obj) - if obj_id not in cache: - cache[obj_id] = (obj, callback(obj, edge)) - obj = cache[obj_id][1] + if obj_id in cache: + new_obj = cache[obj_id][1] + else: + new_obj = callback(obj, edge) + cache[obj_id] = (obj, new_obj) + children = get_children(new_obj) + if children: + for key, value in children: + dq.append(_Edge(edge, key, value)) if edge.parent is not None: - if obj is RemoveNode: + if new_obj is RemoveNode: to_delete.append(edge) continue - setitem(edge.parent.node, edge.key, obj) - children = get_children(obj) - if children: - for key, value in children: - dq.append(_Edge(edge, key, value)) + setitem(edge.parent.node, edge.key, new_obj) _delete_items(to_delete, delitem) @@ -205,24 +207,27 @@ def depth_first_modify(d, callback, get_children=None, setitem=None, delitem=Non edge = dq.pop() obj = edge.node obj_id = id(obj) - if obj_id not in cache: - cache[obj_id] = (obj, callback(obj, edge)) - obj = cache[obj_id][1] + if obj_id in cache: + new_obj = cache[obj_id][1] + else: + new_obj = callback(obj, edge) + cache[obj_id] = (obj, new_obj) + children = get_children(new_obj) + if children: + for key, value in children: + dq.append(_Edge(edge, key, value)) if edge.parent is not None: - if obj is RemoveNode: + if new_obj is RemoveNode: to_delete.append(edge) continue - setitem(edge.parent.node, edge.key, obj) - children = get_children(obj) - if children: - for key, value in children: - dq.append(_Edge(edge, key, value)) + setitem(edge.parent.node, edge.key, new_obj) _delete_items(to_delete, delitem) def leaf_first_modify(d, callback, get_children=None, setitem=None, delitem=None): get_children = get_children or _default_get_children cache = {} + pending = {} to_delete = collections.deque() setitem = setitem or _default_setitem delitem = delitem or _default_delitem @@ -231,9 +236,9 @@ def leaf_first_modify(d, callback, get_children=None, setitem=None, delitem=None while dq: edge = dq.pop() if isinstance(edge, _ShowValue): + obj_id = edge.obj_id edge = edge.obj obj = edge.node - obj_id = id(obj) if obj_id not in cache: cache[obj_id] = (obj, callback(obj, edge)) obj = cache[obj_id][1] @@ -242,31 +247,34 @@ def leaf_first_modify(d, callback, get_children=None, setitem=None, delitem=None to_delete.append(edge) else: setitem(edge.parent.node, edge.key, obj) + if obj_id in pending: + for edge in pending[obj_id]: + if obj is RemoveNode: + to_delete.append(edge) + else: + setitem(edge.parent.node, edge.key, obj) + del pending[obj_id] continue obj = edge.node obj_id = id(obj) + if obj_id not in pending: + pending[obj_id] = [] children = get_children(obj) + dq.append(_ShowValue(edge, obj_id)) if children: - dq.append(_ShowValue(edge, obj_id)) for key, value in children: - dq.append(_Edge(edge, key, value)) + if id(value) in pending: + pending[id(value)].append(_Edge(edge, key, value)) + else: + dq.append(_Edge(edge, key, value)) continue - - if obj_id not in cache: - cache[obj_id] = (obj, callback(obj, edge)) - obj = cache[obj_id][1] - if edge.parent is not None: - if obj is RemoveNode: - to_delete.append(edge) - else: - setitem(edge.parent.node, edge.key, obj) _delete_items(to_delete, delitem) def _default_container_factory(obj): if isinstance(obj, dict): # init with keys to retain order - return dict({k: None for k in obj}) + return {k: None for k in obj} elif isinstance(obj, (list, tuple)): return [None] * len(obj) raise NotImplementedError() @@ -286,8 +294,8 @@ def breadth_first_modify_and_copy(d, callback, get_children=None, setitem=None, edge = dq.popleft() obj = edge.node obj_id = id(obj) - if False and obj_id in cache: - obj = cache[obj_id] + if obj_id in cache: + obj = cache[obj_id][1] else: cobj = callback(obj, edge) if edge.parent is not None and cobj is RemoveNode: diff --git a/asdf/_tests/test_itertree.py b/asdf/_tests/test_itertree.py index cfd1a8d48..6457567a6 100644 --- a/asdf/_tests/test_itertree.py +++ b/asdf/_tests/test_itertree.py @@ -17,7 +17,27 @@ from asdf import _itertree -def test_breadth_first_traversal(): +def _traversal_to_generator(tree, traversal): + if "modify" not in traversal.__name__: + return traversal(tree) + + def make_generator(tree): + values = [] + + def callback(obj, edge): + values.append((obj, edge)) + return obj + + traversal(tree, callback) + yield from values + + return make_generator(tree) + + +@pytest.mark.parametrize( + "traversal", [_itertree.breadth_first, _itertree.breadth_first_modify, _itertree.breadth_first_modify_and_copy] +) +def test_breadth_first_traversal(traversal): tree = { "a": { "b": [1, 2, {"c": 3}], @@ -39,7 +59,8 @@ def test_breadth_first_traversal(): ] expected = [] - for node, edge in _itertree.breadth_first(tree): + + for node, edge in _traversal_to_generator(tree, traversal): if not len(expected): expected = expected_results.pop(0) assert node in expected @@ -71,7 +92,10 @@ def test_recursive_breadth_first_traversal(): assert not expected_results -def test_leaf_first_traversal(): +@pytest.mark.parametrize( + "traversal", [_itertree.leaf_first, _itertree.leaf_first_modify, _itertree.leaf_first_modify_and_copy] +) +def test_leaf_first_traversal(traversal): tree = { "a": { "b": [1, 2, {"c": 3}], @@ -103,13 +127,13 @@ def test_leaf_first_traversal(): ("a", "b", 2, "c"), ("a", "d"), } - for node, edge in _itertree.leaf_first(tree): + for node, edge in _traversal_to_generator(tree, traversal): keys = _itertree.edge_to_keys(edge) assert keys in expected obj = tree for key in keys: obj = obj[key] - assert obj is node + assert obj == node # updated expected seen_keys.add(keys) @@ -165,7 +189,10 @@ def test_recursive_leaf_first_traversal(): assert not visit_ids -def test_depth_first_traversal(): +@pytest.mark.parametrize( + "traversal", [_itertree.depth_first, _itertree.depth_first_modify, _itertree.depth_first_modify_and_copy] +) +def test_depth_first_traversal(traversal): tree = { "a": { "b": [1, 2, {"c": 3}], @@ -194,13 +221,13 @@ def test_depth_first_traversal(): expected = {()} seen_keys = set() - for node, edge in _itertree.depth_first(tree): + for node, edge in _traversal_to_generator(tree, traversal): keys = _itertree.edge_to_keys(edge) assert keys in expected obj = tree for key in keys: obj = obj[key] - assert obj is node + assert obj == node # updated expected seen_keys.add(keys) @@ -449,3 +476,71 @@ def test_key_order(traversal): assert list(result["c"].keys()) == ["d", "e"] assert result["c"]["d"] == [5, 6] assert result["c"]["e"] == [7, 8] + + +@pytest.mark.parametrize( + "traversal", + [ + _itertree.breadth_first_modify, + _itertree.depth_first_modify, + _itertree.leaf_first_modify, + _itertree.breadth_first_modify_and_copy, + _itertree.depth_first_modify_and_copy, + _itertree.leaf_first_modify_and_copy, + ], +) +def test_cache_callback(traversal): + class Foo: + pass + + obj = Foo() + obj.count = 0 + + tree = {} + tree["a"] = obj + tree["b"] = obj + tree["c"] = {"d": obj} + + def callback(obj, edge): + if isinstance(obj, Foo): + obj.count += 1 + return obj + + result = traversal(tree, callback) + if result is None: + result = tree + + assert result["a"].count == 1 + assert result["b"].count == 1 + assert result["c"]["d"].count == 1 + + +@pytest.mark.parametrize( + "traversal", + [ + _itertree.breadth_first_modify, + _itertree.depth_first_modify, + _itertree.leaf_first_modify, + _itertree.breadth_first_modify_and_copy, + _itertree.depth_first_modify_and_copy, + _itertree.leaf_first_modify_and_copy, + ], +) +def test_recursive_object(traversal): + tree = {} + tree["a"] = {"count": 0} + tree["b"] = {"a": tree["a"]} + tree["a"]["b"] = tree["b"] + + def callback(obj, edge): + if isinstance(obj, dict) and "count" in obj: + obj["count"] += 1 + return obj + + result = traversal(tree, callback) + if result is None: + result = tree + + assert result["a"]["count"] == 1 + assert result["b"]["a"]["count"] == 1 + assert result["a"] is result["b"]["a"] From 3215f3861557abcd4c246b79663f8cac02eec239 Mon Sep 17 00:00:00 2001 From: Brett Date: Wed, 3 Jan 2024 13:26:54 -0500 Subject: [PATCH 12/16] cleaning up tests --- asdf/_tests/test_itertree.py | 43 ++++++++++++++++++------------------ 1 file changed, 21 insertions(+), 22 deletions(-) diff --git a/asdf/_tests/test_itertree.py b/asdf/_tests/test_itertree.py index 6457567a6..1fbdbaafe 100644 --- a/asdf/_tests/test_itertree.py +++ b/asdf/_tests/test_itertree.py @@ -68,7 +68,10 @@ def test_breadth_first_traversal(traversal): assert not expected_results -def test_recursive_breadth_first_traversal(): +@pytest.mark.parametrize( + "traversal", [_itertree.breadth_first, _itertree.breadth_first_modify, _itertree.breadth_first_modify_and_copy] +) +def test_recursive_breadth_first_traversal(traversal): tree = { "a": {}, "b": {}, @@ -84,7 +87,7 @@ def test_recursive_breadth_first_traversal(): ] expected = [] - for node, edge in _itertree.breadth_first(tree): + for node, edge in _traversal_to_generator(tree, traversal): if not len(expected): expected = expected_results.pop(0) assert node in expected @@ -145,7 +148,10 @@ def test_leaf_first_traversal(traversal): assert not expected -def test_recursive_leaf_first_traversal(): +@pytest.mark.parametrize( + "traversal", [_itertree.leaf_first, _itertree.leaf_first_modify, _itertree.leaf_first_modify_and_copy] +) +def test_recursive_leaf_first_traversal(traversal): tree = { "a": {}, "b": {}, @@ -154,11 +160,6 @@ def test_recursive_leaf_first_traversal(): tree["b"]["a"] = tree["a"] seen_keys = set() - visit_ids = { - id(tree), - id(tree["a"]), - id(tree["b"]), - } reverse_paths = { ("a", "b"): [("a",), ("b",)], ("b", "a"): [("a",), ("b",)], @@ -170,14 +171,14 @@ def test_recursive_leaf_first_traversal(): ("a", "b"), ("b", "a"), } - for node, edge in _itertree.leaf_first(tree): + visits = [] + for node, edge in _traversal_to_generator(tree, traversal): keys = _itertree.edge_to_keys(edge) assert keys in expected obj = tree for key in keys: obj = obj[key] - assert obj is node - visit_ids.remove(id(node)) + visits.append((obj, edge)) # updated expected seen_keys.add(keys) @@ -186,7 +187,7 @@ def test_recursive_leaf_first_traversal(): if new_keys in seen_keys: continue expected.add(new_keys) - assert not visit_ids + assert len(visits) == 3 @pytest.mark.parametrize( @@ -239,7 +240,10 @@ def test_depth_first_traversal(traversal): assert not expected -def test_recursive_depth_first_traversal(): +@pytest.mark.parametrize( + "traversal", [_itertree.depth_first, _itertree.depth_first_modify, _itertree.depth_first_modify_and_copy] +) +def test_recursive_depth_first_traversal(traversal): tree = { "a": {}, "b": {}, @@ -248,11 +252,6 @@ def test_recursive_depth_first_traversal(): tree["b"]["a"] = tree["a"] seen_keys = set() - visit_ids = { - id(tree), - id(tree["a"]), - id(tree["b"]), - } forward_paths = { (): [("a",), ("b",)], ("a",): [("a", "b")], @@ -263,14 +262,14 @@ def test_recursive_depth_first_traversal(): expected = { (), } - for node, edge in _itertree.depth_first(tree): + visits = [] + for node, edge in _traversal_to_generator(tree, traversal): keys = _itertree.edge_to_keys(edge) assert keys in expected obj = tree for key in keys: obj = obj[key] - assert obj is node - visit_ids.remove(id(node)) + visits.append((node, edge)) # updated expected seen_keys.add(keys) @@ -279,7 +278,7 @@ def test_recursive_depth_first_traversal(): if new_keys in seen_keys: continue expected.add(new_keys) - assert not visit_ids + assert len(visits) == 3 def test_breadth_first_modify(): From 320da109f3596f348a0b59844e4c7869349151b9 Mon Sep 17 00:00:00 2001 From: Brett Date: Wed, 3 Jan 2024 14:00:30 -0500 Subject: [PATCH 13/16] replace _walk_tree_breadth_first --- asdf/search.py | 54 +++++++++++++++----------------------------------- 1 file changed, 16 insertions(+), 38 deletions(-) diff --git a/asdf/search.py b/asdf/search.py index d8507c5d8..531183d45 100644 --- a/asdf/search.py +++ b/asdf/search.py @@ -6,14 +6,21 @@ import re import typing +from . import _itertree from ._display import DEFAULT_MAX_COLS, DEFAULT_MAX_ROWS, DEFAULT_SHOW_VALUES, format_faint, format_italic, render_tree from ._node_info import NodeSchemaInfo, collect_schema_info -from .treeutil import _get_children, _is_container +from .treeutil import _is_container from .util import NotSet __all__ = ["AsdfSearchResult"] +def _get_children(obj): + if hasattr(obj, "__asdf_traverse__"): + obj = obj.__asdf_traverse__() + return _itertree._default_get_children(obj) + + class AsdfSearchResult: """ Result of a call to AsdfFile.search. @@ -222,11 +229,9 @@ def replace(self, value): """ results = [] - def _callback(identifiers, parent, node, children): - if all(f(node, identifiers[-1]) for f in self._filters): - results.append((identifiers[-1], parent)) - - _walk_tree_breadth_first(self._identifiers, self._node, _callback) + for node, edge in _itertree.breadth_first(self._node, get_children=_get_children): + if all(f(node, edge.key) for f in self._filters): + results.append((edge.key, edge.parent.node)) for identifier, parent in results: parent[identifier] = value @@ -284,11 +289,10 @@ def nodes(self): """ results = [] - def _callback(identifiers, parent, node, children): - if all(f(node, identifiers[-1]) for f in self._filters): + for node, edge in _itertree.breadth_first(self._node, get_children=_get_children): + if all(f(node, edge.key) for f in self._filters): results.append(node) - _walk_tree_breadth_first(self._identifiers, self._node, _callback) return results @property @@ -303,11 +307,10 @@ def paths(self): """ results = [] - def _callback(identifiers, parent, node, children): - if all(f(node, identifiers[-1]) for f in self._filters): - results.append(_build_path(identifiers)) + for node, edge in _itertree.breadth_first(self._node, get_children=_get_children): + if all(f(node, edge.key) for f in self._filters): + results.append(_build_path(self._identifiers + list(_itertree.edge_to_keys(edge)))) - _walk_tree_breadth_first(self._identifiers, self._node, _callback) return results def __repr__(self): @@ -369,31 +372,6 @@ def __getitem__(self, key): ) -def _walk_tree_breadth_first(root_identifiers, root_node, callback): - """ - Walk the tree in breadth-first order (useful for prioritizing - lower-depth nodes). - """ - current_nodes = [(root_identifiers, None, root_node)] - seen = set() - while True: - next_nodes = [] - - for identifiers, parent, node in current_nodes: - if (isinstance(node, (dict, list, tuple)) or NodeSchemaInfo.traversable(node)) and id(node) in seen: - continue - tnode = node.__asdf_traverse__() if NodeSchemaInfo.traversable(node) else node - children = _get_children(tnode) - callback(identifiers, parent, node, [c for _, c in children]) - next_nodes.extend([([*identifiers, i], node, c) for i, c in children]) - seen.add(id(node)) - - if len(next_nodes) == 0: - break - - current_nodes = next_nodes - - def _build_path(identifiers): """ Generate the Python code needed to extract the identified node. From 4144d4ee37ca385806a7a67811a8bffcf468299b Mon Sep 17 00:00:00 2001 From: Brett Date: Wed, 3 Jan 2024 16:46:41 -0500 Subject: [PATCH 14/16] cleanup comments --- asdf/_itertree.py | 1 - asdf/_tests/test_itertree.py | 12 ------------ 2 files changed, 13 deletions(-) diff --git a/asdf/_itertree.py b/asdf/_itertree.py index e21597295..17801709d 100644 --- a/asdf/_itertree.py +++ b/asdf/_itertree.py @@ -27,7 +27,6 @@ - c, b (any order) - a (note that this is the inverse of depth-first) - """ import collections diff --git a/asdf/_tests/test_itertree.py b/asdf/_tests/test_itertree.py index 1fbdbaafe..26662c1dc 100644 --- a/asdf/_tests/test_itertree.py +++ b/asdf/_tests/test_itertree.py @@ -1,15 +1,3 @@ -""" -Reorganize these tests into a fixture that generates test trees -Each test tree returns: - - the tree - - the breadth-first order (need to account for multiple valid paths) - - the depth-first order (need to account for multiple valid paths) -The orderings above need to be reversible to allow postorder tests -to also check the path. - -For modification tests the callbacks can also check the order -and can modify the tree in a way that should impact later callbacks. -""" import copy import pytest From 5f7d86bad17771c2c47508212eb7824403c9165a Mon Sep 17 00:00:00 2001 From: Brett Date: Wed, 3 Jan 2024 16:48:15 -0500 Subject: [PATCH 15/16] fix walk_and_modify note about copy --- asdf/treeutil.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/asdf/treeutil.py b/asdf/treeutil.py index 92f7a5421..a7fb2efc3 100644 --- a/asdf/treeutil.py +++ b/asdf/treeutil.py @@ -113,7 +113,8 @@ def _container_factory(obj): def walk_and_modify(top, callback, ignore_implicit_conversion=False, postorder=True, _track_id=False): """Modify a tree by walking it with a callback function. It also has - the effect of doing a deep copy. + the effect of doing a copy. Only "containers" (dict, list, etc) will be + copied, all "leaf" nodes will not be copied. Parameters ---------- From 5136cae7ef7c25740c66f93712b14034aba90be7 Mon Sep 17 00:00:00 2001 From: Brett Date: Thu, 4 Jan 2024 11:58:58 -0500 Subject: [PATCH 16/16] remove _track_id --- asdf/schema.py | 16 +++++++++------- asdf/treeutil.py | 7 ++----- 2 files changed, 11 insertions(+), 12 deletions(-) diff --git a/asdf/schema.py b/asdf/schema.py index e5af67121..061d0e8f4 100644 --- a/asdf/schema.py +++ b/asdf/schema.py @@ -14,7 +14,7 @@ from asdf._jsonschema import validators as mvalidators from asdf._jsonschema.exceptions import RefResolutionError, ValidationError -from . import constants, generic_io, reference, tagged, treeutil, util, versioning, yamlutil +from . import _itertree, constants, generic_io, reference, tagged, treeutil, util, versioning, yamlutil from .config import get_config from .exceptions import AsdfDeprecationWarning, AsdfWarning from .util import _patched_urllib_parse @@ -471,11 +471,10 @@ def _load_schema_cached(url, resolver, resolve_references): if resolve_references: - def resolve_refs(node, json_id): - if json_id is None: - json_id = url - + def resolve_refs(node, edge): if isinstance(node, dict) and "$ref" in node: + json_id = treeutil._get_json_id(schema, edge) or url + suburl_base, suburl_fragment = _safe_resolve(resolver, json_id, node["$ref"]) if suburl_base == url or suburl_base == schema.get("id"): @@ -485,10 +484,13 @@ def resolve_refs(node, json_id): subschema = load_schema(suburl_base, resolver, True) return reference.resolve_fragment(subschema, suburl_fragment) - return node - schema = treeutil.walk_and_modify(schema, resolve_refs, _track_id=True) + # We need to copy here so that we don't end up with a recursive tree. + # When full resolution results in a recursive tree the returned recursive + # schema would cause check_schema to fail. This means some local $ref + # instances are not resolved. + schema = _itertree.leaf_first_modify_and_copy(schema, resolve_refs) return schema diff --git a/asdf/treeutil.py b/asdf/treeutil.py index a7fb2efc3..5e0ef153e 100644 --- a/asdf/treeutil.py +++ b/asdf/treeutil.py @@ -111,7 +111,7 @@ def _container_factory(obj): return result -def walk_and_modify(top, callback, ignore_implicit_conversion=False, postorder=True, _track_id=False): +def walk_and_modify(top, callback, ignore_implicit_conversion=False, postorder=True): """Modify a tree by walking it with a callback function. It also has the effect of doing a copy. Only "containers" (dict, list, etc) will be copied, all "leaf" nodes will not be copied. @@ -165,12 +165,9 @@ def walk_and_modify(top, callback, ignore_implicit_conversion=False, postorder=T else: modify = _itertree.depth_first_modify_and_copy - if callback.__code__.co_argcount == 2 and not _track_id: - _track_id = True + if callback.__code__.co_argcount == 2: warnings.warn("the json_id callback argument is deprecated", AsdfDeprecationWarning) - if _track_id: - def wrapped_callback(obj, edge): json_id = _get_json_id(top, edge) return callback(obj, json_id)