Skip to content

Commit

Permalink
Merge pull request #110 from roocs/i106-enable-read-kerchunk
Browse files Browse the repository at this point in the history
I106 enable read kerchunk
  • Loading branch information
cehbrecht authored Nov 29, 2023
2 parents 1af1554 + bd57eef commit 7f616f9
Show file tree
Hide file tree
Showing 6 changed files with 102 additions and 6 deletions.
4 changes: 4 additions & 0 deletions HISTORY.rst
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
Version History
===============

Next
----
* Updated `xarray_utils` module to support reading `kerchunk` files (#106).

v0.6.5 (2023-11-09)
-------------------

Expand Down
4 changes: 4 additions & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
aiohttp
cftime
numpy>=1.16
xarray>=0.15
Expand All @@ -6,3 +7,6 @@ netCDF4>=1.4
python-dateutil>=2.8.1
cf-xarray>=0.3.1,<=0.8.4; python_version == '3.8'
cf-xarray>=0.3.1; python_version >= '3.9'
fsspec
zarr
zstandard
54 changes: 51 additions & 3 deletions roocs_utils/xarray_utils/xarray_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,14 @@
import cftime
import numpy as np
import xarray as xr
import fsspec

from roocs_utils.project_utils import dset_to_filepaths

known_coord_types = ["time", "level", "latitude", "longitude", "realization"]

KERCHUNK_EXTS = [".json", ".zst", ".zstd"]


def _patch_time_encoding(ds, file_list, **kwargs):
"""
Expand Down Expand Up @@ -66,6 +69,45 @@ def _get_kwargs_for_opener(otype, **kwargs):
return args


def is_kerchunk_file(dset):
"""
Returns a boolean based on reading the file extension.
"""
if not isinstance(dset, str):
return False

return os.path.splitext(dset)[-1] in KERCHUNK_EXTS


def _open_as_kerchunk(dset, **kwargs):
"""
Open the dataset `dset` as a Kerchunk file. Return an Xarray Dataset.
"""
compression = (
"zstd"
if dset.split(".")[-1].startswith("zst")
else kwargs.get("compression", None)
)
remote_options = kwargs.get("remote_options", {})
remote_protocol = kwargs.get("remote_protocol", None)

mapper = fsspec.get_mapper(
"reference://",
fo=dset,
target_options={"compression": compression},
remote_options=remote_options,
remote_protocol=remote_protocol,
)

# Create a copy of kwargs and remove mapper-specific values
kw = kwargs.copy()
for key in ("compression", "remote_options", "remote_protocol"):
if key in kw:
del kw[key]

return xr.open_zarr(mapper, consolidated=False, **kw)


def open_xr_dataset(dset, **kwargs):
"""
Opens an xarray dataset from a dataset input.
Expand All @@ -80,10 +122,16 @@ def open_xr_dataset(dset, **kwargs):
single_file_kwargs = _get_kwargs_for_opener("single", **kwargs)
multi_file_kwargs = _get_kwargs_for_opener("multi", **kwargs)

# Force the value of dset to be a list if not a list or tuple
# Assume that a JSON or ZST/ZSTD file is kerchunk
if type(dset) not in (list, tuple):
# use force=True to allow all file paths to pass through DatasetMapper
dset = dset_to_filepaths(dset, force=True)
# Assume that a JSON or ZST/ZSTD file is kerchunk
if is_kerchunk_file(dset):
return _open_as_kerchunk(dset, **single_file_kwargs)

else:
# Force the value of dset to be a list if not a list or tuple
# use force=True to allow all file paths to pass through DatasetMapper
dset = dset_to_filepaths(dset, force=True)

# If an empty sequence, then raise an Exception
if len(dset) == 0:
Expand Down
6 changes: 6 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,12 @@
"master/test_data/pool/data/CORDEX/data/cordex/output/AFR-22/GERICS/MPI-M-MPI-ESM-LR/historical/r1i1p1/GERICS-REMO2015/v1/day/tas/v20201015/*.nc",
).as_posix()

CMIP6_KERCHUNK_HTTPS_OPEN_JSON = (
"https://gws-access.jasmin.ac.uk/public/cmip6_prep/eodh-eocis/kc-indexes-cmip6-http-v1/"
"CMIP6.CMIP.MOHC.UKESM1-1-LL.1pctCO2.r1i1p1f2.Amon.tasmax.gn.v20220513.json"
)
CMIP6_KERCHUNK_HTTPS_OPEN_ZST = CMIP6_KERCHUNK_HTTPS_OPEN_JSON + ".zst"


@pytest.fixture
def load_test_data():
Expand Down
4 changes: 3 additions & 1 deletion tests/test_xarray_utils/test_get_main_var.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,9 @@
from tests.conftest import CMIP5_TAS


@pytest.mark.skipif(os.path.isdir("/gws") is False, reason="data not available")
@pytest.mark.skipif(
os.path.isdir("/gws/nopw/j04/cp4cds1_vol1") is False, reason="data not available"
)
def test_get_main_var():
data = (
"/gws/nopw/j04/cp4cds1_vol1/data"
Expand Down
36 changes: 34 additions & 2 deletions tests/test_xarray_utils/test_open_xr_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,12 @@
import xarray as xr

from roocs_utils.xarray_utils.xarray_utils import open_xr_dataset
from tests.conftest import C3S_CMIP5_TAS
from tests.conftest import CMIP5_TAS_EC_EARTH
from tests.conftest import (
C3S_CMIP5_TAS,
CMIP5_TAS_EC_EARTH,
CMIP6_KERCHUNK_HTTPS_OPEN_JSON,
CMIP6_KERCHUNK_HTTPS_OPEN_ZST,
)


def test_open_xr_dataset(load_test_data):
Expand All @@ -27,3 +31,31 @@ def test_open_xr_dataset_retains_time_encoding(load_test_data):
kwargs = {"use_cftime": True, "decode_timedelta": False, "combine": "by_coords"}
ds = xr.open_mfdataset(glob.glob(dset), **kwargs)
assert ds.time.encoding == {}


def _common_test_open_xr_dataset_kerchunk(uri):
ds = open_xr_dataset(uri)
assert isinstance(ds, xr.Dataset)
assert "tasmax" in ds

# Also test time encoding is retained
assert hasattr(ds, "time")
assert ds.time.encoding.get("units") == "days since 1850-01-01"

return ds


def test_open_xr_dataset_kerchunk_json(load_test_data):
ds = _common_test_open_xr_dataset_kerchunk(CMIP6_KERCHUNK_HTTPS_OPEN_JSON)


def test_open_xr_dataset_kerchunk_zst(load_test_data):
ds = _common_test_open_xr_dataset_kerchunk(CMIP6_KERCHUNK_HTTPS_OPEN_ZST)


def test_open_xr_dataset_kerchunk_compare_json_vs_zst(load_test_data):
ds1 = _common_test_open_xr_dataset_kerchunk(CMIP6_KERCHUNK_HTTPS_OPEN_JSON)
ds2 = _common_test_open_xr_dataset_kerchunk(CMIP6_KERCHUNK_HTTPS_OPEN_ZST)

diff = ds1.isel(time=slice(0, 2)) - ds2.isel(time=slice(0, 2))
assert diff.max() == diff.min() == 0.0

0 comments on commit 7f616f9

Please sign in to comment.