Skip to content

Commit

Permalink
Merge pull request #173 from scipp/parse-depends-on
Browse files Browse the repository at this point in the history
Add helper to compute positions from depends_on
  • Loading branch information
SimonHeybrock authored Nov 7, 2023
2 parents d5241dc + 0a4cab2 commit 80affde
Show file tree
Hide file tree
Showing 5 changed files with 523 additions and 73 deletions.
1 change: 1 addition & 0 deletions docs/api-reference/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,7 @@
:toctree: ../generated/functions
:recursive:
compute_positions
create_field
create_class
```
Expand Down
1 change: 1 addition & 0 deletions src/scippnexus/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,3 +21,4 @@
from .field import Attrs, Field
from .file import File
from .nexus_classes import *
from .nxtransformations import compute_positions, zip_pixel_offsets
15 changes: 9 additions & 6 deletions src/scippnexus/nxdata.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

import uuid
from functools import cached_property
from itertools import chain
from typing import Any, Dict, Iterable, List, Optional, Tuple, Union

import numpy as np
Expand Down Expand Up @@ -373,12 +374,14 @@ def _assemble_as_physical_component(
"""
data_items = sc.DataGroup()
result = sc.DataGroup()
forward = (
[self._signal_name]
+ list(self._group_dims or [])
+ self._aux_signals
+ self._explicit_coords
+ allow_in_coords
forward = list(
chain(
[self._signal_name],
self._group_dims or [],
self._aux_signals,
self._explicit_coords,
allow_in_coords,
)
)

for name, value in dg.items():
Expand Down
271 changes: 218 additions & 53 deletions src/scippnexus/nxtransformations.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
# @author Simon Heybrock
from __future__ import annotations

from typing import List, Optional, Tuple, Union
from typing import Dict, List, Optional, Tuple, Union

import numpy as np
import scipp as sc
Expand All @@ -17,20 +17,12 @@ class TransformationError(NexusStructureError):
pass


def make_transformation(obj, /, path) -> Optional[Transformation]:
if path.startswith('/'):
return obj.file[path]
elif path != '.':
return obj.parent[path]
return None # end of chain


class NXtransformations(NXobject):
"""Group of transformations."""


class Transformation:
def __init__(self, obj: Union[Field, NXobject]): # could be an NXlog
def __init__(self, obj: Union[Field, Group]): # could be an NXlog
self._obj = obj

@property
Expand All @@ -53,12 +45,6 @@ def attrs(self):
def name(self):
return self._obj.name

@property
def depends_on(self):
if (path := self.attrs.get('depends_on')) is not None:
return make_transformation(self._obj, path)
return None

@property
def offset(self):
if (offset := self.attrs.get('offset')) is None:
Expand Down Expand Up @@ -117,7 +103,7 @@ def make_transformation(
if (depends_on := self.attrs.get('depends_on')) is not None:
if not isinstance(transform, sc.DataArray):
transform = sc.DataArray(transform)
transform.attrs['depends_on'] = sc.scalar(
transform.coords['depends_on'] = sc.scalar(
depends_on_to_relative_path(depends_on, self._obj.parent.name)
)
return transform
Expand Down Expand Up @@ -146,26 +132,17 @@ def _smaller_unit(a, b):
return b.unit


def get_full_transformation(
depends_on: Field,
) -> Union[None, sc.DataArray, sc.Variable]:
"""
Get the 4x4 transformation matrix for a component, resulting
from the full chain of transformations linked by "depends_on"
attributes
def combine_transformations(
chain: List[Union[sc.DataArray, sc.Variable]]
) -> Union[sc.DataArray, sc.Variable]:
"""
if (t0 := make_transformation(depends_on, depends_on[()])) is None:
return None
return get_full_transformation_starting_at(t0)


def get_full_transformation_starting_at(
t0: Transformation, *, index: ScippIndex = None
) -> Union[None, sc.DataArray, sc.Variable]:
transformations = _get_transformations(t0, index=() if index is None else index)
Take the product of a chain of transformations, handling potentially mismatching
time-dependence.
Time-dependent transformations are interpolated to a common time-coordinate.
"""
total_transform = None
for transform in transformations:
for transform in chain:
if total_transform is None:
total_transform = transform
elif isinstance(total_transform, sc.DataArray) and isinstance(
Expand All @@ -190,26 +167,11 @@ def get_full_transformation_starting_at(
else:
total_transform = transform * total_transform
if isinstance(total_transform, sc.DataArray):
time_dependent = [t for t in transformations if isinstance(t, sc.DataArray)]
time_dependent = [t for t in chain if isinstance(t, sc.DataArray)]
times = [da.coords['time'][0] for da in time_dependent]
latest_log_start = sc.reduce(times).max()
return total_transform['time', latest_log_start:].copy()
return total_transform


def _get_transformations(
transform: Transformation, *, index: ScippIndex
) -> List[Union[sc.DataArray, sc.Variable]]:
"""Get all transformations in the depends_on chain."""
transformations = []
t = transform
while t is not None:
transformations.append(t[index])
t = t.depends_on
# TODO: this list of transformation should probably be cached in the future
# to deal with changing beamline components (e.g. pixel positions) during a
# live data stream (see https://github.com/scipp/scippneutron/issues/76).
return transformations
return sc.scalar(1) if total_transform is None else total_transform


def maybe_transformation(
Expand All @@ -228,7 +190,210 @@ def maybe_transformation(
transformation fields.
"""
if (transformation_type := obj.attrs.get('transformation_type')) is not None:
from .nxtransformations import Transformation

return Transformation(obj).make_transformation(value, transformation_type, sel)
return value


class TransformationChainResolver:
"""
Resolve a chain of transformations, given depends_on attributes with absolute or
relative paths.
A `depends_on` field serves as an entry point into a chain of transformations.
It points to another entry, based on an absolute or relative path. The target
entry may have a `depends_on` attribute pointing to the next transform. This
class follows the paths and resolves the chain of transformations.
"""

class ChainError(KeyError):
"""Raised when a transformation chain cannot be resolved."""

pass

def __init__(self, stack: List[sc.DataGroup]):
self._stack = stack

@property
def root(self) -> TransformationChainResolver:
return TransformationChainResolver(self._stack[0:1])

@property
def parent(self) -> TransformationChainResolver:
if len(self._stack) == 1:
raise TransformationChainResolver.ChainError(
"Transformation depends on node beyond root"
)
return TransformationChainResolver(self._stack[:-1])

@property
def value(self) -> sc.DataGroup:
return self._stack[-1]

def __getitem__(self, path: str) -> TransformationChainResolver:
base, *remainder = path.split('/', maxsplit=1)
if base == '':
node = self.root
elif base == '.':
node = self
elif base == '..':
node = self.parent
else:
node = TransformationChainResolver(self._stack + [self._stack[-1][base]])
return node if len(remainder) == 0 else node[remainder[0]]

def resolve_depends_on(self) -> Optional[Union[sc.DataArray, sc.Variable]]:
"""
Resolve the depends_on attribute of a transformation chain.
Returns
-------
:
The resolved position in meter, or None if no depends_on was found.
"""
depends_on = self.value.get('depends_on')
if depends_on is None:
return None
# Note that transformations have to be applied in "reverse" order, i.e.,
# simply taking math.prod(chain) would be wrong, even if we could
# ignore potential time-dependence.
return combine_transformations(self.get_chain(depends_on))

def get_chain(self, depends_on: str) -> List[Union[sc.DataArray, sc.Variable]]:
if depends_on == '.':
return []
node = self[depends_on]
transform = node.value.copy(deep=False)
depends_on = '.'
if isinstance(transform, sc.DataArray):
if (attr := transform.coords.pop('depends_on', None)) is not None:
depends_on = attr.value
# If transform is time-dependent then we keep it is a DataArray, otherwise
# we convert it to a Variable.
transform = transform if transform.coords else transform.data
if transform.dtype in (sc.DType.translation3, sc.DType.affine_transform3):
transform = transform.to(unit='m', copy=False)
return [transform] + node.parent.get_chain(depends_on)


def compute_positions(
dg: sc.DataGroup,
*,
store_position: str = 'position',
store_transform: Optional[str] = None,
) -> sc.DataGroup:
"""
Recursively compute positions from depends_on attributes as well as the
[xyz]_pixel_offset fields of NXdetector groups.
This function does not operate directly on a NeXus file but on the result of
loading a NeXus file or sub-group into a scipp.DataGroup. NeXus puts no
limitations on the structure of the depends_on chains, i.e., they may reference
parent groups. If this is the case, a call to this function will fail if only the
subgroup is passed as input.
Note that this does not consider "legacy" ways of storing positions. In particular,
``NXmonitor.distance``, ``NXdetector.distance``, ``NXdetector.polar_angle``, and
``NXdetector.azimuthal_angle`` are ignored.
Note that transformation chains may be time-dependent. In this case it will not
be applied to the pixel offsets, since the result may consume too much memory and
the shape is in general incompatible with the shape of the data. Use the
``store_transform`` argument to store the resolved transformation chain in this
case.
If a transformation chain has an invalid 'depends_on' value, e.g., a path beyond
the root data group, then the chain is ignored and no position is computed. This
does not affect other chains.
Parameters
----------
dg:
Data group with depends_on entry points into transformation chains.
store_position:
Name used to store result of resolving each depends_on chain.
store_transform:
If not None, store the resolved transformation chain in this field.
Returns
-------
:
New data group with added positions.
"""
# Create resolver at root level, since any depends_on chain may lead to a parent,
# i.e., we cannot use a resolver at the level of each chain's entry point.
resolver = TransformationChainResolver([dg])
return _with_positions(
dg,
store_position=store_position,
store_transform=store_transform,
resolver=resolver,
)


def zip_pixel_offsets(x: Dict[str, sc.Variable], /) -> sc.Variable:
"""
Zip the x_pixel_offset, y_pixel_offset, and z_pixel_offset fields into a vector.
These fields originate from NXdetector groups. All but x_pixel_offset are optional,
e.g., for 2D detectors. Zero values for missing fields are assumed.
Parameters
----------
mapping:
Mapping (typically a data group, or data array coords) containing
x_pixel_offset, y_pixel_offset, and z_pixel_offset.
Returns
-------
:
Vectors with pixel offsets.
See Also
--------
compute_positions
"""
zero = sc.scalar(0.0, unit=x['x_pixel_offset'].unit)
return sc.spatial.as_vectors(
x['x_pixel_offset'],
x.get('y_pixel_offset', zero),
x.get('z_pixel_offset', zero),
)


def _with_positions(
dg: sc.DataGroup,
*,
store_position: str,
store_transform: Optional[str] = None,
resolver: TransformationChainResolver,
) -> sc.DataGroup:
out = sc.DataGroup()
transform = None
if 'depends_on' in dg:
try:
transform = resolver.resolve_depends_on()
out[store_position] = transform * sc.vector([0, 0, 0], unit='m')
if store_transform is not None:
out[store_transform] = transform
except TransformationChainResolver.ChainError:
pass
for name, value in dg.items():
if isinstance(value, sc.DataGroup):
value = _with_positions(
value,
store_position=store_position,
store_transform=store_transform,
resolver=resolver[name],
)
elif (
isinstance(value, sc.DataArray)
and 'x_pixel_offset' in value.coords
# Transform can be time-dependent, do not apply it to offsets since
# result can be massive and is in general not compatible with the shape
# of the data.
and (transform is not None and transform.dims == ())
):
offset = zip_pixel_offsets(value.coords).to(unit='m', copy=False)
value = value.assign_coords({store_position: transform * offset})
out[name] = value
return out
Loading

0 comments on commit 80affde

Please sign in to comment.