Skip to content

Commit

Permalink
Omit unchunked dimensions from Key objects created with DatasetToChunks
Browse files Browse the repository at this point in the history
This allows for splitting datasets across variables even when those variables have different dimensions. See the new integration test for a concrete use-case, resembling real model output.

Fixes #43

PiperOrigin-RevId: 471347485
  • Loading branch information
shoyer authored and Xarray-Beam authors committed Aug 31, 2022
1 parent 4dced4e commit a5a1dd5
Show file tree
Hide file tree
Showing 6 changed files with 47 additions and 5 deletions.
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@

setuptools.setup(
name='xarray-beam',
version='0.3.1',
version='0.3.2',
license='Apache 2.0',
author='Google LLC',
author_email='[email protected]',
Expand Down
2 changes: 1 addition & 1 deletion xarray_beam/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,4 +43,4 @@
)
from xarray_beam import Mean

__version__ = '0.3.1'
__version__ = '0.3.2'
3 changes: 2 additions & 1 deletion xarray_beam/_src/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -282,7 +282,8 @@ def __init__(
chunks = dataset.chunks
if chunks is None:
raise ValueError('dataset must be chunked or chunks must be set')
chunks = normalize_expanded_chunks(chunks, dataset.sizes)
sizes = {k: v for k, v in dataset.sizes.items() if k in chunks}
chunks = normalize_expanded_chunks(chunks, sizes)
self.dataset = dataset
self.chunks = chunks
self.split_vars = split_vars
Expand Down
18 changes: 18 additions & 0 deletions xarray_beam/_src/core_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -261,6 +261,7 @@ def test_dataset_to_chunks_whole(self):
test_util.EagerPipeline()
| xbeam.DatasetToChunks(dataset, chunks={})
)
expected = [(xbeam.Key(), dataset)]
self.assertIdenticalChunks(actual, expected)

def test_dataset_to_chunks_vars(self):
Expand All @@ -280,6 +281,23 @@ def test_dataset_to_chunks_vars(self):
)
self.assertIdenticalChunks(actual, expected)

def test_dataset_to_chunks_split_with_different_dims(self):
dataset = xarray.Dataset({
'foo': (('x', 'y'), np.array([[1, 2, 3], [4, 5, 6]])),
'bar': ('x', np.array([1, 2])),
})
expected = [
(xbeam.Key({'x': 0}, {'foo'}), dataset[['foo']].head(x=1)),
(xbeam.Key({'x': 0}, {'bar'}), dataset[['bar']].head(x=1)),
(xbeam.Key({'x': 1}, {'foo'}), dataset[['foo']].tail(x=1)),
(xbeam.Key({'x': 1}, {'bar'}), dataset[['bar']].tail(x=1)),
]
actual = (
test_util.EagerPipeline()
| xbeam.DatasetToChunks(dataset, chunks={'x': 1}, split_vars=True)
)
self.assertIdenticalChunks(actual, expected)


class ValidateEachChunkTest(test_util.TestCase):

Expand Down
20 changes: 20 additions & 0 deletions xarray_beam/_src/integration_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,26 @@ def test_rechunk_zarr_to_zarr(self, template_method, split_vars):

xarray.testing.assert_identical(roundtripped, dataset)

def test_dataset_to_zarr_with_split_vars(self):
dataset = xarray.Dataset(
{
'volumetric': (
('t', 'x', 'y', 'z'), np.arange(240).reshape(10, 2, 3, 4)
),
'surface': (('t', 'x', 'y'), np.arange(60).reshape(10, 2, 3)),
}
)
temp_dir = self.create_tempdir().full_path
template = dataset.chunk()
chunks = {'t': 1}
(
test_util.EagerPipeline()
| xbeam.DatasetToChunks(dataset, chunks, split_vars=True)
| xbeam.ChunksToZarr(temp_dir, template, chunks)
)
actual = xarray.open_zarr(temp_dir, consolidated=True)
xarray.testing.assert_identical(actual, dataset)


if __name__ == '__main__':
absltest.main()
7 changes: 5 additions & 2 deletions xarray_beam/_src/zarr_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,14 +140,13 @@ def test_2d_chunks_to_zarr(self, coords):
result = xarray.open_zarr(temp_dir, consolidated=True)
xarray.testing.assert_identical(dataset, result)

def test_dataset_to_zarr(self):
def test_dataset_to_zarr_simple(self):
dataset = xarray.Dataset(
{'foo': ('x', np.arange(0, 60, 10))},
coords={'x': np.arange(6)},
attrs={'meta': 'data'},
)
chunked = dataset.chunk({'x': 3})

temp_dir = self.create_tempdir().full_path
(
test_util.EagerPipeline()
Expand All @@ -156,6 +155,10 @@ def test_dataset_to_zarr(self):
actual = xarray.open_zarr(temp_dir, consolidated=True)
xarray.testing.assert_identical(actual, dataset)

def test_dataset_to_zarr_unchunked(self):
dataset = xarray.Dataset(
{'foo': ('x', np.arange(0, 60, 10))},
)
temp_dir = self.create_tempdir().full_path
with self.assertRaisesRegex(
ValueError,
Expand Down

0 comments on commit a5a1dd5

Please sign in to comment.