diff --git a/CHANGES.rst b/CHANGES.rst index 62a76753b..9fd08601a 100644 --- a/CHANGES.rst +++ b/CHANGES.rst @@ -21,6 +21,13 @@ The ASDF Standard is at v1.6.0 is removed in an upcoming asdf release will be ``False`` and asdf will no longer by-default memory map arrays. [#1667] +- Fix bug where a dictionary containing a key ``id`` caused + any contained references to fail to resolve [#1716] + +- Deprecate the following in ``asdf.treeutil``. + ``get_children``, ``is_container``, the ``json_id`` argument to + callbacks provided to ``walk_and_modify`` [#1719] + 3.0.1 (2023-10-30) ------------------ diff --git a/asdf/_asdf.py b/asdf/_asdf.py index 69da27f9e..c44dfc31e 100644 --- a/asdf/_asdf.py +++ b/asdf/_asdf.py @@ -12,7 +12,7 @@ from . import _display as display from . import _node_info as node_info from . import _version as version -from . import constants, generic_io, reference, schema, treeutil, util, versioning, yamlutil +from . import constants, generic_io, reference, schema, util, versioning, yamlutil from ._block.manager import Manager as BlockManager from ._helpers import validate_version from .config import config_context, get_config @@ -169,11 +169,6 @@ def __init__( self._file_format_version = None - # Context of a call to treeutil.walk_and_modify, needed in the AsdfFile - # in case walk_and_modify is re-entered by extension code (via - # custom_tree_to_tagged_tree or tagged_tree_to_custom_tree). - self._tree_modification_context = treeutil._TreeModificationContext() - self._fd = None self._closed = False self._external_asdf_by_uri = {} diff --git a/asdf/_itertree.py b/asdf/_itertree.py new file mode 100644 index 000000000..17801709d --- /dev/null +++ b/asdf/_itertree.py @@ -0,0 +1,425 @@ +""" +For modification, order is important +for a tree + + a + / \ +b c + / \ + d e + +when walking breadth-first down the tree, modify: +- a first +- b, c (any order) +- then d, e (any order) + +this means that a might get modified changing +where b and c come from + +when walking depth-first down the tree, modify: +- a first +- b or c +- if b, then c +- if c then d, e (any order) + +when walking leaf-first up the tree, modify: +- d, e (any order) +- c, b (any order) +- a +(note that this is the inverse of depth-first) +""" +import collections + + +class _Edge: + __slots__ = ["parent", "key", "node"] + + def __init__(self, parent, key, node): + self.parent = parent + self.key = key # can be used to make path + self.node = node # can be used to get things like 'json_id', duplicate of obj in callback + + +class _RemoveNode: + """ + Class of the RemoveNode singleton instance. This instance is used + as a signal for `asdf.treeutil.walk_and_modify` to remove the + node received by the callback. + """ + + def __repr__(self): + return "RemoveNode" + + +RemoveNode = _RemoveNode() + + +def edge_to_keys(edge): + keys = [] + while edge.key is not None: + keys.append(edge.key) + edge = edge.parent + return tuple(keys[::-1]) + + +class _ShowValue: + __slots__ = ["obj", "obj_id"] + + def __init__(self, obj, obj_id): + self.obj = obj + self.obj_id = obj_id + + +def _default_get_children(obj): + if isinstance(obj, dict): + return obj.items() + elif isinstance(obj, (list, tuple)): + return enumerate(obj) + else: + return None + + +def breadth_first(d, get_children=None): + get_children = get_children or _default_get_children + seen = set() + dq = collections.deque() + dq.append(_Edge(None, None, d)) + while dq: + edge = dq.popleft() + obj = edge.node + obj_id = id(obj) + if obj_id in seen: + continue + yield obj, edge + children = get_children(obj) + if children: + seen.add(obj_id) + for key, value in children: + dq.append(_Edge(edge, key, value)) + + +def depth_first(d, get_children=None): + get_children = get_children or _default_get_children + seen = set() + dq = collections.deque() + dq.append(_Edge(None, None, d)) + while dq: + edge = dq.pop() + obj = edge.node + obj_id = id(obj) + if obj_id in seen: + continue + yield obj, edge + children = get_children(obj) + if children: + seen.add(obj_id) + for key, value in children: + dq.append(_Edge(edge, key, value)) + + +def leaf_first(d, get_children=None): + get_children = get_children or _default_get_children + seen = set() + dq = collections.deque() + dq.append(_Edge(None, None, d)) + while dq: + edge = dq.pop() + if isinstance(edge, _ShowValue): + edge = edge.obj + obj = edge.node + yield obj, edge + continue + obj = edge.node + obj_id = id(obj) + if obj_id in seen: + continue + children = get_children(obj) + if children: + seen.add(obj_id) + dq.append(_ShowValue(edge, obj_id)) + for key, value in children: + dq.append(_Edge(edge, key, value)) + continue + yield obj, edge + + +def _default_setitem(obj, key, value): + obj.__setitem__(key, value) + + +def _default_delitem(obj, key): + obj.__delitem__(key) + + +def _delete_items(edges, delitem): + # index all deletions by the parent node id + by_parent_id = {} + for edge in edges: + parent_id = id(edge.parent.node) + if parent_id not in by_parent_id: + by_parent_id[parent_id] = [] + by_parent_id[parent_id].append(edge) + for parent_id in by_parent_id: + # delete with highest/last key first + for edge in sorted(by_parent_id[parent_id], key=lambda edge: edge.key, reverse=True): + delitem(edge.parent.node, edge.key) + + +def breadth_first_modify(d, callback, get_children=None, setitem=None, delitem=None): + get_children = get_children or _default_get_children + cache = {} + to_delete = collections.deque() + setitem = setitem or _default_setitem + delitem = delitem or _default_delitem + dq = collections.deque() + dq.append(_Edge(None, None, d)) + while dq: + edge = dq.popleft() + obj = edge.node + obj_id = id(obj) + if obj_id in cache: + new_obj = cache[obj_id][1] + else: + new_obj = callback(obj, edge) + cache[obj_id] = (obj, new_obj) + children = get_children(new_obj) + if children: + for key, value in children: + dq.append(_Edge(edge, key, value)) + if edge.parent is not None: + if new_obj is RemoveNode: + to_delete.append(edge) + continue + setitem(edge.parent.node, edge.key, new_obj) + _delete_items(to_delete, delitem) + + +def depth_first_modify(d, callback, get_children=None, setitem=None, delitem=None): + get_children = get_children or _default_get_children + cache = {} + to_delete = collections.deque() + setitem = setitem or _default_setitem + delitem = delitem or _default_delitem + dq = collections.deque() + dq.append(_Edge(None, None, d)) + while dq: + edge = dq.pop() + obj = edge.node + obj_id = id(obj) + if obj_id in cache: + new_obj = cache[obj_id][1] + else: + new_obj = callback(obj, edge) + cache[obj_id] = (obj, new_obj) + children = get_children(new_obj) + if children: + for key, value in children: + dq.append(_Edge(edge, key, value)) + if edge.parent is not None: + if new_obj is RemoveNode: + to_delete.append(edge) + continue + setitem(edge.parent.node, edge.key, new_obj) + _delete_items(to_delete, delitem) + + +def leaf_first_modify(d, callback, get_children=None, setitem=None, delitem=None): + get_children = get_children or _default_get_children + cache = {} + pending = {} + to_delete = collections.deque() + setitem = setitem or _default_setitem + delitem = delitem or _default_delitem + dq = collections.deque() + dq.append(_Edge(None, None, d)) + while dq: + edge = dq.pop() + if isinstance(edge, _ShowValue): + obj_id = edge.obj_id + edge = edge.obj + obj = edge.node + if obj_id not in cache: + cache[obj_id] = (obj, callback(obj, edge)) + obj = cache[obj_id][1] + if edge.parent is not None: + if obj is RemoveNode: + to_delete.append(edge) + else: + setitem(edge.parent.node, edge.key, obj) + if obj_id in pending: + for edge in pending[obj_id]: + if obj is RemoveNode: + to_delete.append(edge) + else: + setitem(edge.parent.node, edge.key, obj) + del pending[obj_id] + continue + obj = edge.node + obj_id = id(obj) + if obj_id not in pending: + pending[obj_id] = [] + children = get_children(obj) + dq.append(_ShowValue(edge, obj_id)) + if children: + for key, value in children: + if id(value) in pending: + pending[id(value)].append(_Edge(edge, key, value)) + else: + dq.append(_Edge(edge, key, value)) + continue + _delete_items(to_delete, delitem) + + +def _default_container_factory(obj): + if isinstance(obj, dict): + # init with keys to retain order + return {k: None for k in obj} + elif isinstance(obj, (list, tuple)): + return [None] * len(obj) + raise NotImplementedError() + + +def breadth_first_modify_and_copy(d, callback, get_children=None, setitem=None, delitem=None, container_factory=None): + get_children = get_children or _default_get_children + cache = {} + to_delete = collections.deque() + setitem = setitem or _default_setitem + delitem = delitem or _default_delitem + container_factory = container_factory or _default_container_factory + result = None + dq = collections.deque() + dq.append(_Edge(None, None, d)) + while dq: + edge = dq.popleft() + obj = edge.node + obj_id = id(obj) + if obj_id in cache: + obj = cache[obj_id][1] + else: + cobj = callback(obj, edge) + if edge.parent is not None and cobj is RemoveNode: + to_delete.append(edge) + continue + children = get_children(cobj) + if children: + container = container_factory(cobj) + edge.node = container + for key, value in children: + dq.append(_Edge(edge, key, value)) + cobj = container + cache[obj_id] = (obj, cobj) + obj = cobj + if result is None: + result = obj + if edge.parent is not None: + setitem(edge.parent.node, edge.key, obj) + _delete_items(to_delete, delitem) + return result + + +def depth_first_modify_and_copy(d, callback, get_children=None, setitem=None, delitem=None, container_factory=None): + get_children = get_children or _default_get_children + cache = {} + to_delete = collections.deque() + setitem = setitem or _default_setitem + delitem = delitem or _default_delitem + container_factory = container_factory or _default_container_factory + result = None + dq = collections.deque() + dq.append(_Edge(None, None, d)) + while dq: + edge = dq.pop() + obj = edge.node + obj_id = id(obj) + if obj_id in cache: + new_obj = cache[obj_id][1] + else: + new_obj = callback(obj, edge) + if edge.parent is not None and new_obj is RemoveNode: + to_delete.append(edge) + continue + children = get_children(new_obj) + if children: + container = container_factory(new_obj) + edge.node = container + for key, value in children: + dq.append(_Edge(edge, key, value)) + new_obj = container + cache[obj_id] = (obj, new_obj) + if result is None: + result = new_obj + if edge.parent is not None: + setitem(edge.parent.node, edge.key, new_obj) + _delete_items(to_delete, delitem) + return result + + +def leaf_first_modify_and_copy(d, callback, get_children=None, setitem=None, delitem=None, container_factory=None): + get_children = get_children or _default_get_children + pending = {} + cache = {} + to_delete = collections.deque() + setitem = setitem or _default_setitem + delitem = delitem or _default_delitem + container_factory = container_factory or _default_container_factory + result = None + dq = collections.deque() + dq.append(_Edge(None, None, d)) + while dq: + edge = dq.pop() + if isinstance(edge, _ShowValue): + obj_id = edge.obj_id + edge = edge.obj + obj = edge.node + if obj is result: + result = None + if obj_id not in cache: + new_obj = callback(obj, edge) + cache[obj_id] = (obj, new_obj) + else: + new_obj = cache[obj_id][1] + if result is None: + result = new_obj + if edge.parent is not None: + if new_obj is RemoveNode: + to_delete.append(edge) + else: + setitem(edge.parent.node, edge.key, new_obj) + if obj_id in pending: + for edge in pending[obj_id]: + if new_obj is RemoveNode: + to_delete.append(edge) + else: + setitem(edge.parent.node, edge.key, new_obj) + del pending[obj_id] + continue + obj = edge.node + obj_id = id(obj) + if obj_id in cache: + new_obj = cache[obj_id][1] + else: + children = get_children(obj) + if children: + container = container_factory(obj) + pending[obj_id] = [] + if result is None: + result = container + edge.node = container + dq.append(_ShowValue(edge, obj_id)) + for key, value in children: + if id(value) in pending: + pending[id(value)].append(_Edge(edge, key, value)) + else: + dq.append(_Edge(edge, key, value)) + continue + new_obj = callback(obj, edge) + cache[obj_id] = (obj, new_obj) + if result is None: + result = new_obj + if edge.parent is not None: + if new_obj is RemoveNode: + to_delete.append(edge) + else: + setitem(edge.parent.node, edge.key, new_obj) + _delete_items(to_delete, delitem) + return result diff --git a/asdf/_node_info.py b/asdf/_node_info.py index e61b22241..6b1acc16f 100644 --- a/asdf/_node_info.py +++ b/asdf/_node_info.py @@ -2,7 +2,7 @@ from collections import namedtuple from .schema import load_schema -from .treeutil import get_children +from .treeutil import _get_children def _filter_tree(info, filters): @@ -290,7 +290,7 @@ def from_root_node(cls, key, root_identifier, root_node, schema=None, refresh_ex if parent is None: info.schema = schema - for child_identifier, child_node in get_children(t_node): + for child_identifier, child_node in _get_children(t_node): next_nodes.append((info, child_identifier, child_node)) if len(next_nodes) == 0: diff --git a/asdf/_tests/_regtests/test_1715.py b/asdf/_tests/_regtests/test_1715.py new file mode 100644 index 000000000..18c140eaa --- /dev/null +++ b/asdf/_tests/_regtests/test_1715.py @@ -0,0 +1,28 @@ +import pytest + +import asdf + + +def test_id_in_tree_breaks_ref(tmp_path): + """ + a dict containing id will break contained References + + https://github.com/asdf-format/asdf/issues/1715 + """ + external_fn = tmp_path / "external.asdf" + + external_tree = {"thing": 42} + + asdf.AsdfFile(external_tree).write_to(external_fn) + + main_fn = tmp_path / "main.asdf" + + af = asdf.AsdfFile({}) + af["id"] = "bogus" + af["myref"] = {"$ref": "external.asdf#/thing"} + af.write_to(main_fn) + + with pytest.warns(asdf.exceptions.AsdfDeprecationWarning, match="find_references"): + with asdf.open(main_fn) as af: + af.resolve_references() + assert af["myref"] == 42 diff --git a/asdf/_tests/test_array_blocks.py b/asdf/_tests/test_array_blocks.py index b6163ce08..1f79d7a2c 100644 --- a/asdf/_tests/test_array_blocks.py +++ b/asdf/_tests/test_array_blocks.py @@ -589,10 +589,11 @@ def test_block_index(): assert ff2._blocks.blocks[99].loaded # Force the loading of one array - ff2.tree["arrays"][50] * 2 + arr = ff2.tree["arrays"][50] + arr * 2 for i in range(2, 99): - if i == 50: + if i == arr._source: assert ff2._blocks.blocks[i].loaded else: assert not ff2._blocks.blocks[i].loaded diff --git a/asdf/_tests/test_block_converter.py b/asdf/_tests/test_block_converter.py index 462f7c363..8e5ce688c 100644 --- a/asdf/_tests/test_block_converter.py +++ b/asdf/_tests/test_block_converter.py @@ -1,4 +1,5 @@ import contextlib +from io import BytesIO import numpy as np from numpy.testing import assert_array_equal @@ -125,8 +126,12 @@ def test_block_data_callback_converter(tmp_path): # id(arr) would change every time a = BlockDataCallback(lambda: np.zeros(3, dtype="uint8")) - b = helpers.roundtrip_object(a) - assert_array_equal(a.data, b.data) + bs = BytesIO() + af = asdf.AsdfFile({"obj": a}) + af.write_to(bs) + bs.seek(0) + with asdf.open(bs, lazy_load=False, memmap=False) as af: + assert_array_equal(a.data, af["obj"].data) # make a tree without the BlockData instance to avoid # the initial validate which will trigger block allocation diff --git a/asdf/_tests/test_deprecated.py b/asdf/_tests/test_deprecated.py index 0df6b1f7e..608bb4da4 100644 --- a/asdf/_tests/test_deprecated.py +++ b/asdf/_tests/test_deprecated.py @@ -67,3 +67,21 @@ def test_find_references_during_open_deprecation(tmp_path): with pytest.warns(AsdfDeprecationWarning, match="find_references during open"): with asdf.open(fn) as af: pass + + +def test_get_children_deprecation(): + with pytest.warns(AsdfDeprecationWarning, match="get_children is deprecated"): + asdf.treeutil.get_children({}) + + +def test_is_container_deprecation(): + with pytest.warns(AsdfDeprecationWarning, match="is_container is deprecated"): + asdf.treeutil.is_container({}) + + +def test_json_id_deprecation(): + def callback(obj, json_id): + return obj + + with pytest.warns(AsdfDeprecationWarning, match="the json_id callback argument is deprecated"): + asdf.treeutil.walk_and_modify({"a": 1}, callback) diff --git a/asdf/_tests/test_itertree.py b/asdf/_tests/test_itertree.py new file mode 100644 index 000000000..26662c1dc --- /dev/null +++ b/asdf/_tests/test_itertree.py @@ -0,0 +1,533 @@ +import copy + +import pytest + +from asdf import _itertree + + +def _traversal_to_generator(tree, traversal): + if "modify" not in traversal.__name__: + return traversal(tree) + + def make_generator(tree): + values = [] + + def callback(obj, edge): + values.append((obj, edge)) + return obj + + traversal(tree, callback) + yield from values + + return make_generator(tree) + + +@pytest.mark.parametrize( + "traversal", [_itertree.breadth_first, _itertree.breadth_first_modify, _itertree.breadth_first_modify_and_copy] +) +def test_breadth_first_traversal(traversal): + tree = { + "a": { + "b": [1, 2, {"c": 3}], + "d": 4, + }, + "e": [5, 6, [7, 8, {"f": 9}]], + } + # It is ok for results to come in any order as long as + # all nodes closer to the root come before any more distant + # node. Track those here by ordering expected results by 'layer' + expected_results = [ + [ + tree, + ], + [tree["a"], tree["e"]], + [tree["a"]["b"], tree["a"]["d"], tree["e"][0], tree["e"][1], tree["e"][2]], + [tree["a"]["b"][0], tree["a"]["b"][1], tree["a"]["b"][2], tree["e"][2][0], tree["e"][2][1], tree["e"][2][2]], + [tree["a"]["b"][2]["c"], tree["e"][2][2]["f"]], + ] + + expected = [] + + for node, edge in _traversal_to_generator(tree, traversal): + if not len(expected): + expected = expected_results.pop(0) + assert node in expected + expected.remove(node) + assert not expected_results + + +@pytest.mark.parametrize( + "traversal", [_itertree.breadth_first, _itertree.breadth_first_modify, _itertree.breadth_first_modify_and_copy] +) +def test_recursive_breadth_first_traversal(traversal): + tree = { + "a": {}, + "b": {}, + } + tree["a"]["b"] = tree["b"] + tree["b"]["a"] = tree["a"] + + expected_results = [ + [ + tree, + ], + [tree["a"], tree["b"]], + ] + + expected = [] + for node, edge in _traversal_to_generator(tree, traversal): + if not len(expected): + expected = expected_results.pop(0) + assert node in expected + expected.remove(node) + assert not expected_results + + +@pytest.mark.parametrize( + "traversal", [_itertree.leaf_first, _itertree.leaf_first_modify, _itertree.leaf_first_modify_and_copy] +) +def test_leaf_first_traversal(traversal): + tree = { + "a": { + "b": [1, 2, {"c": 3}], + "d": 4, + }, + "e": [5, 6, [7, 8, {"f": 9}]], + } + seen_keys = set() + reverse_paths = { + ("e", 2, 2, "f"): [("e", 2, 2), ("e", 2, 1), ("e", 2, 0)], + ("e", 2, 2): [("e", 0), ("e", 1), ("e", 2)], + ("e", 2, 1): [("e", 0), ("e", 1), ("e", 2)], + ("e", 2, 0): [("e", 0), ("e", 1), ("e", 2)], + ("e", 2): [("e",), ("a",)], + ("e", 1): [("e",), ("a",)], + ("e", 0): [("e",), ("a",)], + ("e",): [()], + ("a", "b", 2, "c"): [("a", "b", 0), ("a", "b", 1), ("a", "b", 2)], + ("a", "b", 2): [("a", "b"), ("a", "d")], + ("a", "b", 1): [("a", "b"), ("a", "d")], + ("a", "b", 0): [("a", "b"), ("a", "d")], + ("a", "b"): [("a",), ("e",)], + ("a", "d"): [("a",), ("e",)], + ("a",): [()], + (): [], + } + expected = { + ("e", 2, 2, "f"), + ("a", "b", 2, "c"), + ("a", "d"), + } + for node, edge in _traversal_to_generator(tree, traversal): + keys = _itertree.edge_to_keys(edge) + assert keys in expected + obj = tree + for key in keys: + obj = obj[key] + assert obj == node + + # updated expected + seen_keys.add(keys) + expected.remove(keys) + for new_keys in reverse_paths[keys]: + if new_keys in seen_keys: + continue + expected.add(new_keys) + assert not expected + + +@pytest.mark.parametrize( + "traversal", [_itertree.leaf_first, _itertree.leaf_first_modify, _itertree.leaf_first_modify_and_copy] +) +def test_recursive_leaf_first_traversal(traversal): + tree = { + "a": {}, + "b": {}, + } + tree["a"]["b"] = tree["b"] + tree["b"]["a"] = tree["a"] + + seen_keys = set() + reverse_paths = { + ("a", "b"): [("a",), ("b",)], + ("b", "a"): [("a",), ("b",)], + ("a",): [()], + ("b",): [()], + (): [], + } + expected = { + ("a", "b"), + ("b", "a"), + } + visits = [] + for node, edge in _traversal_to_generator(tree, traversal): + keys = _itertree.edge_to_keys(edge) + assert keys in expected + obj = tree + for key in keys: + obj = obj[key] + visits.append((obj, edge)) + + # updated expected + seen_keys.add(keys) + expected.remove(keys) + for new_keys in reverse_paths[keys]: + if new_keys in seen_keys: + continue + expected.add(new_keys) + assert len(visits) == 3 + + +@pytest.mark.parametrize( + "traversal", [_itertree.depth_first, _itertree.depth_first_modify, _itertree.depth_first_modify_and_copy] +) +def test_depth_first_traversal(traversal): + tree = { + "a": { + "b": [1, 2, {"c": 3}], + "d": 4, + }, + "e": [5, 6, [7, 8, {"f": 9}]], + } + forward_paths = { + (): [("a",), ("e",)], + ("a",): [("a", "b"), ("a", "d")], + ("a", "b"): [("a", "b", 0), ("a", "b", 1), ("a", "b", 2)], + ("a", "b", 0): [], + ("a", "b", 1): [], + ("a", "b", 2): [("a", "b", 2, "c")], + ("a", "b", 2, "c"): [], + ("a", "d"): [], + ("e",): [("e", 0), ("e", 1), ("e", 2)], + ("e", 0): [], + ("e", 1): [], + ("e", 2): [("e", 2, 0), ("e", 2, 1), ("e", 2, 2)], + ("e", 2, 0): [], + ("e", 2, 1): [], + ("e", 2, 2): [("e", 2, 2, "f")], + ("e", 2, 2, "f"): [], + } + expected = {()} + seen_keys = set() + + for node, edge in _traversal_to_generator(tree, traversal): + keys = _itertree.edge_to_keys(edge) + assert keys in expected + obj = tree + for key in keys: + obj = obj[key] + assert obj == node + + # updated expected + seen_keys.add(keys) + expected.remove(keys) + for new_keys in forward_paths[keys]: + if new_keys in seen_keys: + continue + expected.add(new_keys) + assert not expected + + +@pytest.mark.parametrize( + "traversal", [_itertree.depth_first, _itertree.depth_first_modify, _itertree.depth_first_modify_and_copy] +) +def test_recursive_depth_first_traversal(traversal): + tree = { + "a": {}, + "b": {}, + } + tree["a"]["b"] = tree["b"] + tree["b"]["a"] = tree["a"] + + seen_keys = set() + forward_paths = { + (): [("a",), ("b",)], + ("a",): [("a", "b")], + ("b",): [("b", "a")], + ("a", "b"): [], + ("b", "a"): [], + } + expected = { + (), + } + visits = [] + for node, edge in _traversal_to_generator(tree, traversal): + keys = _itertree.edge_to_keys(edge) + assert keys in expected + obj = tree + for key in keys: + obj = obj[key] + visits.append((node, edge)) + + # updated expected + seen_keys.add(keys) + expected.remove(keys) + for new_keys in forward_paths[keys]: + if new_keys in seen_keys: + continue + expected.add(new_keys) + assert len(visits) == 3 + + +def test_breadth_first_modify(): + tree = { + "a": [1, 2, {"b": 3}], + "c": { + "d": [4, 5, 6], + }, + } + + def callback(obj, keys): + if isinstance(obj, list) and 1 in obj: + return [1, 2, 3] + if isinstance(obj, dict): + assert "b" not in obj + return obj + + _itertree.breadth_first_modify(tree, callback) + assert tree["a"] == [1, 2, 3] + + +def test_depth_first_modify(): + tree = { + "a": [1, 2, {"b": 3}], + "c": { + "d": [4, 5, 6], + }, + } + + def callback(obj, keys): + if isinstance(obj, dict) and "d" in obj: + obj["d"] = [42] + if isinstance(obj, list) and 42 in obj: + assert len(obj) == 1 + return obj + + _itertree.depth_first_modify(tree, callback) + assert tree["c"]["d"] == [42] + + +def test_leaf_first_modify(): + tree = { + "a": [1, 2, {"b": 3}], + "c": { + "d": [4, 5, 6], + }, + } + + def callback(obj, keys): + if isinstance(obj, list) and 1 in obj: + assert 42 in obj + if isinstance(obj, dict) and "b" in obj: + return 42 + return obj + + _itertree.leaf_first_modify(tree, callback) + assert tree["a"] == [1, 2, 42] + + +def test_breadth_first_modify_and_copy(): + tree = { + "a": [1, 2, {"b": 3}], + "c": { + "d": [4, 5, 6], + }, + } + + def callback(obj, keys): + if isinstance(obj, list) and 1 in obj: + assert 42 not in obj + if isinstance(obj, dict) and "b" in obj: + return 42 + return obj + + # copy the tree to make sure it's not modified + copied_tree = copy.deepcopy(tree) + result = _itertree.breadth_first_modify_and_copy(copied_tree, callback) + assert result["a"] == [1, 2, 42] + assert copied_tree == tree + + +def test_depth_first_modify_and_copy(): + tree = { + "a": [1, 2, {"b": 3}], + "c": { + "d": [4, 5, 6], + }, + } + + def callback(obj, keys): + if isinstance(obj, list) and 1 in obj: + assert 42 not in obj + if isinstance(obj, dict) and "b" in obj: + return 42 + return obj + + # copy the tree to make sure it's not modified + copied_tree = copy.deepcopy(tree) + result = _itertree.depth_first_modify_and_copy(copied_tree, callback) + assert result["a"] == [1, 2, 42] + assert copied_tree == tree + + +def test_leaf_first_modify_and_copy(): + tree = { + "a": [1, 2, {"b": 3}], + "c": { + "d": [4, 5, 6], + }, + } + + def callback(obj, keys): + if isinstance(obj, list) and 1 in obj: + assert 42 in obj + if isinstance(obj, dict) and "b" in obj: + return 42 + return obj + + # copy the tree to make sure it's not modified + copied_tree = copy.deepcopy(tree) + result = _itertree.leaf_first_modify_and_copy(copied_tree, callback) + assert result["a"] == [1, 2, 42] + assert copied_tree == tree + + +@pytest.mark.parametrize( + "traversal", + [ + _itertree.breadth_first_modify, + _itertree.depth_first_modify, + _itertree.leaf_first_modify, + _itertree.breadth_first_modify_and_copy, + _itertree.depth_first_modify_and_copy, + _itertree.leaf_first_modify_and_copy, + ], +) +def test_node_removal(traversal): + tree = { + "a": [1, 2, 3], + "b": 4, + } + + def callback(obj, edge): + if obj in (1, 3, 4): + return _itertree.RemoveNode + return obj + + result = traversal(tree, callback) + if result is not None: # this is a copy + original = tree + else: + original = None + result = tree + assert result["a"] == [ + 2, + ] + assert "b" not in result + if original is not None: + assert original["a"] == [1, 2, 3] + assert original["b"] == 4 + assert set(original.keys()) == {"a", "b"} + + +@pytest.mark.parametrize( + "traversal", + [ + _itertree.breadth_first_modify, + _itertree.depth_first_modify, + _itertree.leaf_first_modify, + _itertree.breadth_first_modify_and_copy, + _itertree.depth_first_modify_and_copy, + _itertree.leaf_first_modify_and_copy, + ], +) +def test_key_order(traversal): + """ + All traversal and modification functions should preserve + the order of keys in a dictionary + """ + tree = {} + tree["a"] = [1, 2] + tree["b"] = [3, 4] + tree["c"] = {} + tree["c"]["d"] = [5, 6] + tree["c"]["e"] = [7, 8] + + result = traversal(tree, lambda obj, edge: obj) + if result is None: + result = tree + + assert list(result.keys()) == ["a", "b", "c"] + assert result["a"] == [1, 2] + assert result["b"] == [3, 4] + assert list(result["c"].keys()) == ["d", "e"] + assert result["c"]["d"] == [5, 6] + assert result["c"]["e"] == [7, 8] + + +@pytest.mark.parametrize( + "traversal", + [ + _itertree.breadth_first_modify, + _itertree.depth_first_modify, + _itertree.leaf_first_modify, + _itertree.breadth_first_modify_and_copy, + _itertree.depth_first_modify_and_copy, + _itertree.leaf_first_modify_and_copy, + ], +) +def test_cache_callback(traversal): + class Foo: + pass + + obj = Foo() + obj.count = 0 + + tree = {} + tree["a"] = obj + tree["b"] = obj + tree["c"] = {"d": obj} + + def callback(obj, edge): + if isinstance(obj, Foo): + obj.count += 1 + return obj + + result = traversal(tree, callback) + if result is None: + result = tree + + assert result["a"].count == 1 + assert result["b"].count == 1 + assert result["c"]["d"].count == 1 + + +@pytest.mark.parametrize( + "traversal", + [ + _itertree.breadth_first_modify, + _itertree.depth_first_modify, + _itertree.leaf_first_modify, + _itertree.breadth_first_modify_and_copy, + _itertree.depth_first_modify_and_copy, + _itertree.leaf_first_modify_and_copy, + ], +) +def test_recursive_object(traversal): + tree = {} + tree["a"] = {"count": 0} + tree["b"] = {"a": tree["a"]} + tree["a"]["b"] = tree["b"] + + def callback(obj, edge): + if isinstance(obj, dict) and "count" in obj: + obj["count"] += 1 + return obj + + result = traversal(tree, callback) + if result is None: + result = tree + + assert result["a"]["count"] == 1 + assert result["b"]["a"]["count"] == 1 + assert result["a"] is result["b"]["a"] diff --git a/asdf/_tests/test_schema.py b/asdf/_tests/test_schema.py index 82cea9238..996ed8fbe 100644 --- a/asdf/_tests/test_schema.py +++ b/asdf/_tests/test_schema.py @@ -1216,6 +1216,8 @@ def test_validator_visit_repeat_nodes(): node = asdf.tags.core.Software(name="Minesweeper") tree = yamlutil.custom_tree_to_tagged_tree({"node": node, "other_node": node, "nested": {"node": node}}, ctx) + assert tree["node"] is tree["other_node"] + assert tree["node"] is tree["nested"]["node"] visited_nodes = [] def _test_validator(validator, value, instance, schema): diff --git a/asdf/_tests/test_treeutil.py b/asdf/_tests/test_treeutil.py index 5e9139ee8..147a12da8 100644 --- a/asdf/_tests/test_treeutil.py +++ b/asdf/_tests/test_treeutil.py @@ -3,27 +3,27 @@ def test_get_children(): parent = ["foo", "bar"] - assert treeutil.get_children(parent) == [(0, "foo"), (1, "bar")] + assert treeutil._get_children(parent) == [(0, "foo"), (1, "bar")] parent = ("foo", "bar") - assert treeutil.get_children(parent) == [(0, "foo"), (1, "bar")] + assert treeutil._get_children(parent) == [(0, "foo"), (1, "bar")] parent = {"foo": "bar", "ding": "dong"} - assert sorted(treeutil.get_children(parent)) == sorted([("foo", "bar"), ("ding", "dong")]) + assert sorted(treeutil._get_children(parent)) == sorted([("foo", "bar"), ("ding", "dong")]) parent = "foo" - assert treeutil.get_children(parent) == [] + assert treeutil._get_children(parent) == [] parent = None - assert treeutil.get_children(parent) == [] + assert treeutil._get_children(parent) == [] def test_is_container(): for value in [[], {}, ()]: - assert treeutil.is_container(value) is True + assert treeutil._is_container(value) is True for value in ["foo", 12, 13.9827]: - assert treeutil.is_container(value) is False + assert treeutil._is_container(value) is False def test_walk_and_modify_shared_references(): diff --git a/asdf/reference.py b/asdf/reference.py index b408653e0..73ace1998 100644 --- a/asdf/reference.py +++ b/asdf/reference.py @@ -112,11 +112,11 @@ def find_references(tree, ctx, _warning_msg=False): `Reference` objects. """ - def do_find(tree, json_id): + def do_find(tree): if isinstance(tree, dict) and "$ref" in tree: if _warning_msg: warnings.warn(_warning_msg, AsdfDeprecationWarning) - return Reference(tree["$ref"], json_id, asdffile=ctx) + return Reference(tree["$ref"], asdffile=ctx) return tree return treeutil.walk_and_modify(tree, do_find, ignore_implicit_conversion=ctx._ignore_implicit_conversion) diff --git a/asdf/schema.py b/asdf/schema.py index c09bad792..061d0e8f4 100644 --- a/asdf/schema.py +++ b/asdf/schema.py @@ -14,7 +14,7 @@ from asdf._jsonschema import validators as mvalidators from asdf._jsonschema.exceptions import RefResolutionError, ValidationError -from . import constants, generic_io, reference, tagged, treeutil, util, versioning, yamlutil +from . import _itertree, constants, generic_io, reference, tagged, treeutil, util, versioning, yamlutil from .config import get_config from .exceptions import AsdfDeprecationWarning, AsdfWarning from .util import _patched_urllib_parse @@ -471,11 +471,10 @@ def _load_schema_cached(url, resolver, resolve_references): if resolve_references: - def resolve_refs(node, json_id): - if json_id is None: - json_id = url - + def resolve_refs(node, edge): if isinstance(node, dict) and "$ref" in node: + json_id = treeutil._get_json_id(schema, edge) or url + suburl_base, suburl_fragment = _safe_resolve(resolver, json_id, node["$ref"]) if suburl_base == url or suburl_base == schema.get("id"): @@ -485,10 +484,13 @@ def resolve_refs(node, json_id): subschema = load_schema(suburl_base, resolver, True) return reference.resolve_fragment(subschema, suburl_fragment) - return node - schema = treeutil.walk_and_modify(schema, resolve_refs) + # We need to copy here so that we don't end up with a recursive tree. + # When full resolution results in a recursive tree the returned recursive + # schema would cause check_schema to fail. This means some local $ref + # instances are not resolved. + schema = _itertree.leaf_first_modify_and_copy(schema, resolve_refs) return schema diff --git a/asdf/search.py b/asdf/search.py index 08f0950a6..531183d45 100644 --- a/asdf/search.py +++ b/asdf/search.py @@ -6,14 +6,21 @@ import re import typing +from . import _itertree from ._display import DEFAULT_MAX_COLS, DEFAULT_MAX_ROWS, DEFAULT_SHOW_VALUES, format_faint, format_italic, render_tree from ._node_info import NodeSchemaInfo, collect_schema_info -from .treeutil import get_children, is_container +from .treeutil import _is_container from .util import NotSet __all__ = ["AsdfSearchResult"] +def _get_children(obj): + if hasattr(obj, "__asdf_traverse__"): + obj = obj.__asdf_traverse__() + return _itertree._default_get_children(obj) + + class AsdfSearchResult: """ Result of a call to AsdfFile.search. @@ -184,7 +191,7 @@ def _filter(node, identifier): return False if isinstance(value, typing.Pattern): - if is_container(node): + if _is_container(node): # The string representation of a container object tends to # include the child object values, but that's probably not # what searchers want. @@ -222,11 +229,9 @@ def replace(self, value): """ results = [] - def _callback(identifiers, parent, node, children): - if all(f(node, identifiers[-1]) for f in self._filters): - results.append((identifiers[-1], parent)) - - _walk_tree_breadth_first(self._identifiers, self._node, _callback) + for node, edge in _itertree.breadth_first(self._node, get_children=_get_children): + if all(f(node, edge.key) for f in self._filters): + results.append((edge.key, edge.parent.node)) for identifier, parent in results: parent[identifier] = value @@ -284,11 +289,10 @@ def nodes(self): """ results = [] - def _callback(identifiers, parent, node, children): - if all(f(node, identifiers[-1]) for f in self._filters): + for node, edge in _itertree.breadth_first(self._node, get_children=_get_children): + if all(f(node, edge.key) for f in self._filters): results.append(node) - _walk_tree_breadth_first(self._identifiers, self._node, _callback) return results @property @@ -303,11 +307,10 @@ def paths(self): """ results = [] - def _callback(identifiers, parent, node, children): - if all(f(node, identifiers[-1]) for f in self._filters): - results.append(_build_path(identifiers)) + for node, edge in _itertree.breadth_first(self._node, get_children=_get_children): + if all(f(node, edge.key) for f in self._filters): + results.append(_build_path(self._identifiers + list(_itertree.edge_to_keys(edge)))) - _walk_tree_breadth_first(self._identifiers, self._node, _callback) return results def __repr__(self): @@ -369,31 +372,6 @@ def __getitem__(self, key): ) -def _walk_tree_breadth_first(root_identifiers, root_node, callback): - """ - Walk the tree in breadth-first order (useful for prioritizing - lower-depth nodes). - """ - current_nodes = [(root_identifiers, None, root_node)] - seen = set() - while True: - next_nodes = [] - - for identifiers, parent, node in current_nodes: - if (isinstance(node, (dict, list, tuple)) or NodeSchemaInfo.traversable(node)) and id(node) in seen: - continue - tnode = node.__asdf_traverse__() if NodeSchemaInfo.traversable(node) else node - children = get_children(tnode) - callback(identifiers, parent, node, [c for _, c in children]) - next_nodes.extend([([*identifiers, i], node, c) for i, c in children]) - seen.add(id(node)) - - if len(next_nodes) == 0: - break - - current_nodes = next_nodes - - def _build_path(identifiers): """ Generate the Python code needed to extract the identified node. diff --git a/asdf/treeutil.py b/asdf/treeutil.py index 27b2989cb..5e0ef153e 100644 --- a/asdf/treeutil.py +++ b/asdf/treeutil.py @@ -2,12 +2,10 @@ Utility functions for managing tree-like data structures. """ -import types import warnings -from contextlib import contextmanager -from . import tagged -from .exceptions import AsdfWarning +from . import _itertree, tagged +from .exceptions import AsdfDeprecationWarning, AsdfWarning __all__ = ["walk", "iter_tree", "walk_and_modify", "get_children", "is_container", "PendingValue", "RemoveNode"] @@ -57,139 +55,8 @@ def iter_tree(top): tree : object The modified tree. """ - seen = set() - - def recurse(tree): - tree_id = id(tree) - - if tree_id in seen: - return - - if isinstance(tree, (list, tuple)): - seen.add(tree_id) - for val in tree: - yield from recurse(val) - seen.remove(tree_id) - elif isinstance(tree, dict): - seen.add(tree_id) - for val in tree.values(): - yield from recurse(val) - seen.remove(tree_id) - - yield tree - - return recurse(top) - - -class _TreeModificationContext: - """ - Context of a call to walk_and_modify, which includes a map - of already modified nodes, a list of generators to drain - before exiting the call, and a set of node object ids that - are currently pending modification. - - Instances of this class are context managers that track - how many times they have been entered, and only drain - generators and reset themselves when exiting the outermost - context. They are also collections that map unmodified - nodes to the corresponding modified result. - """ - - def __init__(self): - self._map = {} - self._generators = [] - self._depth = 0 - self._pending = set() - - def add_generator(self, generator): - """ - Add a generator that should be drained before exiting - the outermost call to walk_and_modify. - """ - self._generators.append(generator) - - def is_pending(self, node): - """ - Return True if the node is already being modified. - This will not be the case unless the node contains a - reference to itself somewhere among its descendents. - """ - return id(node) in self._pending - - @contextmanager - def pending(self, node): - """ - Context manager that marks a node as pending for the - duration of the context. - """ - if id(node) in self._pending: - msg = ( - "Unhandled cycle in tree. This is possibly a bug " - "in extension code, which should be yielding " - "nodes that may contain reference cycles." - ) - raise RuntimeError(msg) - - self._pending.add(id(node)) - try: - yield self - finally: - self._pending.remove(id(node)) - - def __enter__(self): - self._depth += 1 - return self - - def __exit__(self, exc_type, exc_value, traceback): - self._depth -= 1 - - if self._depth == 0: - # If we're back to 0 depth, then we're exiting - # the outermost context, so it's time to drain - # the generators and reset this object for next - # time. - if exc_type is None: - self._drain_generators() - self._generators = [] - self._map = {} - self._pending = set() - - def _drain_generators(self): - """ - Drain each generator we've accumulated during this - call to walk_and_modify. - """ - # Generator code may add yet more generators - # to the list, so we need to loop until the - # list is empty. - while len(self._generators) > 0: - generators = self._generators - self._generators = [] - for generator in generators: - for _ in generator: - # Subsequent yields of the generator should - # always return the same value. What we're - # really doing here is executing the generator's - # remaining code, to further modify that first - # yielded object. - pass - - def __contains__(self, node): - return id(node) in self._map - - def __getitem__(self, node): - return self._map[id(node)][1] - - def __setitem__(self, node, result): - if id(node) in self._map: - # This indicates that an already defined - # modified node is being replaced, which is an - # error because it breaks references within the - # tree. - msg = "Node already has an associated result" - raise RuntimeError(msg) - - self._map[id(node)] = (node, result) + for node, edge in _itertree.depth_first(top): + yield node class _PendingValue: @@ -206,23 +73,48 @@ def __repr__(self): PendingValue = _PendingValue() -class _RemoveNode: - """ - Class of the RemoveNode singleton instance. This instance is used - as a signal for `asdf.treeutil.walk_and_modify` to remove the - node received by the callback. - """ - - def __repr__(self): - return "RemoveNode" - - -RemoveNode = _RemoveNode() - - -def walk_and_modify(top, callback, ignore_implicit_conversion=False, postorder=True, _context=None): +RemoveNode = _itertree.RemoveNode + + +def _get_json_id(top, edge): + keys = [] + while edge and edge.key is not None: + keys.append(edge.key) + edge = edge.parent + node = top + if hasattr(node, "get") and isinstance(node.get("id", None), str): + json_id = node["id"] + else: + json_id = None + for key in keys[::-1]: + if hasattr(node, "get") and isinstance(node.get("id", None), str): + json_id = node["id"] + node = node[key] + return json_id + + +def _container_factory(obj): + if isinstance(obj, tagged.TaggedDict): + result = tagged.TaggedDict({k: None for k in obj}) + result._tag = obj._tag + elif isinstance(obj, tagged.TaggedList): + result = tagged.TaggedList([None] * len(obj)) + result._tag = obj._tag + elif isinstance(obj, dict): + result = obj.__class__({k: None for k in obj}) + elif isinstance(obj, list): + result = obj.__class__([None] * len(obj)) + elif isinstance(obj, tuple): + result = [None] * len(obj) + else: + raise NotImplementedError() + return result + + +def walk_and_modify(top, callback, ignore_implicit_conversion=False, postorder=True): """Modify a tree by walking it with a callback function. It also has - the effect of doing a deep copy. + the effect of doing a copy. Only "containers" (dict, list, etc) will be + copied, all "leaf" nodes will not be copied. Parameters ---------- @@ -234,14 +126,16 @@ def walk_and_modify(top, callback, ignore_implicit_conversion=False, postorder=T one or two arguments: - an instance from the tree - - a json id (optional) + - a json id (optional) DEPRECATED It may return a different instance in order to modify the tree. If the singleton instance `~asdf.treeutil._RemoveNode` is returned, the node will be removed from the tree. - The json id is the context under which any relative URLs - should be resolved. It may be `None` if no ids are in the file + The json id optional argument is deprecated. This function + will no longer track ids. The json id is the context under which + any relative URLs should be resolved. It may be `None` if no + ids are in the file The tree is traversed depth-first, with order specified by the ``postorder`` argument. @@ -265,162 +159,37 @@ def walk_and_modify(top, callback, ignore_implicit_conversion=False, postorder=T The modified tree. """ - callback_arity = callback.__code__.co_argcount - if callback_arity < 1 or callback_arity > 2: - msg = "Expected callback to accept one or two arguments" - raise ValueError(msg) - - def _handle_generator(result): - # If the result is a generator, generate one value to - # extract the true result, then register the generator - # to be drained later. - if isinstance(result, types.GeneratorType): - generator = result - result = next(generator) - _context.add_generator(generator) - - return result - - def _handle_callback(node, json_id): - result = callback(node) if callback_arity == 1 else callback(node, json_id) - - return _handle_generator(result) - - def _handle_mapping(node, json_id): - result = node.__class__() - if isinstance(node, tagged.Tagged): - result._tag = node._tag - - pending_items = {} - for key, value in node.items(): - if _context.is_pending(value): - # The child node is pending modification, which means - # it must be its own ancestor. Assign the special - # PendingValue instance for now, and note that we'll - # need to fill in the real value later. - pending_items[key] = value - result[key] = PendingValue - - elif (val := _recurse(value, json_id)) is not RemoveNode: - result[key] = val - - yield result - - if len(pending_items) > 0: - # Now that we've yielded, the pending children should - # be available. - for key, value in pending_items.items(): - if (val := _recurse(value, json_id)) is not RemoveNode: - result[key] = val - else: - # The callback may have decided to delete - # this node after all. - del result[key] - - def _handle_mutable_sequence(node, json_id): - result = node.__class__() - if isinstance(node, tagged.Tagged): - result._tag = node._tag - - pending_items = {} - for i, value in enumerate(node): - if _context.is_pending(value): - # The child node is pending modification, which means - # it must be its own ancestor. Assign the special - # PendingValue instance for now, and note that we'll - # need to fill in the real value later. - pending_items[i] = value - result.append(PendingValue) - else: - result.append(_recurse(value, json_id)) - - yield result - - for i, value in pending_items.items(): - # Now that we've yielded, the pending children should - # be available. - result[i] = _recurse(value, json_id) - - def _handle_immutable_sequence(node, json_id): - # Immutable sequences containing themselves are impossible - # to construct (well, maybe possible in a C extension, but - # we're not going to worry about that), so we don't need - # to yield here. - contents = [_recurse(value, json_id) for value in node] - - try: - result = node.__class__(contents) - if isinstance(node, tagged.Tagged): - result._tag = node._tag - except TypeError: - # The derived class signature is different, so simply store the - # list representing the contents. Currently this is primarily - # intended to handle namedtuple and NamedTuple instances. - if not ignore_implicit_conversion: - warnings.warn(f"Failed to serialize instance of {type(node)}, converting to list instead", AsdfWarning) - result = contents - - return result - - def _handle_children(node, json_id): - if isinstance(node, dict): - result = _handle_mapping(node, json_id) - elif isinstance(node, tuple): - result = _handle_immutable_sequence(node, json_id) - elif isinstance(node, list): - result = _handle_mutable_sequence(node, json_id) - else: - result = node - - return _handle_generator(result) - - def _recurse(node, json_id=None): - if node in _context: - # The node's modified result has already been - # created, all we need to do is return it. This - # occurs when the tree contains multiple references - # to the same object id. - return _context[node] - - # Inform the context that we're going to start modifying - # this node. - with _context.pending(node): - # Take note of the "id" field, in case we're modifying - # a schema and need to know the namespace for resolving - # URIs. Ignore an id that is not a string, since it may - # be an object defining an id property and not an id - # itself (this is common in metaschemas). - if isinstance(node, dict) and "id" in node and isinstance(node["id"], str): - json_id = node["id"] - - if postorder: - # If this is a postorder modification, invoke the - # callback on this node's children first. - result = _handle_children(node, json_id) - result = _handle_callback(result, json_id) - else: - # Otherwise, invoke the callback on the node first, - # then its children. - result = _handle_callback(node, json_id) - result = _handle_children(result, json_id) - - # Store the result in the context, in case there are - # additional references to the same node elsewhere in - # the tree. - _context[node] = result - - return result - - if _context is None: - _context = _TreeModificationContext() - - with _context: - return _recurse(top) - # Generators will be drained here, if this is the outermost - # call to walk_and_modify. + if postorder: + modify = _itertree.leaf_first_modify_and_copy + else: + modify = _itertree.depth_first_modify_and_copy -def get_children(node): + if callback.__code__.co_argcount == 2: + warnings.warn("the json_id callback argument is deprecated", AsdfDeprecationWarning) + + def wrapped_callback(obj, edge): + json_id = _get_json_id(top, edge) + return callback(obj, json_id) + + else: + + def wrapped_callback(obj, edge): + return callback(obj) + + if ignore_implicit_conversion: + container_factory = _container_factory + else: + + def container_factory(obj): + if isinstance(obj, tuple) and type(obj) != tuple: + warnings.warn(f"Failed to serialize instance of {type(obj)}, converting to list instead", AsdfWarning) + return _container_factory(obj) + + return modify(top, wrapped_callback, container_factory=container_factory) + + +def _get_children(node): """ Retrieve the children (and their dict keys or list/tuple indices) of an ASDF tree node. @@ -446,7 +215,15 @@ def get_children(node): return [] -def is_container(node): +def get_children(node): + warnings.warn("asdf.treeutil.get_children is deprecated", AsdfDeprecationWarning) + return _get_children(node) + + +get_children.__doc__ = _get_children.__doc__ + + +def _is_container(node): """ Determine if an ASDF tree node is an instance of a "container" type (i.e., value may contain child nodes). @@ -462,3 +239,11 @@ def is_container(node): True if node is a container, False otherwise """ return isinstance(node, (dict, list, tuple)) + + +def is_container(node): + warnings.warn("asdf.treeutil.is_container is deprecated", AsdfDeprecationWarning) + return _is_container(node) + + +is_container.__doc__ = _is_container.__doc__ diff --git a/asdf/yamlutil.py b/asdf/yamlutil.py index e4b1d5049..2262de692 100644 --- a/asdf/yamlutil.py +++ b/asdf/yamlutil.py @@ -1,5 +1,5 @@ import warnings -from collections import OrderedDict +from collections import OrderedDict, deque from types import GeneratorType import numpy as np @@ -221,6 +221,8 @@ def custom_tree_to_tagged_tree(tree, ctx, _serialization_context=None): extension_manager = _serialization_context.extension_manager + generators = deque() + def _convert_obj(obj, converter): tag = converter.select_tag(obj, _serialization_context) # if select_tag returns None, converter.to_yaml_tree should return a new @@ -233,8 +235,7 @@ def _convert_obj(obj, converter): converter = extension_manager.get_converter_for_type(type(obj)) except KeyError: # no converter supports this type, return it as-is - yield obj - return + return obj if converter in converters_used: msg = "Conversion cycle detected" raise TypeError(msg) @@ -244,10 +245,8 @@ def _convert_obj(obj, converter): _serialization_context.assign_blocks() if isinstance(node, GeneratorType): - generator = node + generators.append(generator) node = next(generator) - else: - generator = None if isinstance(node, dict): tagged_node = tagged.TaggedDict(node, tag) @@ -262,9 +261,7 @@ def _convert_obj(obj, converter): _serialization_context._mark_extension_used(converter.extension) - yield tagged_node - if generator is not None: - yield from generator + return tagged_node cfg = config.get_config() convert_ndarray_subclasses = cfg.convert_unknown_ndarray_subclasses @@ -292,15 +289,18 @@ def _walker(obj): converters_cache[typ] = lambda obj: obj return obj - return treeutil.walk_and_modify( + new_tree = treeutil.walk_and_modify( tree, _walker, ignore_implicit_conversion=ctx._ignore_implicit_conversion, # Walk the tree in preorder, so that extensions can return # container nodes with unserialized children. postorder=False, - _context=ctx._tree_modification_context, ) + for generator in generators: + for _ in generator: + pass + return new_tree def tagged_tree_to_custom_tree(tree, ctx, force_raw_types=False, _serialization_context=None): @@ -312,6 +312,7 @@ def tagged_tree_to_custom_tree(tree, ctx, force_raw_types=False, _serialization_ _serialization_context = ctx._create_serialization_context(BlockAccess.READ) extension_manager = _serialization_context.extension_manager + generators = deque() def _walker(node): if force_raw_types: @@ -327,6 +328,9 @@ def _walker(node): _serialization_context.assign_object(obj) _serialization_context.assign_blocks() _serialization_context._mark_extension_used(converter.extension) + if isinstance(obj, GeneratorType): + generators.append(obj) + obj = next(obj) return obj if not ctx._ignore_unrecognized_tag: @@ -336,15 +340,18 @@ def _walker(node): ) return node - return treeutil.walk_and_modify( + new_tree = treeutil.walk_and_modify( tree, _walker, ignore_implicit_conversion=ctx._ignore_implicit_conversion, # Walk the tree in postorder, so that extensions receive # container nodes with children already deserialized. postorder=True, - _context=ctx._tree_modification_context, ) + for generator in generators: + for _ in generator: + pass + return new_tree def load_tree(stream): diff --git a/docs/asdf/deprecations.rst b/docs/asdf/deprecations.rst index 547ac9296..f975b2cec 100644 --- a/docs/asdf/deprecations.rst +++ b/docs/asdf/deprecations.rst @@ -20,6 +20,13 @@ Automatic calling of ``AsdfFile.find_references`` during calls to ``AsdfFile.__init__`` and ``asdf.open``. Call ``AsdfFile.find_references`` to find references. +``asdf.treeutil.get_children`` and ``asdf.treeutil.is_container`` are deprecated. +These were never intended to be public. + +The support for callbacks with a ``json_id`` to ``asdf.treeutil.walk_and_modify`` +is deprecated. Please use a different library or your own code to track ids during +tree modifications. + Version 3.0 ===========