From f09989b51a0941386e1cb98bd27240b0a26eccaf Mon Sep 17 00:00:00 2001 From: Stephan Hoyer Date: Fri, 2 Sep 2022 21:39:41 -0700 Subject: [PATCH] Support datasets with differently chunked variables in DatasetToChunks There are two major internal changes: 1. Key objects from DatasetToChunks now can include different dimensions for different variables when using split_vars=True. This makes it easier to handle large datasets with many variables and different chunking per variable. 2. Inputs inside the DatasetToChunks pipeline can now be sharded across many tasks. This is important for scalability to large datasets, especially with this chagne because the above refactor increases the number of inputs by the number of variables when split_vars=True. Otherwise, we can run into performance issues on the machine launching the pipeline when the number of inputs goes into the millions (e.g., slow speed, out of memory). See the new integration test for a concrete use-case, resembling real model output. Also revise the warning message in the README to be a bit friendlier. Fixes https://github.com/google/xarray-beam/issues/43 PiperOrigin-RevId: 471948735 --- .github/workflows/ci-build.yml | 2 +- README.md | 21 ++-- setup.py | 4 +- xarray_beam/__init__.py | 8 +- xarray_beam/_src/core.py | 138 +++++++++++++++++++++++---- xarray_beam/_src/core_test.py | 87 +++++++++++++++-- xarray_beam/_src/integration_test.py | 75 ++++++++++++++- xarray_beam/_src/test_util.py | 3 +- xarray_beam/_src/zarr_test.py | 7 +- 9 files changed, 292 insertions(+), 53 deletions(-) diff --git a/.github/workflows/ci-build.yml b/.github/workflows/ci-build.yml index 0d5ab63..8b71462 100644 --- a/.github/workflows/ci-build.yml +++ b/.github/workflows/ci-build.yml @@ -15,7 +15,7 @@ jobs: strategy: fail-fast: false matrix: - python-version: ["3.7", "3.8", "3.9"] + python-version: ["3.7", "3.8", "3.9", "3.10"] steps: - name: Cancel previous uses: styfle/cancel-workflow-action@0.7.0 diff --git a/README.md b/README.md index 9b13c2d..a00bb8b 100644 --- a/README.md +++ b/README.md @@ -20,20 +20,21 @@ multi-dimensional labeled arrays, such as: For more about our approach and how to get started, **[read the documentation](https://xarray-beam.readthedocs.io/)**! -**🚨 Warning: Xarray-Beam is new and unpolished 🚨** +**Warning: Xarray-Beam is a sharp tool 🔪** -Expect sharp edges 🔪 and performance cliffs 🧗, particularly related to the -management of lazy data with Dask and reading/writing data with Zarr. We have -used it to efficiently process ~25 TB datasets. We _expect_ it to scale to PB -size datasets, but that's easier said than done. We welcome feedback and -contributions from early adopters, and hope to have it ready for wider audience -soon. +Xarray-Beam is relatively new, and focused on expert users: + +- We use it extensively at Google for processing large-scale weather datasets, + but there is not yet a vibrant external community. +- It provides low-level abstractions that facilitate writing very large + scale data pipelines (e.g., 100+ TB), but by design it requires explicitly + thinking about how every operation is parallelized. ## Installation -Xarray-Beam requires recent versions of immutabledict, xarray, dask, rechunker -and zarr, and the *latest* release of Apache Beam (2.31.0 or later). For best -performance when writing Zarr files, use Xarray 0.19.0 or later. +Xarray-Beam requires recent versions of immutabledict, Xarray, Dask, Rechunker, +Zarr, and Apache Beam. For best performance when writing Zarr files, use Xarray +0.19.0 or later. ## Disclaimer diff --git a/setup.py b/setup.py index 5a50809..df6b9e9 100644 --- a/setup.py +++ b/setup.py @@ -42,7 +42,7 @@ setuptools.setup( name='xarray-beam', - version='0.3.1', + version='0.4.0', license='Apache 2.0', author='Google LLC', author_email='noreply@google.com', @@ -52,6 +52,6 @@ 'docs': docs_requires, }, url='https://github.com/google/xarray-beam', - packages=setuptools.find_packages(exclude=["examples"]), + packages=setuptools.find_packages(exclude=['examples']), python_requires='>=3', ) diff --git a/xarray_beam/__init__.py b/xarray_beam/__init__.py index 4d230e4..db9287d 100644 --- a/xarray_beam/__init__.py +++ b/xarray_beam/__init__.py @@ -14,6 +14,9 @@ """Public API for Xarray-Beam.""" # pylint: disable=g-multiple-import +from xarray_beam._src.combiners import ( + MeanCombineFn, +) from xarray_beam._src.core import ( Key, DatasetToChunks, @@ -21,9 +24,6 @@ offsets_to_slices, validate_chunk ) -from xarray_beam._src.combiners import ( - MeanCombineFn, -) from xarray_beam._src.rechunk import ( ConsolidateChunks, ConsolidateVariables, @@ -43,4 +43,4 @@ ) from xarray_beam import Mean -__version__ = '0.3.1' +__version__ = '0.4.0' diff --git a/xarray_beam/_src/core.py b/xarray_beam/_src/core.py index a740b97..68b040d 100644 --- a/xarray_beam/_src/core.py +++ b/xarray_beam/_src/core.py @@ -13,6 +13,7 @@ # limitations under the License. """Core data model for xarray-beam.""" import itertools +import math from typing import ( AbstractSet, Container, @@ -196,16 +197,16 @@ def _chunks_to_offsets( def iter_chunk_keys( - chunks: Mapping[str, Tuple[int, ...]], + offsets: Mapping[str, Sequence[int]], + vars: Optional[AbstractSet[str]] = None, # pylint: disable=redefined-builtin ) -> Iterator[Key]: """Iterate over the Key objects corresponding to the given chunks.""" - all_offsets = _chunks_to_offsets(chunks) - chunk_indices = [range(len(sizes)) for sizes in chunks.values()] + chunk_indices = [range(len(sizes)) for sizes in offsets.values()] for indices in itertools.product(*chunk_indices): - offsets = { - dim: all_offsets[dim][index] for dim, index in zip(chunks, indices) + key_offsets = { + dim: offsets[dim][index] for dim, index in zip(offsets, indices) } - yield Key(offsets) + yield Key(key_offsets, vars) def compute_offset_index( @@ -262,6 +263,7 @@ def __init__( chunks: Optional[Mapping[str, Union[int, Tuple[int, ...]]]] = None, split_vars: bool = False, num_threads: Optional[int] = None, + shard_keys_threshold: int = 200_000, ): """Initialize DatasetToChunks. @@ -271,44 +273,138 @@ def __init__( chunked. If the dataset *is* already chunked with Dask, `chunks` takes precedence over the existing chunks. split_vars: whether to split the dataset into separate records for each - data variables or to keep all data variables together. + data variable or to keep all data variables together. num_threads: optional number of Dataset chunks to load in parallel per worker. More threads can increase throughput, but also increases memory usage and makes it harder for Beam runners to shard work. Note that each variable in a Dataset is already loaded in parallel, so this is most useful for Datasets with a small number of variables. + shard_keys_threshold: threshold at which to compute keys on Beam workers, + rather than only on the host process. This is important for scaling + pipelines to millions of tasks. """ if chunks is None: chunks = dataset.chunks if chunks is None: - raise ValueError('dataset must be chunked or chunks must be set') - chunks = normalize_expanded_chunks(chunks, dataset.sizes) + raise ValueError('dataset must be chunked or chunks must be provided') + expanded_chunks = normalize_expanded_chunks(chunks, dataset.sizes) self.dataset = dataset - self.chunks = chunks + self.expanded_chunks = expanded_chunks self.split_vars = split_vars self.num_threads = num_threads - self.offset_index = compute_offset_index(_chunks_to_offsets(chunks)) + self.shard_keys_threshold = shard_keys_threshold + # TODO(shoyer): consider recalculating these potentially large properties on + # each worker, rather than only once on the host. + self.offsets = _chunks_to_offsets(expanded_chunks) + self.offset_index = compute_offset_index(self.offsets) + # We use the simple heuristic of only sharding inputs along the dimension + # with the most chunks. + lengths = {k: len(v) for k, v in self.offsets.items()} + self.sharded_dim = max(lengths, key=lengths.get) if lengths else None + self.shard_count = self._shard_count() + + def _task_count(self) -> int: + """Count the number of tasks emitted by this transform.""" + counts = {k: len(v) for k, v in self.expanded_chunks.items()} + if not self.split_vars: + return int(np.prod(list(counts.values()))) + total = 0 + for variable in self.dataset.values(): + count_list = [v for k, v in counts.items() if k in variable.dims] + total += int(np.prod(count_list)) + return total + + def _shard_count(self) -> Optional[int]: + """Determine the number of times to shard input keys.""" + task_count = self._task_count() + if task_count <= self.shard_keys_threshold: + return None # no sharding + + if not self.split_vars: + return math.ceil(task_count / self.shard_keys_threshold) + + var_count = sum( + self.sharded_dim in var.dims for var in self.dataset.values() + ) + return math.ceil(task_count / (var_count * self.shard_keys_threshold)) + + def _iter_all_keys(self) -> Iterator[Key]: + """Iterate over all Key objects.""" + if not self.split_vars: + yield from iter_chunk_keys(self.offsets) + else: + for name, variable in self.dataset.items(): + relevant_offsets = { + k: v for k, v in self.offsets.items() if k in variable.dims + } + yield from iter_chunk_keys(relevant_offsets, vars={name}) + + def _iter_shard_keys( + self, shard_id: Optional[int], var_name: Optional[str] + ) -> Iterator[Key]: + """Iterate over Key objects for a specific shard and variable.""" + if var_name is None: + offsets = self.offsets + else: + offsets = { + dim: self.offsets[dim] for dim in self.dataset[var_name].dims + } + + if shard_id is None: + assert self.split_vars + yield from iter_chunk_keys(offsets, vars={var_name}) + else: + assert self.split_vars == (var_name is not None) + dim = self.sharded_dim + count = math.ceil(len(self.offsets[dim]) / self.shard_count) + dim_slice = slice(shard_id * count, (shard_id + 1) * count) + offsets = {**offsets, dim: offsets[dim][dim_slice]} + vars_ = {var_name} if self.split_vars else None + yield from iter_chunk_keys(offsets, vars=vars_) + + def _shard_inputs(self) -> List[Tuple[Optional[int], Optional[str]]]: + """Create inputs for sharded key iterators.""" + if not self.split_vars: + return [(i, None) for i in range(self.shard_count)] + + inputs = [] + for name, variable in self.dataset.items(): + if self.sharded_dim in variable.dims: + inputs.extend([(i, name) for i in range(self.shard_count)]) + else: + inputs.append((None, name)) + return inputs def _key_to_chunks(self, key: Key) -> Iterator[Tuple[Key, xarray.Dataset]]: + """Convert a Key into an in-memory (Key, xarray.Dataset) pair.""" sizes = { - dim: self.chunks[dim][self.offset_index[dim][offset]] + dim: self.expanded_chunks[dim][self.offset_index[dim][offset]] for dim, offset in key.offsets.items() } slices = offsets_to_slices(key.offsets, sizes) - chunk = self.dataset.isel(slices) + dataset = self.dataset if key.vars is None else self.dataset[list(key.vars)] + chunk = dataset.isel(slices) # Load the data, using a separate thread for each variable - num_threads = len(self.dataset.data_vars) + num_threads = len(self.dataset) result = chunk.chunk().compute(num_workers=num_threads) - if self.split_vars: - for k in result: - yield key.replace(vars={k}), result[[k]] - else: - yield key, result + yield key, result def expand(self, pcoll): + if self.shard_count is None: + # Create all keys on the machine launching the Beam pipeline. This is + # faster if the number of keys is small. + key_pcoll = pcoll | beam.Create(self._iter_all_keys()) + else: + # Create keys in separate shards on Beam workers. This is more scalable. + key_pcoll = ( + pcoll + | beam.Create(self._shard_inputs()) + | beam.FlatMapTuple(self._iter_shard_keys) + | beam.Reshuffle() + ) + return ( - pcoll - | beam.Create(iter_chunk_keys(self.chunks)) + key_pcoll | threadmap.FlatThreadMap( self._key_to_chunks, num_threads=self.num_threads ) diff --git a/xarray_beam/_src/core_test.py b/xarray_beam/_src/core_test.py index 7cdbf89..1d30358 100644 --- a/xarray_beam/_src/core_test.py +++ b/xarray_beam/_src/core_test.py @@ -14,6 +14,7 @@ """Tests for xarray_beam._src.core.""" from absl.testing import absltest +from absl.testing import parameterized import apache_beam as beam import immutabledict import numpy as np @@ -32,7 +33,7 @@ def test_constructor(self): key = xbeam.Key({'x': 0, 'y': 10}) self.assertIsInstance(key.offsets, immutabledict.immutabledict) self.assertEqual(dict(key.offsets), {'x': 0, 'y': 10}) - self.assertEqual(key.vars, None) + self.assertIsNone(key.vars) key = xbeam.Key(vars={'foo'}) self.assertEqual(dict(key.offsets), {}) @@ -181,7 +182,7 @@ def test_offsets_to_slices_base(self): class DatasetToChunksTest(test_util.TestCase): def test_iter_chunk_keys(self): - actual = list(core.iter_chunk_keys({'x': (3, 3), 'y': (2, 2, 2)})) + actual = list(core.iter_chunk_keys({'x': (0, 3), 'y': (0, 2, 4)})) expected = [ xbeam.Key({'x': 0, 'y': 0}), xbeam.Key({'x': 0, 'y': 2}), @@ -248,6 +249,12 @@ def test_dataset_to_chunks_multiple(self): ) self.assertIdenticalChunks(actual, expected) + actual = ( + test_util.EagerPipeline() + | xbeam.DatasetToChunks(dataset.chunk({'x': 3}), shard_keys_threshold=1) + ) + self.assertIdenticalChunks(actual, expected) + def test_dataset_to_chunks_whole(self): dataset = xarray.Dataset({'foo': ('x', np.arange(6))}) expected = [(xbeam.Key({'x': 0}), dataset)] @@ -261,6 +268,7 @@ def test_dataset_to_chunks_whole(self): test_util.EagerPipeline() | xbeam.DatasetToChunks(dataset, chunks={}) ) + expected = [(xbeam.Key({'x': 0}), dataset)] self.assertIdenticalChunks(actual, expected) def test_dataset_to_chunks_vars(self): @@ -280,6 +288,65 @@ def test_dataset_to_chunks_vars(self): ) self.assertIdenticalChunks(actual, expected) + @parameterized.parameters( + {'shard_keys_threshold': 1}, + {'shard_keys_threshold': 2}, + {'shard_keys_threshold': 10}, + ) + def test_dataset_to_chunks_split_with_different_dims( + self, shard_keys_threshold + ): + dataset = xarray.Dataset({ + 'foo': (('x', 'y'), np.array([[1, 2, 3], [4, 5, 6]])), + 'bar': ('x', np.array([1, 2])), + 'baz': ('z', np.array([1, 2, 3])), + }) + expected = [ + (xbeam.Key({'x': 0, 'y': 0}, {'foo'}), dataset[['foo']].head(x=1)), + (xbeam.Key({'x': 0}, {'bar'}), dataset[['bar']].head(x=1)), + (xbeam.Key({'x': 1, 'y': 0}, {'foo'}), dataset[['foo']].tail(x=1)), + (xbeam.Key({'x': 1}, {'bar'}), dataset[['bar']].tail(x=1)), + (xbeam.Key({'z': 0}, {'baz'}), dataset[['baz']]), + ] + actual = ( + test_util.EagerPipeline() + | xbeam.DatasetToChunks( + dataset, + chunks={'x': 1}, + split_vars=True, + shard_keys_threshold=shard_keys_threshold, + ) + ) + self.assertIdenticalChunks(actual, expected) + + def test_dataset_to_chunks_empty(self): + dataset = xarray.Dataset() + expected = [(xbeam.Key({}), dataset)] + actual = ( + test_util.EagerPipeline() + | xbeam.DatasetToChunks(dataset) + ) + self.assertIdenticalChunks(actual, expected) + + def test_task_count(self): + dataset = xarray.Dataset({ + 'foo': (('x', 'y'), np.zeros((3, 6))), + 'bar': ('x', np.zeros(3)), + 'baz': ('z', np.zeros(10)), + }) + + to_chunks = xbeam.DatasetToChunks(dataset, chunks={'x': 1}) + self.assertEqual(to_chunks._task_count(), 3) + + to_chunks = xbeam.DatasetToChunks(dataset, chunks={'x': 1}, split_vars=True) + self.assertEqual(to_chunks._task_count(), 7) + + to_chunks = xbeam.DatasetToChunks(dataset, chunks={'y': 1}, split_vars=True) + self.assertEqual(to_chunks._task_count(), 8) + + to_chunks = xbeam.DatasetToChunks(dataset, chunks={'z': 1}, split_vars=True) + self.assertEqual(to_chunks._task_count(), 12) + class ValidateEachChunkTest(test_util.TestCase): @@ -287,13 +354,13 @@ def test_unmatched_dimension_raises_error(self): dataset = xarray.Dataset({'foo': ('x', np.arange(6))}) with self.assertRaises(ValueError) as e: ( - [(xbeam.Key({'x': 0, 'y': 0}), dataset)] - | xbeam.ValidateEachChunk() + [(xbeam.Key({'x': 0, 'y': 0}), dataset)] + | xbeam.ValidateEachChunk() ) self.assertIn( - "Key offset(s) 'y' in Key(offsets={'x': 0, 'y': 0}, vars=None) not found in " - "Dataset dimensions", - e.exception.args[0] + "Key offset(s) 'y' in Key(offsets={'x': 0, 'y': 0}, vars=None) not " + "found in Dataset dimensions", + e.exception.args[0] ) def test_unmatched_variables_raises_error(self): @@ -304,9 +371,9 @@ def test_unmatched_variables_raises_error(self): | xbeam.ValidateEachChunk() ) self.assertIn( - "Key var(s) 'bar' in Key(offsets={'x': 0}, vars={'bar'}) not found in Dataset " - "data variables", - e.exception.args[0] + "Key var(s) 'bar' in Key(offsets={'x': 0}, vars={'bar'}) not found in " + "Dataset data variables", + e.exception.args[0] ) def test_validate_chunks_compose_in_pipeline(self): diff --git a/xarray_beam/_src/integration_test.py b/xarray_beam/_src/integration_test.py index e83b76b..4eb900c 100644 --- a/xarray_beam/_src/integration_test.py +++ b/xarray_beam/_src/integration_test.py @@ -37,6 +37,18 @@ class IntegrationTest(test_util.TestCase): 'template_method': 'eager', 'split_vars': True, }, + { + 'testcase_name': 'eager_unified_sharded', + 'template_method': 'eager', + 'split_vars': False, + 'shard_keys_threshold': 20, + }, + { + 'testcase_name': 'eager_split_sharded', + 'template_method': 'eager', + 'split_vars': True, + 'shard_keys_threshold': 20, + }, { 'testcase_name': 'lazy_unified', 'template_method': 'lazy', @@ -53,7 +65,9 @@ class IntegrationTest(test_util.TestCase): 'split_vars': True, }, ) - def test_rechunk_zarr_to_zarr(self, template_method, split_vars): + def test_rechunk_zarr_to_zarr( + self, template_method, split_vars, shard_keys_threshold=1_000_000, + ): src_dir = self.create_tempdir('source').full_path dest_dir = self.create_tempdir('destination').full_path @@ -83,7 +97,11 @@ def test_rechunk_zarr_to_zarr(self, template_method, split_vars): # run pipeline ( pipeline - | xbeam.DatasetToChunks(on_disk, split_vars=split_vars) + | xbeam.DatasetToChunks( + on_disk, + split_vars=split_vars, + shard_keys_threshold=shard_keys_threshold, + ) | xbeam.Rechunk( on_disk.sizes, source_chunks, target_chunks, itemsize=8, @@ -95,6 +113,59 @@ def test_rechunk_zarr_to_zarr(self, template_method, split_vars): xarray.testing.assert_identical(roundtripped, dataset) + @parameterized.named_parameters( + { + 'testcase_name': 'unified_unsharded', + 'split_vars': False, + 'shard_keys_threshold': 1_000_000, + }, + { + 'testcase_name': 'split_unsharded', + 'split_vars': True, + 'shard_keys_threshold': 1_000_000, + }, + { + 'testcase_name': 'unified_sharded', + 'split_vars': False, + 'shard_keys_threshold': 3, + }, + { + 'testcase_name': 'split_sharded', + 'split_vars': True, + 'shard_keys_threshold': 3, + }, + ) + def test_dataset_to_zarr_with_irregular_variables( + self, split_vars, shard_keys_threshold, + ): + dataset = xarray.Dataset( + { + 'volume1': ( + ('t', 'x', 'y', 'z1'), np.arange(240).reshape(10, 2, 3, 4) + ), + 'volume2': ( + ('t', 'x', 'y', 'z2'), np.arange(300).reshape(10, 2, 3, 5) + ), + 'surface': (('t', 'x', 'y'), np.arange(60).reshape(10, 2, 3)), + 'static': (('x', 'y'), np.arange(6).reshape(2, 3)), + } + ) + temp_dir = self.create_tempdir().full_path + template = dataset.chunk() + chunks = {'t': 1, 'z1': 2, 'z2': 3} + ( + test_util.EagerPipeline() + | xbeam.DatasetToChunks( + dataset, + chunks, + split_vars=split_vars, + shard_keys_threshold=shard_keys_threshold, + ) + | xbeam.ChunksToZarr(temp_dir, template, chunks) + ) + actual = xarray.open_zarr(temp_dir, consolidated=True) + xarray.testing.assert_identical(actual, dataset) + if __name__ == '__main__': absltest.main() diff --git a/xarray_beam/_src/test_util.py b/xarray_beam/_src/test_util.py index 7d45f55..dda730c 100644 --- a/xarray_beam/_src/test_util.py +++ b/xarray_beam/_src/test_util.py @@ -52,11 +52,12 @@ def __or__(self, ptransform): class TestCase(parameterized.TestCase): + """TestCase for use in internal Xarray-Beam tests.""" def _assert_chunks(self, array_assert_func, actual, expected): actual = dict(actual) expected = dict(expected) - self.assertEqual(list(expected), list(actual), msg='inconsistent keys') + self.assertCountEqual(expected, actual, msg='inconsistent keys') for key in expected: array_assert_func(actual[key], expected[key]) diff --git a/xarray_beam/_src/zarr_test.py b/xarray_beam/_src/zarr_test.py index 1090c2a..da0748b 100644 --- a/xarray_beam/_src/zarr_test.py +++ b/xarray_beam/_src/zarr_test.py @@ -140,14 +140,13 @@ def test_2d_chunks_to_zarr(self, coords): result = xarray.open_zarr(temp_dir, consolidated=True) xarray.testing.assert_identical(dataset, result) - def test_dataset_to_zarr(self): + def test_dataset_to_zarr_simple(self): dataset = xarray.Dataset( {'foo': ('x', np.arange(0, 60, 10))}, coords={'x': np.arange(6)}, attrs={'meta': 'data'}, ) chunked = dataset.chunk({'x': 3}) - temp_dir = self.create_tempdir().full_path ( test_util.EagerPipeline() @@ -156,6 +155,10 @@ def test_dataset_to_zarr(self): actual = xarray.open_zarr(temp_dir, consolidated=True) xarray.testing.assert_identical(actual, dataset) + def test_dataset_to_zarr_unchunked(self): + dataset = xarray.Dataset( + {'foo': ('x', np.arange(0, 60, 10))}, + ) temp_dir = self.create_tempdir().full_path with self.assertRaisesRegex( ValueError,