Skip to content

Commit

Permalink
Include full path in warning
Browse files Browse the repository at this point in the history
  • Loading branch information
SimonHeybrock committed Nov 8, 2023
1 parent 9ce7743 commit 5bc936e
Show file tree
Hide file tree
Showing 2 changed files with 57 additions and 30 deletions.
31 changes: 25 additions & 6 deletions src/scippnexus/nxtransformations.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from __future__ import annotations

import warnings
from dataclasses import dataclass
from typing import Dict, List, Optional, Tuple, Union

import numpy as np
Expand Down Expand Up @@ -211,9 +212,24 @@ class ChainError(KeyError):

pass

def __init__(self, stack: List[sc.DataGroup]):
@dataclass
class Entry:
name: str
value: sc.DataGroup

def __init__(self, stack: List[TransformationChainResolver.Entry]):
self._stack = stack

@staticmethod
def from_root(dg: sc.DataGroup) -> TransformationChainResolver:
return TransformationChainResolver(
[TransformationChainResolver.Entry(name='', value=dg)]
)

@property
def name(self) -> str:
return '/'.join([e.name for e in self._stack])

@property
def root(self) -> TransformationChainResolver:
return TransformationChainResolver(self._stack[0:1])
Expand All @@ -228,7 +244,7 @@ def parent(self) -> TransformationChainResolver:

@property
def value(self) -> sc.DataGroup:
return self._stack[-1]
return self._stack[-1].value

def __getitem__(self, path: str) -> TransformationChainResolver:
base, *remainder = path.split('/', maxsplit=1)
Expand All @@ -240,12 +256,15 @@ def __getitem__(self, path: str) -> TransformationChainResolver:
node = self.parent
else:
try:
child = self._stack[-1][base]
child = self._stack[-1].value[base]
except KeyError:
raise TransformationChainResolver.ChainError(
f"Transformation depends on non-existing node '{base}'"
f"{base} not found in {self.name}"
)
node = TransformationChainResolver(self._stack + [child])
node = TransformationChainResolver(
self._stack
+ [TransformationChainResolver.Entry(name=base, value=child)]
)
return node if len(remainder) == 0 else node[remainder[0]]

