Skip to content

Commit

Permalink
feat: check transformation units (#41)
Browse files Browse the repository at this point in the history
  • Loading branch information
jokasimr authored Nov 6, 2024
1 parent 760a6f7 commit 3a05222
Show file tree
Hide file tree
Showing 2 changed files with 159 additions and 4 deletions.
55 changes: 51 additions & 4 deletions src/chexus/validators.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,14 +200,59 @@ def __init__(self) -> None:
)

def applies_to(self, node: Dataset | Group) -> bool:
return (
return is_transformation(node) and "offset" in node.attrs

def validate(self, node: Dataset | Group) -> Violation | None:
if "offset_units" not in node.attrs:
return Violation(node.name)


class transformation_offset_units_invalid(Validator):
def __init__(self) -> None:
super().__init__(
"transformation_offset_units_invalid",
"Transformation offset_units attr. should be a length unit",
)

def applies_to(self, node: Dataset | Group) -> bool:
return is_transformation(node) and "offset_units" in node.attrs

def validate(self, node: Dataset | Group) -> Violation | None:
import scipp as sc

try:
sc.scalar(1, unit=node.attrs["offset_units"]).to(unit="m")
except sc.UnitError:
return Violation(node.name)


class transformation_units_invalid(Validator):
def __init__(self) -> None:
super().__init__(
"transformation_value_units_invalid",
"Transformation value units should be a length unit "
"if transformation type is translation and "
"a rotation unit if transformation type is rotation",
)

def applies_to(self, node: Dataset | Group) -> bool:
return is_transformation(node) and (
isinstance(node, Dataset)
and is_transformation(node)
and "offset" in node.attrs
or (isinstance(node, Group) and 'value' in node.children)
)

def validate(self, node: Dataset | Group) -> Violation | None:
if "offset_units" not in node.attrs:
import scipp as sc

unit = (node.children['value'] if isinstance(node, Group) else node).attrs.get(
'units'
)
expected_unit = (
"m" if node.attrs["transformation_type"] == "translation" else "rad"
)
try:
sc.scalar(1, unit=unit).to(unit=expected_unit)
except sc.UnitError:
return Violation(node.name)


Expand Down Expand Up @@ -443,5 +488,7 @@ def base_validators(*, has_scipp=True):
validators += [
chopper_frequency_units_invalid(),
dataset_units_check(),
transformation_offset_units_invalid(),
transformation_units_invalid(),
]
return validators
108 changes: 108 additions & 0 deletions tests/validators_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -255,6 +255,114 @@ def test_transformation_offset_units_missing():
assert result.name == "x"


@pytest.mark.parametrize(
("unit"),
["m", "mm"],
)
def test_transformation_offset_units_invalid_good(unit):
good = chexus.Dataset(
name="x",
value=1,
shape=None,
dtype=float,
parent=None,
attrs={
"transformation_type": "translation",
"vector": [1.0, 0.0, 0.0],
"offset": 1.0,
"offset_units": unit,
},
)
assert chexus.validators.transformation_offset_units_invalid().applies_to(good)
assert (
chexus.validators.transformation_offset_units_invalid().validate(good) is None
)


@pytest.mark.parametrize(
("unit"),
["Hz", "rad"],
)
def test_transformation_offset_units_invalid_bad(unit):
bad = chexus.Dataset(
name="x",
value=1,
shape=None,
dtype=float,
parent=None,
attrs={
"transformation_type": "translation",
"vector": [1.0, 0.0, 0.0],
"offset": 1.0,
"offset_units": unit,
},
)
assert chexus.validators.transformation_offset_units_invalid().applies_to(bad)
assert isinstance(
chexus.validators.transformation_offset_units_invalid().validate(bad),
chexus.Violation,
)


@pytest.mark.parametrize(
("transformation_type", "good_units", "bad_units"),
[
("rotation", ("deg", "rad"), ("m", "")),
("translation", ("m", "mm"), ("deg", "")),
],
)
@pytest.mark.parametrize("is_log", [True, False])
def test_transformation_units_invalid(
transformation_type, good_units, bad_units, is_log
):
def create_transformation(transformation_type, unit, is_log):
common = {
"name": "x",
"parent": None,
"attrs": {
"transformation_type": transformation_type,
"vector": [1.0, 0.0, 0.0],
},
}
if is_log:
return chexus.Group(
**common,
children={
'value': chexus.Dataset(
name="value",
value=[1.0, 2.0],
shape=(2,),
dtype=float,
parent=None,
attrs={
"units": unit,
},
)
},
)
ds = chexus.Dataset(
**common,
value=1.0,
shape=None,
dtype=float,
)
ds.attrs['units'] = unit
return ds

for unit in good_units:
good = create_transformation(transformation_type, unit, is_log)
assert chexus.validators.transformation_units_invalid().applies_to(good)
assert chexus.validators.transformation_units_invalid().validate(good) is None

for unit in bad_units:
bad = create_transformation(transformation_type, unit, is_log)
assert chexus.validators.transformation_units_invalid().applies_to(bad)
assert isinstance(
chexus.validators.transformation_units_invalid().validate(bad),
chexus.Violation,
)


@pytest.mark.parametrize("units", ["NX_LENGTH", "NX_DIMENSIONLESS", "hz", ["m"]])
def test_units_invalid(units: str):
good = chexus.Dataset(
Expand Down

0 comments on commit 3a05222

Please sign in to comment.