forked from google/xarray-beam
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request google#31 from alxmrs:pangeo-fp
PiperOrigin-RevId: 398295725
- Loading branch information
Showing
3 changed files
with
359 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -35,6 +35,9 @@ | |
'absl-py', | ||
'pandas', | ||
'pytest', | ||
'pangeo-forge-recipes', | ||
'scipy', | ||
'h5netcdf' | ||
] | ||
|
||
setuptools.setup( | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,151 @@ | ||
# Copyright 2021 Google LLC | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# https://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
"""IO with Pangeo-Forge.""" | ||
import contextlib | ||
import tempfile | ||
from typing import ( | ||
Dict, | ||
Iterator, | ||
Optional, | ||
Mapping, | ||
Tuple, | ||
) | ||
|
||
import apache_beam as beam | ||
import fsspec | ||
import xarray | ||
from apache_beam.io.filesystems import FileSystems | ||
|
||
from xarray_beam._src import core, rechunk | ||
|
||
|
||
def _zero_dimensions(dataset: xarray.Dataset) -> Mapping[str, int]: | ||
return {dim: 0 for dim in dataset.dims.keys()} | ||
|
||
|
||
def _expand_dimensions_by_key( | ||
dataset: xarray.Dataset, | ||
index: 'FilePatternIndex', | ||
pattern: 'FilePattern' | ||
) -> xarray.Dataset: | ||
"""Expand the dimensions of the `Dataset` by offsets found in the `Key`.""" | ||
combine_dims_by_name = { | ||
combine_dim.name: combine_dim for combine_dim in pattern.combine_dims | ||
} | ||
index_by_name = { | ||
idx.name: idx for idx in index | ||
} | ||
|
||
if not combine_dims_by_name: | ||
return dataset | ||
|
||
for dim_key in index_by_name.keys(): | ||
# skip expanding dimensions if they already exist | ||
if dim_key in dataset.dims: | ||
continue | ||
|
||
try: | ||
combine_dim = combine_dims_by_name[dim_key] | ||
except KeyError: | ||
raise ValueError( | ||
f"could not find CombineDim named {dim_key!r} in pattern {pattern!r}." | ||
) | ||
|
||
dim_val = combine_dim.keys[index_by_name[dim_key].index] | ||
dataset = dataset.expand_dims(**{dim_key: [dim_val]}) | ||
|
||
return dataset | ||
|
||
|
||
class FilePatternToChunks(beam.PTransform): | ||
"""Open data described by a Pangeo-Forge `FilePattern` into keyed chunks.""" | ||
|
||
from pangeo_forge_recipes.patterns import FilePattern, FilePatternIndex | ||
|
||
def __init__( | ||
self, | ||
pattern: 'FilePattern', | ||
chunks: Optional[Mapping[str, int]] = None, | ||
local_copy: bool = False, | ||
xarray_open_kwargs: Optional[Dict] = None | ||
): | ||
"""Initialize FilePatternToChunks. | ||
TODO(#29): Currently, `MergeDim`s are not supported. | ||
Args: | ||
pattern: a `FilePattern` describing a dataset. | ||
chunks: split each open dataset into smaller chunks. If not set, the | ||
transform will return one file per chunk. | ||
local_copy: Open files from the pattern with local copies instead of a | ||
buffered reader. | ||
xarray_open_kwargs: keyword arguments to pass to `xarray.open_dataset()`. | ||
""" | ||
self.pattern = pattern | ||
self.chunks = chunks | ||
self.local_copy = local_copy | ||
self.xarray_open_kwargs = xarray_open_kwargs or {} | ||
self._max_size_idx = {} | ||
|
||
if pattern.merge_dims: | ||
raise ValueError("patterns with `MergeDim`s are not supported.") | ||
|
||
@contextlib.contextmanager | ||
def _open_dataset(self, path: str) -> xarray.Dataset: | ||
"""Open as an XArray Dataset, sometimes with local caching.""" | ||
if self.local_copy: | ||
with tempfile.TemporaryDirectory() as tmpdir: | ||
local_file = fsspec.open_local( | ||
f"simplecache::{path}", | ||
simplecache={'cache_storage': tmpdir} | ||
) | ||
yield xarray.open_dataset(local_file, **self.xarray_open_kwargs) | ||
else: | ||
with FileSystems().open(path) as file: | ||
yield xarray.open_dataset(file, **self.xarray_open_kwargs) | ||
|
||
def _open_chunks( | ||
self, | ||
index: 'FilePatternIndex', | ||
path: str | ||
) -> Iterator[Tuple[core.Key, xarray.Dataset]]: | ||
"""Open datasets into chunks with XArray.""" | ||
with self._open_dataset(path) as dataset: | ||
|
||
dataset = _expand_dimensions_by_key(dataset, index, self.pattern) | ||
|
||
if not self._max_size_idx: | ||
self._max_size_idx = dataset.sizes | ||
|
||
base_key = core.Key(_zero_dimensions(dataset)).with_offsets( | ||
**{dim.name: self._max_size_idx[dim.name] * dim.index for dim in index} | ||
) | ||
|
||
num_threads = len(dataset.data_vars) | ||
|
||
# If chunks is not set by the user, treat the dataset as a single chunk. | ||
if self.chunks is None: | ||
yield base_key, dataset.compute(num_workers=num_threads) | ||
return | ||
|
||
for new_key, chunk in rechunk.split_chunks(base_key, dataset, | ||
self.chunks): | ||
yield new_key, chunk.compute(num_workers=num_threads) | ||
|
||
def expand(self, pcoll): | ||
return ( | ||
pcoll | ||
| beam.Create(list(self.pattern.items())) | ||
| beam.FlatMapTuple(self._open_chunks) | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,205 @@ | ||
# Copyright 2021 Google LLC | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# https://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
"""Tests for xarray_beam._src.pangeo.""" | ||
|
||
import contextlib | ||
import itertools | ||
import tempfile | ||
from typing import Dict | ||
|
||
import numpy as np | ||
from absl.testing import parameterized | ||
from pangeo_forge_recipes.patterns import ( | ||
FilePattern, | ||
ConcatDim, | ||
DimIndex, | ||
CombineOp | ||
) | ||
|
||
from xarray_beam import split_chunks | ||
from xarray_beam._src import core | ||
from xarray_beam._src import test_util | ||
from xarray_beam._src.pangeo_forge import ( | ||
FilePatternToChunks, | ||
_expand_dimensions_by_key | ||
) | ||
|
||
|
||
class ExpandDimensionsByKeyTest(test_util.TestCase): | ||
|
||
def setUp(self): | ||
self.test_data = test_util.dummy_era5_surface_dataset() | ||
self.level = ConcatDim("level", list(range(91, 100))) | ||
self.pattern = FilePattern(lambda level: f"gs://dir/{level}.nc", self.level) | ||
|
||
def test_expands_dimensions(self): | ||
for i, (index, _) in enumerate(self.pattern.items()): | ||
actual = _expand_dimensions_by_key( | ||
self.test_data, index, self.pattern | ||
) | ||
|
||
expected_dims = dict(self.test_data.dims) | ||
expected_dims.update({"level": 1}) | ||
|
||
self.assertEqual(expected_dims, dict(actual.dims)) | ||
self.assertEqual(np.array([self.level.keys[i]]), actual["level"]) | ||
|
||
def test_raises_error_when_dataset_is_not_found(self): | ||
index = (DimIndex('boat', 0, 1, CombineOp.CONCAT),) | ||
with self.assertRaisesRegex(ValueError, "boat"): | ||
_expand_dimensions_by_key( | ||
self.test_data, index, self.pattern | ||
) | ||
|
||
|
||
class FilePatternToChunksTest(test_util.TestCase): | ||
|
||
def setUp(self): | ||
self.test_data = test_util.dummy_era5_surface_dataset() | ||
|
||
@contextlib.contextmanager | ||
def pattern_from_testdata(self) -> FilePattern: | ||
"""Produces a FilePattern for a temporary NetCDF file of test data.""" | ||
with tempfile.TemporaryDirectory() as tmpdir: | ||
target = f'{tmpdir}/era5.nc' | ||
self.test_data.to_netcdf(target) | ||
yield FilePattern(lambda: target) | ||
|
||
@contextlib.contextmanager | ||
def multifile_pattern( | ||
self, | ||
time_step: int = 479, | ||
longitude_step: int = 47 | ||
) -> FilePattern: | ||
"""Produces a FilePattern for a temporary NetCDF file of test data.""" | ||
time_dim = ConcatDim('time', list(range(0, 360 * 4, time_step))) | ||
longitude_dim = ConcatDim('longitude', list(range(0, 144, longitude_step))) | ||
|
||
with tempfile.TemporaryDirectory() as tmpdir: | ||
def make_path(time: int, longitude: int) -> str: | ||
return f'{tmpdir}/era5-{time}-{longitude}.nc' | ||
|
||
for time in time_dim.keys: | ||
for long in longitude_dim.keys: | ||
chunk = self.test_data.isel( | ||
time=slice(time, time + time_step), | ||
longitude=slice(long, long + longitude_step) | ||
) | ||
chunk.to_netcdf(make_path(time, long)) | ||
yield FilePattern(make_path, time_dim, longitude_dim) | ||
|
||
def test_returns_single_dataset(self): | ||
expected = [ | ||
(core.Key({"time": 0, "latitude": 0, "longitude": 0}), self.test_data) | ||
] | ||
with self.pattern_from_testdata() as pattern: | ||
actual = test_util.EagerPipeline() | FilePatternToChunks(pattern) | ||
|
||
self.assertAllCloseChunks(actual, expected) | ||
|
||
def test_single_subchunks_returns_multiple_datasets(self): | ||
with self.pattern_from_testdata() as pattern: | ||
result = ( | ||
test_util.EagerPipeline() | ||
| FilePatternToChunks(pattern, chunks={"longitude": 48}) | ||
) | ||
|
||
expected = [ | ||
( | ||
core.Key({"time": 0, "latitude": 0, "longitude": i}), | ||
self.test_data.isel(longitude=slice(i, i + 48)) | ||
) | ||
for i in range(0, 144, 48) | ||
] | ||
self.assertAllCloseChunks(result, expected) | ||
|
||
def test_multiple_subchunks_returns_multiple_datasets(self): | ||
with self.pattern_from_testdata() as pattern: | ||
result = ( | ||
test_util.EagerPipeline() | ||
| FilePatternToChunks(pattern, | ||
chunks={"longitude": 48, "latitude": 24}) | ||
) | ||
|
||
expected = [ | ||
( | ||
core.Key({"time": 0, "longitude": o, "latitude": a}), | ||
self.test_data.isel(longitude=slice(o, o + 48), | ||
latitude=slice(a, a + 24)) | ||
) | ||
for o, a in itertools.product(range(0, 144, 48), range(0, 73, 24)) | ||
] | ||
|
||
self.assertAllCloseChunks(result, expected) | ||
|
||
@parameterized.parameters( | ||
dict(time_step=479, longitude_step=47), | ||
dict(time_step=365, longitude_step=72), | ||
dict(time_step=292, longitude_step=71), | ||
dict(time_step=291, longitude_step=48), | ||
) | ||
def test_multiple_datasets_returns_multiple_datasets( | ||
self, | ||
time_step: int, | ||
longitude_step: int | ||
): | ||
expected = [ | ||
( | ||
core.Key({"time": t, "latitude": 0, "longitude": o}), | ||
self.test_data.isel( | ||
time=slice(t, t + time_step), | ||
longitude=slice(o, o + longitude_step) | ||
) | ||
) for t, o in itertools.product( | ||
range(0, 360 * 4, time_step), | ||
range(0, 144, longitude_step) | ||
) | ||
] | ||
with self.multifile_pattern(time_step, longitude_step) as pattern: | ||
actual = test_util.EagerPipeline() | FilePatternToChunks(pattern) | ||
|
||
self.assertAllCloseChunks(actual, expected) | ||
|
||
@parameterized.parameters( | ||
dict(time_step=365, longitude_step=72, chunks={"latitude": 36}), | ||
dict(time_step=365, longitude_step=72, chunks={"longitude": 36}), | ||
dict(time_step=365, longitude_step=72, | ||
chunks={"longitude": 36, "latitude": 66}), | ||
) | ||
def test_multiple_datasets_with_subchunks_returns_multiple_datasets( | ||
self, | ||
time_step: int, | ||
longitude_step: int, | ||
chunks: Dict[str, int], | ||
): | ||
|
||
expected = [] | ||
for t, o in itertools.product(range(0, 360 * 4, time_step), | ||
range(0, 144, longitude_step)): | ||
expected.extend( | ||
split_chunks( | ||
core.Key({"latitude": 0, "longitude": o, "time": t}), | ||
self.test_data.isel( | ||
time=slice(t, t + time_step), | ||
longitude=slice(o, o + longitude_step) | ||
), | ||
chunks) | ||
) | ||
with self.multifile_pattern(time_step, longitude_step) as pattern: | ||
actual = test_util.EagerPipeline() | FilePatternToChunks( | ||
pattern, | ||
chunks=chunks | ||
) | ||
|
||
self.assertAllCloseChunks(actual, expected) |