def resolve_depends_on(self) -> Optional[Union[sc.DataArray, sc.Variable]]:
Expand Down Expand Up @@ -328,7 +347,7 @@ def compute_positions(
"""
# Create resolver at root level, since any depends_on chain may lead to a parent,
# i.e., we cannot use a resolver at the level of each chain's entry point.
resolver = TransformationChainResolver([dg])
resolver = TransformationChainResolver.from_root(dg)
return _with_positions(
dg,
store_position=store_position,
Expand Down
56 changes: 32 additions & 24 deletions tests/nxtransformations_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -435,7 +435,7 @@ def test_slice_transformations(h5root):


def test_TransformationChainResolver_path_handling():
tree = TransformationChainResolver([{'a': {'b': {'c': 1}}}])
tree = TransformationChainResolver.from_root({'a': {'b': {'c': 1}}})
assert tree['a']['b']['c'].value == 1
assert tree['a/b/c'].value == 1
assert tree['/a/b/c'].value == 1
Expand All @@ -444,14 +444,24 @@ def test_TransformationChainResolver_path_handling():
assert tree['a/b']['./c'].value == 1


def test_TransformationChainResolver_name():
tree = TransformationChainResolver.from_root({'a': {'b': {'c': 1}}})
assert tree['a']['b']['c'].name == '/a/b/c'
assert tree['a/b/c'].name == '/a/b/c'
assert tree['/a/b/c'].name == '/a/b/c'
assert tree['a']['../a/b/c'].name == '/a/b/c'
assert tree['a/b']['../../a/b/c'].name == '/a/b/c'
assert tree['a/b']['./c'].name == '/a/b/c'


def test_TransformationChainResolver_raises_ChainError_if_child_does_not_exists():
tree = TransformationChainResolver([{'a': {'b': {'c': 1}}}])
tree = TransformationChainResolver.from_root({'a': {'b': {'c': 1}}})
with pytest.raises(TransformationChainResolver.ChainError):
tree['a']['b']['d']


def test_TransformationChainResolver_raises_ChainError_if_path_leads_beyond_root():
tree = TransformationChainResolver([{'a': {'b': {'c': 1}}}])
tree = TransformationChainResolver.from_root({'a': {'b': {'c': 1}}})
with pytest.raises(TransformationChainResolver.ChainError):
tree['..']
with pytest.raises(TransformationChainResolver.ChainError):
Expand All @@ -466,21 +476,23 @@ def test_TransformationChainResolver_raises_ChainError_if_path_leads_beyond_root


def test_resolve_depends_on_dot():
tree = TransformationChainResolver([{'depends_on': '.'}])
tree = TransformationChainResolver.from_root({'depends_on': '.'})
assert sc.identical(tree.resolve_depends_on() * origin, origin)


def test_resolve_depends_on_child():
transform = sc.DataArray(shiftX, coords={'depends_on': sc.scalar('.')})
tree = TransformationChainResolver([{'depends_on': 'child', 'child': transform}])
tree = TransformationChainResolver.from_root(
{'depends_on': 'child', 'child': transform}
)
expected = sc.vector([1, 0, 0], unit='m')
assert sc.identical(tree.resolve_depends_on() * origin, expected)


def test_resolve_depends_on_grandchild():
transform = sc.DataArray(shiftX, coords={'depends_on': sc.scalar('.')})
tree = TransformationChainResolver(
[{'depends_on': 'child/grandchild', 'child': {'grandchild': transform}}]
tree = TransformationChainResolver.from_root(
{'depends_on': 'child/grandchild', 'child': {'grandchild': transform}}
)
expected = sc.vector([1, 0, 0], unit='m')
assert sc.identical(tree.resolve_depends_on() * origin, expected)
Expand All @@ -489,8 +501,8 @@ def test_resolve_depends_on_grandchild():
def test_resolve_depends_on_child1_depends_on_child2():
transform1 = sc.DataArray(shiftX, coords={'depends_on': sc.scalar('child2')})
transform2 = sc.DataArray(rotZ, coords={'depends_on': sc.scalar('.')})
tree = TransformationChainResolver(
[{'depends_on': 'child1', 'child1': transform1, 'child2': transform2}]
tree = TransformationChainResolver.from_root(
{'depends_on': 'child1', 'child1': transform1, 'child2': transform2}
)
# Note order
expected = transform2.data * transform1.data
Expand All @@ -500,13 +512,11 @@ def test_resolve_depends_on_child1_depends_on_child2():
def test_resolve_depends_on_grandchild1_depends_on_grandchild2():
transform1 = sc.DataArray(shiftX, coords={'depends_on': sc.scalar('grandchild2')})
transform2 = sc.DataArray(rotZ, coords={'depends_on': sc.scalar('.')})
tree = TransformationChainResolver(
[
{
'depends_on': 'child/grandchild1',
'child': {'grandchild1': transform1, 'grandchild2': transform2},
}
]
tree = TransformationChainResolver.from_root(
{
'depends_on': 'child/grandchild1',
'child': {'grandchild1': transform1, 'grandchild2': transform2},
}
)
expected = transform2.data * transform1.data
assert sc.identical(tree.resolve_depends_on(), expected)
Expand All @@ -515,14 +525,12 @@ def test_resolve_depends_on_grandchild1_depends_on_grandchild2():
def test_resolve_depends_on_grandchild1_depends_on_child2():
transform1 = sc.DataArray(shiftX, coords={'depends_on': sc.scalar('../child2')})
transform2 = sc.DataArray(rotZ, coords={'depends_on': sc.scalar('.')})
tree = TransformationChainResolver(
[
{
'depends_on': 'child1/grandchild1',
'child1': {'grandchild1': transform1},
'child2': transform2,
}
]
tree = TransformationChainResolver.from_root(
{
'depends_on': 'child1/grandchild1',
'child1': {'grandchild1': transform1},
'child2': transform2,
}
)
expected = transform2.data * transform1.data
assert sc.identical(tree.resolve_depends_on(), expected)
Expand Down

0 comments on commit 5bc936e

Please sign in to comment.