From 4e7f949e62c2fd65c3451011c601af8831925b34 Mon Sep 17 00:00:00 2001 From: johannes Date: Fri, 27 Oct 2023 11:04:01 +0200 Subject: [PATCH] test: masks are added to all signals in dataset --- src/scippnexus/nxdata.py | 2 +- tests/nxdetector_test.py | 21 +++++++++++++++++++++ 2 files changed, 22 insertions(+), 1 deletion(-) diff --git a/src/scippnexus/nxdata.py b/src/scippnexus/nxdata.py index d52bd468..21fa2ce1 100644 --- a/src/scippnexus/nxdata.py +++ b/src/scippnexus/nxdata.py @@ -592,7 +592,7 @@ def assemble(self, dg: sc.DataGroup) -> Union[sc.DataArray, sc.Dataset]: for suffix, bitmask in bitmasks.items(): masks = self.transform_bitmask_to_dict_of_masks(bitmask, suffix) for da in ( - array_or_dataset.items() + array_or_dataset.values() if isinstance(array_or_dataset, sc.Dataset) else [array_or_dataset] ): diff --git a/tests/nxdetector_test.py b/tests/nxdetector_test.py index 8da21c53..817e4abb 100644 --- a/tests/nxdetector_test.py +++ b/tests/nxdetector_test.py @@ -758,3 +758,24 @@ def test_pixel_masks_reads_expected_fields(h5root): assert np.allclose(da.masks.get('dead').values, np.array([[0, 1], [0, 0]])) assert np.allclose(da.masks.get('gap_2').values, np.array([[1, 0], [0, 0]])) assert np.allclose(da.masks.get('dead_2').values, np.array([[0, 1], [0, 0]])) + + +def test_pixel_masks_adds_mask_to_all_dataarrays_of_dataset(h5root): + bitmask = 1 << np.array([[0, 1], [-1, -1]]) + da = sc.DataArray( + sc.array(dims=['xx', 'yy'], unit='K', values=[[1.1, 2.2], [3.3, 4.4]]) + ) + da.coords['detector_numbers'] = detector_numbers_xx_yy_1234() + da.coords['xx'] = sc.array(dims=['xx'], unit='m', values=[0.1, 0.2]) + detector = snx.create_class(h5root, 'detector0', NXdetector) + snx.create_field(detector, 'detector_numbers', da.coords['detector_numbers']) + snx.create_field(detector, 'xx', da.coords['xx']) + snx.create_field(detector, 'data', da.data) + snx.create_field(detector, 'data_2', da.data) + detector.attrs['auxiliary_signals'] = ['data_2'] + snx.create_field(detector, 'pixel_mask', bitmask) + detector.attrs['axes'] = ['xx', '.'] + detector = make_group(detector) + ds = detector[...] + assert set(ds['data'].masks.keys()) == set(('gap', 'dead')) + assert set(ds['data_2'].masks.keys()) == set(('gap', 'dead'))