Skip to content

Commit

Permalink
test: masks are added to all signals in dataset
Browse files Browse the repository at this point in the history
  • Loading branch information
johannes committed Oct 27, 2023
1 parent 04370d4 commit 4e7f949
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 1 deletion.
2 changes: 1 addition & 1 deletion src/scippnexus/nxdata.py
Original file line number Diff line number Diff line change
Expand Up @@ -592,7 +592,7 @@ def assemble(self, dg: sc.DataGroup) -> Union[sc.DataArray, sc.Dataset]:
for suffix, bitmask in bitmasks.items():
masks = self.transform_bitmask_to_dict_of_masks(bitmask, suffix)
for da in (
array_or_dataset.items()
array_or_dataset.values()
if isinstance(array_or_dataset, sc.Dataset)
else [array_or_dataset]
):
Expand Down
21 changes: 21 additions & 0 deletions tests/nxdetector_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -758,3 +758,24 @@ def test_pixel_masks_reads_expected_fields(h5root):
assert np.allclose(da.masks.get('dead').values, np.array([[0, 1], [0, 0]]))
assert np.allclose(da.masks.get('gap_2').values, np.array([[1, 0], [0, 0]]))
assert np.allclose(da.masks.get('dead_2').values, np.array([[0, 1], [0, 0]]))


def test_pixel_masks_adds_mask_to_all_dataarrays_of_dataset(h5root):
bitmask = 1 << np.array([[0, 1], [-1, -1]])
da = sc.DataArray(
sc.array(dims=['xx', 'yy'], unit='K', values=[[1.1, 2.2], [3.3, 4.4]])
)
da.coords['detector_numbers'] = detector_numbers_xx_yy_1234()
da.coords['xx'] = sc.array(dims=['xx'], unit='m', values=[0.1, 0.2])
detector = snx.create_class(h5root, 'detector0', NXdetector)
snx.create_field(detector, 'detector_numbers', da.coords['detector_numbers'])
snx.create_field(detector, 'xx', da.coords['xx'])
snx.create_field(detector, 'data', da.data)
snx.create_field(detector, 'data_2', da.data)
detector.attrs['auxiliary_signals'] = ['data_2']
snx.create_field(detector, 'pixel_mask', bitmask)
detector.attrs['axes'] = ['xx', '.']
detector = make_group(detector)
ds = detector[...]
assert set(ds['data'].masks.keys()) == set(('gap', 'dead'))
assert set(ds['data_2'].masks.keys()) == set(('gap', 'dead'))

0 comments on commit 4e7f949

Please sign in to comment.