Skip to content

Commit

Permalink
Support datasets with differently chunked variables in DatasetToChunks
Browse files Browse the repository at this point in the history
There are two major internal changes:
1. Key objects from DatasetToChunks now can include different dimensions for different variables when using split_vars=True. This makes it easier to handle large datasets with many variables and different chunking per variable.
2. Inputs inside the DatasetToChunks pipeline can now be sharded across many tasks. This is important for scalability to large datasets, especially with this chagne because the above refactor increases the number of inputs by the number of variables when split_vars=True. Otherwise, we can run into performance issues on the machine launching the pipeline when the number of inputs goes into the millions (e.g., slow speed, out of memory).

See the new integration test for a concrete use-case, resembling real model output.

Also revise the warning message in the README to be a bit friendlier.

Fixes #43

PiperOrigin-RevId: 471347485
  • Loading branch information
shoyer authored and Xarray-Beam authors committed Sep 1, 2022
1 parent 4dced4e commit b607f99
Show file tree
Hide file tree
Showing 8 changed files with 281 additions and 48 deletions.
22 changes: 12 additions & 10 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -20,20 +20,22 @@ multi-dimensional labeled arrays, such as:
For more about our approach and how to get started,
**[read the documentation](https://xarray-beam.readthedocs.io/)**!

**🚨 Warning: Xarray-Beam is new and unpolished 🚨**
**Warning: Xarray-Beam is a sharp tool 🔪**

Expect sharp edges 🔪 and performance cliffs 🧗, particularly related to the
management of lazy data with Dask and reading/writing data with Zarr. We have
used it to efficiently process ~25 TB datasets. We _expect_ it to scale to PB
size datasets, but that's easier said than done. We welcome feedback and
contributions from early adopters, and hope to have it ready for wider audience
soon.
Xarray-Beam is relatively new, and focused on expert users:

- We use it extensively at Google for processing large-scale weather datasets,
but there is not yet a vibrant external community. It is relatively stable,
but we are still refining parts of its API.
- It provides low-level abstractions that facilitate writing very large
scale data pipelines (e.g., 100+ TB), but by design it requires explicitly
thinking about how every operation is parallelized.

## Installation

Xarray-Beam requires recent versions of immutabledict, xarray, dask, rechunker
and zarr, and the *latest* release of Apache Beam (2.31.0 or later). For best
performance when writing Zarr files, use Xarray 0.19.0 or later.
Xarray-Beam requires recent versions of immutabledict, Xarray, Dask, Rechunker,
Zarr, and Apache Beam. For best performance when writing Zarr files, use Xarray
0.19.0 or later.

## Disclaimer

Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@

setuptools.setup(
name='xarray-beam',
version='0.3.1',
version='0.4.0',
license='Apache 2.0',
author='Google LLC',
author_email='[email protected]',
Expand Down
2 changes: 1 addition & 1 deletion xarray_beam/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,4 +43,4 @@
)
from xarray_beam import Mean

__version__ = '0.3.1'
__version__ = '0.4.0'
137 changes: 116 additions & 21 deletions xarray_beam/_src/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# limitations under the License.
"""Core data model for xarray-beam."""
import itertools
import math
from typing import (
AbstractSet,
Container,
Expand Down Expand Up @@ -196,16 +197,16 @@ def _chunks_to_offsets(


def iter_chunk_keys(
chunks: Mapping[str, Tuple[int, ...]],
offsets: Mapping[str, Sequence[int]],
vars: Optional[AbstractSet[str]] = None, # pylint: disable=redefined-builtin
) -> Iterator[Key]:
"""Iterate over the Key objects corresponding to the given chunks."""
all_offsets = _chunks_to_offsets(chunks)
chunk_indices = [range(len(sizes)) for sizes in chunks.values()]
chunk_indices = [range(len(sizes)) for sizes in offsets.values()]
for indices in itertools.product(*chunk_indices):
offsets = {
dim: all_offsets[dim][index] for dim, index in zip(chunks, indices)
key_offsets = {
dim: offsets[dim][index] for dim, index in zip(offsets, indices)
}
yield Key(offsets)
yield Key(key_offsets, vars)


def compute_offset_index(
Expand Down Expand Up @@ -262,6 +263,7 @@ def __init__(
chunks: Optional[Mapping[str, Union[int, Tuple[int, ...]]]] = None,
split_vars: bool = False,
num_threads: Optional[int] = None,
shard_keys_threshold: int = 200_000,
):
"""Initialize DatasetToChunks.
Expand All @@ -271,44 +273,137 @@ def __init__(
chunked. If the dataset *is* already chunked with Dask, `chunks` takes
precedence over the existing chunks.
split_vars: whether to split the dataset into separate records for each
data variables or to keep all data variables together.
data variable or to keep all data variables together.
num_threads: optional number of Dataset chunks to load in parallel per
worker. More threads can increase throughput, but also increases memory
usage and makes it harder for Beam runners to shard work. Note that each
variable in a Dataset is already loaded in parallel, so this is most
useful for Datasets with a small number of variables.
shard_keys_threshold: threshold at which to compute keys on Beam workers,
rather than only on the host process. This is important for
"""
if chunks is None:
chunks = dataset.chunks
if chunks is None:
raise ValueError('dataset must be chunked or chunks must be set')
chunks = normalize_expanded_chunks(chunks, dataset.sizes)
raise ValueError('dataset must be chunked or chunks must be provided')
expanded_chunks = normalize_expanded_chunks(chunks, dataset.sizes)
self.dataset = dataset
self.chunks = chunks
self.expanded_chunks = expanded_chunks
self.split_vars = split_vars
self.num_threads = num_threads
self.offset_index = compute_offset_index(_chunks_to_offsets(chunks))
self.shard_keys_threshold = shard_keys_threshold
# TODO(shoyer): consider recalculating these potentially large properties on
# each worker, rather than only once on the host.
self.offsets = _chunks_to_offsets(expanded_chunks)
self.offset_index = compute_offset_index(self.offsets)
# We use the simple heuristic of only sharding inputs only along the
# dimension with the most chunks.
lengths = {k: len(v) for k, v in self.offsets.items()}
self.sharded_dim = max(lengths, key=lengths.get) if lengths else None
self.shard_count = self._shard_count()

def _task_count(self) -> int:
"""Count the number of tasks emitted by this transform."""
counts = {k: len(v) for k, v in self.expanded_chunks.items()}
if not self.split_vars:
return int(np.prod(list(counts.values())))
total = 0
for variable in self.dataset.values():
count_list = [v for k, v in counts.items() if k in variable.dims]
total += int(np.prod(count_list))
return total

def _shard_count(self) -> Optional[int]:
"""Determine the number of times to shard input keys."""
task_count = self._task_count()
if task_count <= self.shard_keys_threshold:
return None # no sharding

if not self.split_vars:
return math.ceil(task_count / self.shard_keys_threshold)

var_count = sum(
self.sharded_dim in var.dims for var in self.dataset.values()
)
return math.ceil(task_count / (var_count * self.shard_keys_threshold))

def _iter_all_keys(self) -> Iterator[Key]:
"""Iterate over all Key objects."""
if not self.split_vars:
yield from iter_chunk_keys(self.offsets)
else:
for name, variable in self.dataset.items():
relevant_offsets = {
k: v for k, v in self.offsets.items() if k in variable.dims
}
yield from iter_chunk_keys(relevant_offsets, vars={name})

def _iter_shard_keys(
self, shard_id: Optional[int], var_name: Optional[str]
) -> Iterator[Key]:
"""Iterate over Key objects for a specific shard and variable."""
if var_name is None:
offsets = self.offsets
else:
offsets = {
dim: self.offsets[dim] for dim in self.dataset[var_name].dims
}

if shard_id is None:
assert self.split_vars
yield from iter_chunk_keys(offsets, vars={var_name})
else:
assert self.split_vars == (var_name is not None)
dim = self.sharded_dim
count = math.ceil(len(self.offsets[dim]) / self.shard_count)
dim_slice = slice(shard_id * count, (shard_id + 1) * count)
offsets = {**offsets, dim: offsets[dim][dim_slice]}
vars_ = {var_name} if self.split_vars else None
yield from iter_chunk_keys(offsets, vars=vars_)

def _shard_inputs(self) -> List[Tuple[Optional[int], Optional[str]]]:
"""Create inputs for sharded key iterators."""
if not self.split_vars:
return [(i, None) for i in range(self.shard_count)]

inputs = []
for name, variable in self.dataset.items():
if self.sharded_dim in variable.dims:
inputs.extend([(i, name) for i in range(self.shard_count)])
else:
inputs.append((None, name))
return inputs

def _key_to_chunks(self, key: Key) -> Iterator[Tuple[Key, xarray.Dataset]]:
"""Convert a Key into an in-memory (Key, xarray.Dataset) pair."""
sizes = {
dim: self.chunks[dim][self.offset_index[dim][offset]]
dim: self.expanded_chunks[dim][self.offset_index[dim][offset]]
for dim, offset in key.offsets.items()
}
slices = offsets_to_slices(key.offsets, sizes)
chunk = self.dataset.isel(slices)
dataset = self.dataset if key.vars is None else self.dataset[list(key.vars)]
chunk = dataset.isel(slices)
# Load the data, using a separate thread for each variable
num_threads = len(self.dataset.data_vars)
num_threads = len(self.dataset)
result = chunk.chunk().compute(num_workers=num_threads)
if self.split_vars:
for k in result:
yield key.replace(vars={k}), result[[k]]
else:
yield key, result
yield key, result

def expand(self, pcoll):
if self.shard_count is None:
# Create all keys on the machine launching the Beam pipeline. This is
# faster if the number of keys is small.
key_pcoll = pcoll | beam.Create(self._iter_all_keys())
else:
# Create keys in separate shards on Beam workers. This is more scalable.
key_pcoll = (
pcoll
| beam.Create(self._shard_inputs())
| beam.FlatMapTuple(self._iter_shard_keys)
| beam.Reshuffle()
)

return (
pcoll
| beam.Create(iter_chunk_keys(self.chunks))
key_pcoll
| threadmap.FlatThreadMap(
self._key_to_chunks, num_threads=self.num_threads
)
Expand Down
87 changes: 77 additions & 10 deletions xarray_beam/_src/core_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
"""Tests for xarray_beam._src.core."""

from absl.testing import absltest
from absl.testing import parameterized
import apache_beam as beam
import immutabledict
import numpy as np
Expand All @@ -32,7 +33,7 @@ def test_constructor(self):
key = xbeam.Key({'x': 0, 'y': 10})
self.assertIsInstance(key.offsets, immutabledict.immutabledict)
self.assertEqual(dict(key.offsets), {'x': 0, 'y': 10})
self.assertEqual(key.vars, None)
self.assertIsNone(key.vars)

key = xbeam.Key(vars={'foo'})
self.assertEqual(dict(key.offsets), {})
Expand Down Expand Up @@ -181,7 +182,7 @@ def test_offsets_to_slices_base(self):
class DatasetToChunksTest(test_util.TestCase):

def test_iter_chunk_keys(self):
actual = list(core.iter_chunk_keys({'x': (3, 3), 'y': (2, 2, 2)}))
actual = list(core.iter_chunk_keys({'x': (0, 3), 'y': (0, 2, 4)}))
expected = [
xbeam.Key({'x': 0, 'y': 0}),
xbeam.Key({'x': 0, 'y': 2}),
Expand Down Expand Up @@ -248,6 +249,12 @@ def test_dataset_to_chunks_multiple(self):
)
self.assertIdenticalChunks(actual, expected)

actual = (
test_util.EagerPipeline()
| xbeam.DatasetToChunks(dataset.chunk({'x': 3}), shard_keys_threshold=1)
)
self.assertIdenticalChunks(actual, expected)

def test_dataset_to_chunks_whole(self):
dataset = xarray.Dataset({'foo': ('x', np.arange(6))})
expected = [(xbeam.Key({'x': 0}), dataset)]
Expand All @@ -261,6 +268,7 @@ def test_dataset_to_chunks_whole(self):
test_util.EagerPipeline()
| xbeam.DatasetToChunks(dataset, chunks={})
)
expected = [(xbeam.Key({'x': 0}), dataset)]
self.assertIdenticalChunks(actual, expected)

def test_dataset_to_chunks_vars(self):
Expand All @@ -280,20 +288,79 @@ def test_dataset_to_chunks_vars(self):
)
self.assertIdenticalChunks(actual, expected)

@parameterized.parameters(
{'shard_keys_threshold': 1},
{'shard_keys_threshold': 2},
{'shard_keys_threshold': 10},
)
def test_dataset_to_chunks_split_with_different_dims(
self, shard_keys_threshold
):
dataset = xarray.Dataset({
'foo': (('x', 'y'), np.array([[1, 2, 3], [4, 5, 6]])),
'bar': ('x', np.array([1, 2])),
'baz': ('z', np.array([1, 2, 3])),
})
expected = [
(xbeam.Key({'x': 0, 'y': 0}, {'foo'}), dataset[['foo']].head(x=1)),
(xbeam.Key({'x': 0}, {'bar'}), dataset[['bar']].head(x=1)),
(xbeam.Key({'x': 1, 'y': 0}, {'foo'}), dataset[['foo']].tail(x=1)),
(xbeam.Key({'x': 1}, {'bar'}), dataset[['bar']].tail(x=1)),
(xbeam.Key({'z': 0}, {'baz'}), dataset[['baz']]),
]
actual = (
test_util.EagerPipeline()
| xbeam.DatasetToChunks(
dataset,
chunks={'x': 1},
split_vars=True,
shard_keys_threshold=shard_keys_threshold,
)
)
self.assertIdenticalChunks(actual, expected)

def test_dataset_to_chunks_empty(self):
dataset = xarray.Dataset()
expected = [(xbeam.Key({}), dataset)]
actual = (
test_util.EagerPipeline()
| xbeam.DatasetToChunks(dataset)
)
self.assertIdenticalChunks(actual, expected)

def test_key_count(self):
dataset = xarray.Dataset({
'foo': (('x', 'y'), np.array([[1, 2, 3], [4, 5, 6]])),
'bar': ('x', np.array([1, 2])),
'baz': ('z', np.array([1, 2, 3, 4, 5])),
})

to_chunks = xbeam.DatasetToChunks(dataset, chunks={'x': 1})
self.assertEqual(to_chunks._key_count(), 2)

to_chunks = xbeam.DatasetToChunks(dataset, chunks={'x': 1}, split_vars=True)
self.assertEqual(to_chunks._key_count(), 5)

to_chunks = xbeam.DatasetToChunks(dataset, chunks={'y': 1}, split_vars=True)
self.assertEqual(to_chunks._key_count(), 5)

to_chunks = xbeam.DatasetToChunks(dataset, chunks={'z': 1}, split_vars=True)
self.assertEqual(to_chunks._key_count(), 7)


class ValidateEachChunkTest(test_util.TestCase):

def test_unmatched_dimension_raises_error(self):
dataset = xarray.Dataset({'foo': ('x', np.arange(6))})
with self.assertRaises(ValueError) as e:
(
[(xbeam.Key({'x': 0, 'y': 0}), dataset)]
| xbeam.ValidateEachChunk()
[(xbeam.Key({'x': 0, 'y': 0}), dataset)]
| xbeam.ValidateEachChunk()
)
self.assertIn(
"Key offset(s) 'y' in Key(offsets={'x': 0, 'y': 0}, vars=None) not found in "
"Dataset dimensions",
e.exception.args[0]
"Key offset(s) 'y' in Key(offsets={'x': 0, 'y': 0}, vars=None) not "
"found in Dataset dimensions",
e.exception.args[0]
)

def test_unmatched_variables_raises_error(self):
Expand All @@ -304,9 +371,9 @@ def test_unmatched_variables_raises_error(self):
| xbeam.ValidateEachChunk()
)
self.assertIn(
"Key var(s) 'bar' in Key(offsets={'x': 0}, vars={'bar'}) not found in Dataset "
"data variables",
e.exception.args[0]
"Key var(s) 'bar' in Key(offsets={'x': 0}, vars={'bar'}) not found in "
"Dataset data variables",
e.exception.args[0]
)

def test_validate_chunks_compose_in_pipeline(self):
Expand Down
Loading

0 comments on commit b607f99

Please sign in to comment.