diff --git a/src/scippnexus/nxtransformations.py b/src/scippnexus/nxtransformations.py index 44b0be38..1628ba05 100644 --- a/src/scippnexus/nxtransformations.py +++ b/src/scippnexus/nxtransformations.py @@ -3,6 +3,8 @@ # @author Simon Heybrock from __future__ import annotations +import warnings +from dataclasses import dataclass from typing import Dict, List, Optional, Tuple, Union import numpy as np @@ -210,9 +212,24 @@ class ChainError(KeyError): pass - def __init__(self, stack: List[sc.DataGroup]): + @dataclass + class Entry: + name: str + value: sc.DataGroup + + def __init__(self, stack: List[TransformationChainResolver.Entry]): self._stack = stack + @staticmethod + def from_root(dg: sc.DataGroup) -> TransformationChainResolver: + return TransformationChainResolver( + [TransformationChainResolver.Entry(name='', value=dg)] + ) + + @property + def name(self) -> str: + return '/'.join([e.name for e in self._stack]) + @property def root(self) -> TransformationChainResolver: return TransformationChainResolver(self._stack[0:1]) @@ -227,7 +244,7 @@ def parent(self) -> TransformationChainResolver: @property def value(self) -> sc.DataGroup: - return self._stack[-1] + return self._stack[-1].value def __getitem__(self, path: str) -> TransformationChainResolver: base, *remainder = path.split('/', maxsplit=1) @@ -238,7 +255,16 @@ def __getitem__(self, path: str) -> TransformationChainResolver: elif base == '..': node = self.parent else: - node = TransformationChainResolver(self._stack + [self._stack[-1][base]]) + try: + child = self._stack[-1].value[base] + except KeyError: + raise TransformationChainResolver.ChainError( + f"{base} not found in {self.name}" + ) + node = TransformationChainResolver( + self._stack + + [TransformationChainResolver.Entry(name=base, value=child)] + ) return node if len(remainder) == 0 else node[remainder[0]] def resolve_depends_on(self) -> Optional[Union[sc.DataArray, sc.Variable]]: @@ -321,7 +347,7 @@ def compute_positions( """ # Create resolver at root level, since any depends_on chain may lead to a parent, # i.e., we cannot use a resolver at the level of each chain's entry point. - resolver = TransformationChainResolver([dg]) + resolver = TransformationChainResolver.from_root(dg) return _with_positions( dg, store_position=store_position, @@ -375,8 +401,10 @@ def _with_positions( out[store_position] = transform * sc.vector([0, 0, 0], unit='m') if store_transform is not None: out[store_transform] = transform - except TransformationChainResolver.ChainError: - pass + except TransformationChainResolver.ChainError as e: + warnings.warn( + UserWarning(f'depends_on chain references missing node:\n{e}') + ) for name, value in dg.items(): if isinstance(value, sc.DataGroup): value = _with_positions( diff --git a/tests/nxtransformations_test.py b/tests/nxtransformations_test.py index 52ceb69a..0f4406d4 100644 --- a/tests/nxtransformations_test.py +++ b/tests/nxtransformations_test.py @@ -435,7 +435,7 @@ def test_slice_transformations(h5root): def test_TransformationChainResolver_path_handling(): - tree = TransformationChainResolver([{'a': {'b': {'c': 1}}}]) + tree = TransformationChainResolver.from_root({'a': {'b': {'c': 1}}}) assert tree['a']['b']['c'].value == 1 assert tree['a/b/c'].value == 1 assert tree['/a/b/c'].value == 1 @@ -444,19 +444,29 @@ def test_TransformationChainResolver_path_handling(): assert tree['a/b']['./c'].value == 1 -def test_TransformationChainResolver_raises_if_child_does_not_exists(): - tree = TransformationChainResolver([{'a': {'b': {'c': 1}}}]) - with pytest.raises(KeyError): +def test_TransformationChainResolver_name(): + tree = TransformationChainResolver.from_root({'a': {'b': {'c': 1}}}) + assert tree['a']['b']['c'].name == '/a/b/c' + assert tree['a/b/c'].name == '/a/b/c' + assert tree['/a/b/c'].name == '/a/b/c' + assert tree['a']['../a/b/c'].name == '/a/b/c' + assert tree['a/b']['../../a/b/c'].name == '/a/b/c' + assert tree['a/b']['./c'].name == '/a/b/c' + + +def test_TransformationChainResolver_raises_ChainError_if_child_does_not_exists(): + tree = TransformationChainResolver.from_root({'a': {'b': {'c': 1}}}) + with pytest.raises(TransformationChainResolver.ChainError): tree['a']['b']['d'] -def test_TransformationChainResolver_raises_if_path_leads_beyond_root(): - tree = TransformationChainResolver([{'a': {'b': {'c': 1}}}]) - with pytest.raises(KeyError): +def test_TransformationChainResolver_raises_ChainError_if_path_leads_beyond_root(): + tree = TransformationChainResolver.from_root({'a': {'b': {'c': 1}}}) + with pytest.raises(TransformationChainResolver.ChainError): tree['..'] - with pytest.raises(KeyError): + with pytest.raises(TransformationChainResolver.ChainError): tree['a']['../..'] - with pytest.raises(KeyError): + with pytest.raises(TransformationChainResolver.ChainError): tree['../a'] @@ -466,21 +476,23 @@ def test_TransformationChainResolver_raises_if_path_leads_beyond_root(): def test_resolve_depends_on_dot(): - tree = TransformationChainResolver([{'depends_on': '.'}]) + tree = TransformationChainResolver.from_root({'depends_on': '.'}) assert sc.identical(tree.resolve_depends_on() * origin, origin) def test_resolve_depends_on_child(): transform = sc.DataArray(shiftX, coords={'depends_on': sc.scalar('.')}) - tree = TransformationChainResolver([{'depends_on': 'child', 'child': transform}]) + tree = TransformationChainResolver.from_root( + {'depends_on': 'child', 'child': transform} + ) expected = sc.vector([1, 0, 0], unit='m') assert sc.identical(tree.resolve_depends_on() * origin, expected) def test_resolve_depends_on_grandchild(): transform = sc.DataArray(shiftX, coords={'depends_on': sc.scalar('.')}) - tree = TransformationChainResolver( - [{'depends_on': 'child/grandchild', 'child': {'grandchild': transform}}] + tree = TransformationChainResolver.from_root( + {'depends_on': 'child/grandchild', 'child': {'grandchild': transform}} ) expected = sc.vector([1, 0, 0], unit='m') assert sc.identical(tree.resolve_depends_on() * origin, expected) @@ -489,8 +501,8 @@ def test_resolve_depends_on_grandchild(): def test_resolve_depends_on_child1_depends_on_child2(): transform1 = sc.DataArray(shiftX, coords={'depends_on': sc.scalar('child2')}) transform2 = sc.DataArray(rotZ, coords={'depends_on': sc.scalar('.')}) - tree = TransformationChainResolver( - [{'depends_on': 'child1', 'child1': transform1, 'child2': transform2}] + tree = TransformationChainResolver.from_root( + {'depends_on': 'child1', 'child1': transform1, 'child2': transform2} ) # Note order expected = transform2.data * transform1.data @@ -500,13 +512,11 @@ def test_resolve_depends_on_child1_depends_on_child2(): def test_resolve_depends_on_grandchild1_depends_on_grandchild2(): transform1 = sc.DataArray(shiftX, coords={'depends_on': sc.scalar('grandchild2')}) transform2 = sc.DataArray(rotZ, coords={'depends_on': sc.scalar('.')}) - tree = TransformationChainResolver( - [ - { - 'depends_on': 'child/grandchild1', - 'child': {'grandchild1': transform1, 'grandchild2': transform2}, - } - ] + tree = TransformationChainResolver.from_root( + { + 'depends_on': 'child/grandchild1', + 'child': {'grandchild1': transform1, 'grandchild2': transform2}, + } ) expected = transform2.data * transform1.data assert sc.identical(tree.resolve_depends_on(), expected) @@ -515,14 +525,12 @@ def test_resolve_depends_on_grandchild1_depends_on_grandchild2(): def test_resolve_depends_on_grandchild1_depends_on_child2(): transform1 = sc.DataArray(shiftX, coords={'depends_on': sc.scalar('../child2')}) transform2 = sc.DataArray(rotZ, coords={'depends_on': sc.scalar('.')}) - tree = TransformationChainResolver( - [ - { - 'depends_on': 'child1/grandchild1', - 'child1': {'grandchild1': transform1}, - 'child2': transform2, - } - ] + tree = TransformationChainResolver.from_root( + { + 'depends_on': 'child1/grandchild1', + 'child1': {'grandchild1': transform1}, + 'child2': transform2, + } ) expected = transform2.data * transform1.data assert sc.identical(tree.resolve_depends_on(), expected) @@ -615,6 +623,7 @@ def test_compute_positions_with_rotation(h5root): ) +@pytest.mark.filterwarnings("ignore:depends_on chain references missing node") def test_compute_positions_skips_for_path_beyond_root(h5root): instrument = snx.create_class(h5root, 'instrument', snx.NXinstrument) value = sc.scalar(6.5, unit='m') @@ -712,3 +721,13 @@ def test_compute_positions_does_not_apply_time_dependent_transform_to_pixel_offs assert 'position' not in result['detector_0']['data'].coords result = snx.compute_positions(loaded, store_transform='transform') assert_identical(result['detector_0']['transform'], t * offset) + + +def test_compute_positions_warns_if_depends_on_is_dead_link(h5root): + instrument = snx.create_class(h5root, 'instrument', snx.NXinstrument) + detector = create_detector(instrument) + snx.create_field(detector, 'depends_on', sc.scalar('transform')) + root = make_group(h5root) + loaded = root[()] + with pytest.warns(UserWarning, match='depends_on chain references missing node'): + snx.compute_positions(loaded)