Skip to content

Commit

Permalink
Custom group resolution for depends_on
Browse files Browse the repository at this point in the history
  • Loading branch information
jl-wynen committed Nov 19, 2024
1 parent 349c8df commit a4ab7d1
Show file tree
Hide file tree
Showing 2 changed files with 69 additions and 3 deletions.
12 changes: 11 additions & 1 deletion src/scippnexus/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,7 +222,7 @@ def nx_class(self) -> type | None:
return NXroot

@cached_property
def attrs(self) -> dict[str, Any]:
def attrs(self) -> MappingProxyType[str, Any]:
"""The attributes of the group.
Cannot be used for writing attributes, since they are cached for performance."""
Expand Down Expand Up @@ -479,6 +479,16 @@ def dims(self) -> tuple[str, ...]:
def shape(self) -> tuple[int, ...]:
return tuple(self.sizes.values())

@property
def definitions(self) -> MappingProxyType[str, str | type] | None:
return (
None if self._definitions is None else MappingProxyType(self._definitions)
)

@property
def underlying(self) -> H5Group:
return self._group


def _create_field_params_numpy(data: np.ndarray):
return data, None, {}
Expand Down
60 changes: 58 additions & 2 deletions src/scippnexus/nxtransformations.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,13 +37,20 @@

import warnings
from dataclasses import dataclass, field, replace
from pathlib import PurePosixPath
from typing import Literal

import numpy as np
import scipp as sc
from scipp.scipy import interpolate

from .base import Group, NexusStructureError, NXobject, base_definitions_dict
from .base import (
Group,
NexusStructureError,
NXobject,
base_definitions_dict,
is_dataset,
)
from .field import DependsOn, Field


Expand Down Expand Up @@ -266,6 +273,55 @@ def compute(self) -> sc.Variable | sc.DataArray:
return transform


def _resolve_path(cwd: str, path: str) -> str:
"""Resolve a path as if in a working directory, based only on strings.
``base`` must be absolute.
``path`` is resolved as if ``base`` is the current working directory.
Returns an absolute path.
"""
p = PurePosixPath(path)
if p.is_absolute():
return p.as_posix()

base_parts = list(PurePosixPath(cwd).parts)
path_parts = [segment for segment in p.parts if segment != '.']
while path_parts[0] == '..':
if not path_parts:
raise ValueError(f"Relative path beyond root: '{p}'")
base_parts.pop(-1)
path_parts.pop(0)
return PurePosixPath(*base_parts, *path_parts).as_posix()


def _locate_depends_on_target(parent: Field | Group, depends_on: str) -> Field | Group:
"""Find the target of a depends_on link.
The returned object is equivalent to calling ``parent[depends_on]``
in the context of transformations.
This function does not work in general because it does not handle NXdata attributes.
The method used here uses the underlying h5py groups to find the target group or
field to avoid constructing expensive intermediate snx.Group objects.
"""
raw = parent.underlying
target_path = _resolve_path(raw.name, depends_on)
target = raw.file[target_path]

if is_dataset(target):
from .base import _dtype_fromdataset, _squeezed_field_sizes

res = Field(
target,
parent=Group(target.parent, definitions=parent.definitions),
sizes=_squeezed_field_sizes(target),
dtype=_dtype_fromdataset(target),
)
else:
res = Group(target, definitions=parent.definitions)
return res


def parse_depends_on_chain(
parent: Field | Group, depends_on: DependsOn
) -> TransformationChain | None:
Expand All @@ -274,7 +330,7 @@ def parse_depends_on_chain(
depends_on = depends_on.value
try:
while depends_on != '.':
transform = parent[depends_on]
transform = _locate_depends_on_target(parent, depends_on)
parent = transform.parent
depends_on = transform.attrs['depends_on']
chain.transformations[transform.name] = transform[()]
Expand Down

0 comments on commit a4ab7d1

Please sign in to comment.