diff --git a/HISTORY.rst b/HISTORY.rst index 2203057..673527b 100644 --- a/HISTORY.rst +++ b/HISTORY.rst @@ -1,6 +1,10 @@ Version History =============== +Next +---- +* Updated `xarray_utils` module to support reading `kerchunk` files (#106). + v0.6.5 (2023-11-09) ------------------- diff --git a/requirements.txt b/requirements.txt index f51cbb4..2488887 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,3 +1,4 @@ +aiohttp cftime numpy>=1.16 xarray>=0.15 @@ -6,3 +7,6 @@ netCDF4>=1.4 python-dateutil>=2.8.1 cf-xarray>=0.3.1,<=0.8.4; python_version == '3.8' cf-xarray>=0.3.1; python_version >= '3.9' +fsspec +zarr +zstandard diff --git a/roocs_utils/xarray_utils/xarray_utils.py b/roocs_utils/xarray_utils/xarray_utils.py index 231026c..d3fb6c2 100644 --- a/roocs_utils/xarray_utils/xarray_utils.py +++ b/roocs_utils/xarray_utils/xarray_utils.py @@ -7,11 +7,14 @@ import cftime import numpy as np import xarray as xr +import fsspec from roocs_utils.project_utils import dset_to_filepaths known_coord_types = ["time", "level", "latitude", "longitude", "realization"] +KERCHUNK_EXTS = [".json", ".zst", ".zstd"] + def _patch_time_encoding(ds, file_list, **kwargs): """ @@ -66,6 +69,45 @@ def _get_kwargs_for_opener(otype, **kwargs): return args +def is_kerchunk_file(dset): + """ + Returns a boolean based on reading the file extension. + """ + if not isinstance(dset, str): + return False + + return os.path.splitext(dset)[-1] in KERCHUNK_EXTS + + +def _open_as_kerchunk(dset, **kwargs): + """ + Open the dataset `dset` as a Kerchunk file. Return an Xarray Dataset. + """ + compression = ( + "zstd" + if dset.split(".")[-1].startswith("zst") + else kwargs.get("compression", None) + ) + remote_options = kwargs.get("remote_options", {}) + remote_protocol = kwargs.get("remote_protocol", None) + + mapper = fsspec.get_mapper( + "reference://", + fo=dset, + target_options={"compression": compression}, + remote_options=remote_options, + remote_protocol=remote_protocol, + ) + + # Create a copy of kwargs and remove mapper-specific values + kw = kwargs.copy() + for key in ("compression", "remote_options", "remote_protocol"): + if key in kw: + del kw[key] + + return xr.open_zarr(mapper, consolidated=False, **kw) + + def open_xr_dataset(dset, **kwargs): """ Opens an xarray dataset from a dataset input. @@ -80,10 +122,16 @@ def open_xr_dataset(dset, **kwargs): single_file_kwargs = _get_kwargs_for_opener("single", **kwargs) multi_file_kwargs = _get_kwargs_for_opener("multi", **kwargs) - # Force the value of dset to be a list if not a list or tuple + # Assume that a JSON or ZST/ZSTD file is kerchunk if type(dset) not in (list, tuple): - # use force=True to allow all file paths to pass through DatasetMapper - dset = dset_to_filepaths(dset, force=True) + # Assume that a JSON or ZST/ZSTD file is kerchunk + if is_kerchunk_file(dset): + return _open_as_kerchunk(dset, **single_file_kwargs) + + else: + # Force the value of dset to be a list if not a list or tuple + # use force=True to allow all file paths to pass through DatasetMapper + dset = dset_to_filepaths(dset, force=True) # If an empty sequence, then raise an Exception if len(dset) == 0: diff --git a/tests/conftest.py b/tests/conftest.py index 2b95d44..3affeef 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -42,6 +42,12 @@ "master/test_data/pool/data/CORDEX/data/cordex/output/AFR-22/GERICS/MPI-M-MPI-ESM-LR/historical/r1i1p1/GERICS-REMO2015/v1/day/tas/v20201015/*.nc", ).as_posix() +CMIP6_KERCHUNK_HTTPS_OPEN_JSON = ( + "https://gws-access.jasmin.ac.uk/public/cmip6_prep/eodh-eocis/kc-indexes-cmip6-http-v1/" + "CMIP6.CMIP.MOHC.UKESM1-1-LL.1pctCO2.r1i1p1f2.Amon.tasmax.gn.v20220513.json" +) +CMIP6_KERCHUNK_HTTPS_OPEN_ZST = CMIP6_KERCHUNK_HTTPS_OPEN_JSON + ".zst" + @pytest.fixture def load_test_data(): diff --git a/tests/test_xarray_utils/test_get_main_var.py b/tests/test_xarray_utils/test_get_main_var.py index 809c144..3b9b6c5 100644 --- a/tests/test_xarray_utils/test_get_main_var.py +++ b/tests/test_xarray_utils/test_get_main_var.py @@ -7,7 +7,9 @@ from tests.conftest import CMIP5_TAS -@pytest.mark.skipif(os.path.isdir("/gws") is False, reason="data not available") +@pytest.mark.skipif( + os.path.isdir("/gws/nopw/j04/cp4cds1_vol1") is False, reason="data not available" +) def test_get_main_var(): data = ( "/gws/nopw/j04/cp4cds1_vol1/data" diff --git a/tests/test_xarray_utils/test_open_xr_dataset.py b/tests/test_xarray_utils/test_open_xr_dataset.py index 79099aa..540e9a0 100644 --- a/tests/test_xarray_utils/test_open_xr_dataset.py +++ b/tests/test_xarray_utils/test_open_xr_dataset.py @@ -5,8 +5,12 @@ import xarray as xr from roocs_utils.xarray_utils.xarray_utils import open_xr_dataset -from tests.conftest import C3S_CMIP5_TAS -from tests.conftest import CMIP5_TAS_EC_EARTH +from tests.conftest import ( + C3S_CMIP5_TAS, + CMIP5_TAS_EC_EARTH, + CMIP6_KERCHUNK_HTTPS_OPEN_JSON, + CMIP6_KERCHUNK_HTTPS_OPEN_ZST, +) def test_open_xr_dataset(load_test_data): @@ -27,3 +31,31 @@ def test_open_xr_dataset_retains_time_encoding(load_test_data): kwargs = {"use_cftime": True, "decode_timedelta": False, "combine": "by_coords"} ds = xr.open_mfdataset(glob.glob(dset), **kwargs) assert ds.time.encoding == {} + + +def _common_test_open_xr_dataset_kerchunk(uri): + ds = open_xr_dataset(uri) + assert isinstance(ds, xr.Dataset) + assert "tasmax" in ds + + # Also test time encoding is retained + assert hasattr(ds, "time") + assert ds.time.encoding.get("units") == "days since 1850-01-01" + + return ds + + +def test_open_xr_dataset_kerchunk_json(load_test_data): + ds = _common_test_open_xr_dataset_kerchunk(CMIP6_KERCHUNK_HTTPS_OPEN_JSON) + + +def test_open_xr_dataset_kerchunk_zst(load_test_data): + ds = _common_test_open_xr_dataset_kerchunk(CMIP6_KERCHUNK_HTTPS_OPEN_ZST) + + +def test_open_xr_dataset_kerchunk_compare_json_vs_zst(load_test_data): + ds1 = _common_test_open_xr_dataset_kerchunk(CMIP6_KERCHUNK_HTTPS_OPEN_JSON) + ds2 = _common_test_open_xr_dataset_kerchunk(CMIP6_KERCHUNK_HTTPS_OPEN_ZST) + + diff = ds1.isel(time=slice(0, 2)) - ds2.isel(time=slice(0, 2)) + assert diff.max() == diff.min() == 0.0