Skip to content

Commit

Permalink
Merge pull request google#31 from alxmrs:pangeo-fp
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 398295725
  • Loading branch information
Xarray-Beam authors committed Sep 22, 2021
2 parents c5ab737 + 0ee8980 commit ebfdbf0
Show file tree
Hide file tree
Showing 3 changed files with 359 additions and 0 deletions.
3 changes: 3 additions & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,9 @@
'absl-py',
'pandas',
'pytest',
'pangeo-forge-recipes',
'scipy',
'h5netcdf'
]

setuptools.setup(
Expand Down
151 changes: 151 additions & 0 deletions xarray_beam/_src/pangeo_forge.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,151 @@
# Copyright 2021 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""IO with Pangeo-Forge."""
import contextlib
import tempfile
from typing import (
Dict,
Iterator,
Optional,
Mapping,
Tuple,
)

import apache_beam as beam
import fsspec
import xarray
from apache_beam.io.filesystems import FileSystems

from xarray_beam._src import core, rechunk


def _zero_dimensions(dataset: xarray.Dataset) -> Mapping[str, int]:
return {dim: 0 for dim in dataset.dims.keys()}


def _expand_dimensions_by_key(
dataset: xarray.Dataset,
index: 'FilePatternIndex',
pattern: 'FilePattern'
) -> xarray.Dataset:
"""Expand the dimensions of the `Dataset` by offsets found in the `Key`."""
combine_dims_by_name = {
combine_dim.name: combine_dim for combine_dim in pattern.combine_dims
}
index_by_name = {
idx.name: idx for idx in index
}

if not combine_dims_by_name:
return dataset

for dim_key in index_by_name.keys():
# skip expanding dimensions if they already exist
if dim_key in dataset.dims:
continue

try:
combine_dim = combine_dims_by_name[dim_key]
except KeyError:
raise ValueError(
f"could not find CombineDim named {dim_key!r} in pattern {pattern!r}."
)

dim_val = combine_dim.keys[index_by_name[dim_key].index]
dataset = dataset.expand_dims(**{dim_key: [dim_val]})

return dataset


class FilePatternToChunks(beam.PTransform):
"""Open data described by a Pangeo-Forge `FilePattern` into keyed chunks."""

from pangeo_forge_recipes.patterns import FilePattern, FilePatternIndex

def __init__(
self,
pattern: 'FilePattern',
chunks: Optional[Mapping[str, int]] = None,
local_copy: bool = False,
xarray_open_kwargs: Optional[Dict] = None
):
"""Initialize FilePatternToChunks.
TODO(#29): Currently, `MergeDim`s are not supported.
Args:
pattern: a `FilePattern` describing a dataset.
chunks: split each open dataset into smaller chunks. If not set, the
transform will return one file per chunk.
local_copy: Open files from the pattern with local copies instead of a
buffered reader.
xarray_open_kwargs: keyword arguments to pass to `xarray.open_dataset()`.
"""
self.pattern = pattern
self.chunks = chunks
self.local_copy = local_copy
self.xarray_open_kwargs = xarray_open_kwargs or {}
self._max_size_idx = {}

if pattern.merge_dims:
raise ValueError("patterns with `MergeDim`s are not supported.")

@contextlib.contextmanager
def _open_dataset(self, path: str) -> xarray.Dataset:
"""Open as an XArray Dataset, sometimes with local caching."""
if self.local_copy:
with tempfile.TemporaryDirectory() as tmpdir:
local_file = fsspec.open_local(
f"simplecache::{path}",
simplecache={'cache_storage': tmpdir}
)
yield xarray.open_dataset(local_file, **self.xarray_open_kwargs)
else:
with FileSystems().open(path) as file:
yield xarray.open_dataset(file, **self.xarray_open_kwargs)

def _open_chunks(
self,
index: 'FilePatternIndex',
path: str
) -> Iterator[Tuple[core.Key, xarray.Dataset]]:
"""Open datasets into chunks with XArray."""
with self._open_dataset(path) as dataset:

dataset = _expand_dimensions_by_key(dataset, index, self.pattern)

if not self._max_size_idx:
self._max_size_idx = dataset.sizes

base_key = core.Key(_zero_dimensions(dataset)).with_offsets(
**{dim.name: self._max_size_idx[dim.name] * dim.index for dim in index}
)

num_threads = len(dataset.data_vars)

# If chunks is not set by the user, treat the dataset as a single chunk.
if self.chunks is None:
yield base_key, dataset.compute(num_workers=num_threads)
return

for new_key, chunk in rechunk.split_chunks(base_key, dataset,
self.chunks):
yield new_key, chunk.compute(num_workers=num_threads)

def expand(self, pcoll):
return (
pcoll
| beam.Create(list(self.pattern.items()))
| beam.FlatMapTuple(self._open_chunks)
)
205 changes: 205 additions & 0 deletions xarray_beam/_src/pangeo_forge_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,205 @@
# Copyright 2021 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for xarray_beam._src.pangeo."""

import contextlib
import itertools
import tempfile
from typing import Dict

import numpy as np
from absl.testing import parameterized
from pangeo_forge_recipes.patterns import (
FilePattern,
ConcatDim,
DimIndex,
CombineOp
)

from xarray_beam import split_chunks
from xarray_beam._src import core
from xarray_beam._src import test_util
from xarray_beam._src.pangeo_forge import (
FilePatternToChunks,
_expand_dimensions_by_key
)


class ExpandDimensionsByKeyTest(test_util.TestCase):

def setUp(self):
self.test_data = test_util.dummy_era5_surface_dataset()
self.level = ConcatDim("level", list(range(91, 100)))
self.pattern = FilePattern(lambda level: f"gs://dir/{level}.nc", self.level)

def test_expands_dimensions(self):
for i, (index, _) in enumerate(self.pattern.items()):
actual = _expand_dimensions_by_key(
self.test_data, index, self.pattern
)

expected_dims = dict(self.test_data.dims)
expected_dims.update({"level": 1})

self.assertEqual(expected_dims, dict(actual.dims))
self.assertEqual(np.array([self.level.keys[i]]), actual["level"])

def test_raises_error_when_dataset_is_not_found(self):
index = (DimIndex('boat', 0, 1, CombineOp.CONCAT),)
with self.assertRaisesRegex(ValueError, "boat"):
_expand_dimensions_by_key(
self.test_data, index, self.pattern
)


class FilePatternToChunksTest(test_util.TestCase):

def setUp(self):
self.test_data = test_util.dummy_era5_surface_dataset()

@contextlib.contextmanager
def pattern_from_testdata(self) -> FilePattern:
"""Produces a FilePattern for a temporary NetCDF file of test data."""
with tempfile.TemporaryDirectory() as tmpdir:
target = f'{tmpdir}/era5.nc'
self.test_data.to_netcdf(target)
yield FilePattern(lambda: target)

@contextlib.contextmanager
def multifile_pattern(
self,
time_step: int = 479,
longitude_step: int = 47
) -> FilePattern:
"""Produces a FilePattern for a temporary NetCDF file of test data."""
time_dim = ConcatDim('time', list(range(0, 360 * 4, time_step)))
longitude_dim = ConcatDim('longitude', list(range(0, 144, longitude_step)))

with tempfile.TemporaryDirectory() as tmpdir:
def make_path(time: int, longitude: int) -> str:
return f'{tmpdir}/era5-{time}-{longitude}.nc'

for time in time_dim.keys:
for long in longitude_dim.keys:
chunk = self.test_data.isel(
time=slice(time, time + time_step),
longitude=slice(long, long + longitude_step)
)
chunk.to_netcdf(make_path(time, long))
yield FilePattern(make_path, time_dim, longitude_dim)

def test_returns_single_dataset(self):
expected = [
(core.Key({"time": 0, "latitude": 0, "longitude": 0}), self.test_data)
]
with self.pattern_from_testdata() as pattern:
actual = test_util.EagerPipeline() | FilePatternToChunks(pattern)

self.assertAllCloseChunks(actual, expected)

def test_single_subchunks_returns_multiple_datasets(self):
with self.pattern_from_testdata() as pattern:
result = (
test_util.EagerPipeline()
| FilePatternToChunks(pattern, chunks={"longitude": 48})
)

expected = [
(
core.Key({"time": 0, "latitude": 0, "longitude": i}),
self.test_data.isel(longitude=slice(i, i + 48))
)
for i in range(0, 144, 48)
]
self.assertAllCloseChunks(result, expected)

def test_multiple_subchunks_returns_multiple_datasets(self):
with self.pattern_from_testdata() as pattern:
result = (
test_util.EagerPipeline()
| FilePatternToChunks(pattern,
chunks={"longitude": 48, "latitude": 24})
)

expected = [
(
core.Key({"time": 0, "longitude": o, "latitude": a}),
self.test_data.isel(longitude=slice(o, o + 48),
latitude=slice(a, a + 24))
)
for o, a in itertools.product(range(0, 144, 48), range(0, 73, 24))
]

self.assertAllCloseChunks(result, expected)

@parameterized.parameters(
dict(time_step=479, longitude_step=47),
dict(time_step=365, longitude_step=72),
dict(time_step=292, longitude_step=71),
dict(time_step=291, longitude_step=48),
)
def test_multiple_datasets_returns_multiple_datasets(
self,
time_step: int,
longitude_step: int
):
expected = [
(
core.Key({"time": t, "latitude": 0, "longitude": o}),
self.test_data.isel(
time=slice(t, t + time_step),
longitude=slice(o, o + longitude_step)
)
) for t, o in itertools.product(
range(0, 360 * 4, time_step),
range(0, 144, longitude_step)
)
]
with self.multifile_pattern(time_step, longitude_step) as pattern:
actual = test_util.EagerPipeline() | FilePatternToChunks(pattern)

self.assertAllCloseChunks(actual, expected)

@parameterized.parameters(
dict(time_step=365, longitude_step=72, chunks={"latitude": 36}),
dict(time_step=365, longitude_step=72, chunks={"longitude": 36}),
dict(time_step=365, longitude_step=72,
chunks={"longitude": 36, "latitude": 66}),
)
def test_multiple_datasets_with_subchunks_returns_multiple_datasets(
self,
time_step: int,
longitude_step: int,
chunks: Dict[str, int],
):

expected = []
for t, o in itertools.product(range(0, 360 * 4, time_step),
range(0, 144, longitude_step)):
expected.extend(
split_chunks(
core.Key({"latitude": 0, "longitude": o, "time": t}),
self.test_data.isel(
time=slice(t, t + time_step),
longitude=slice(o, o + longitude_step)
),
chunks)
)
with self.multifile_pattern(time_step, longitude_step) as pattern:
actual = test_util.EagerPipeline() | FilePatternToChunks(
pattern,
chunks=chunks
)

self.assertAllCloseChunks(actual, expected)

0 comments on commit ebfdbf0

Please sign in to comment.