Skip to content

Commit

Permalink
Merge pull request #1037 from xcube-dev/konstntokas-1030-add_data_typ…
Browse files Browse the repository at this point in the history
…e_open_data_store_framework

Add `data_type` to `open_data` method in `DataStore` class
  • Loading branch information
konstntokas authored Jul 10, 2024
2 parents 775c18e + 43c1f3d commit 7809dd8
Show file tree
Hide file tree
Showing 9 changed files with 189 additions and 41 deletions.
6 changes: 6 additions & 0 deletions CHANGES.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,12 @@
`pyproject.toml.` (related to #992)
* Normalisation with `xcube.core.normalize.normalize_dataset` fails when chunk encoding
must be updated (#1033)
* The `open_data` method of xcube's default `xcube.core.store.DataStore` implementations
now supports a keyword argument `data_type`, which determines the
data type of the return value. Note that `opener_id` includes the `data_type`
at its first position and will override the `date_type` argument.
To preserve backward compatibility, the keyword argument `data_type`
has not yet been added to the `open_data()` method arguments. (#1030)

## Changes in 1.6.0

Expand Down
1 change: 1 addition & 0 deletions environment.yml
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ dependencies:
- zarr >=2.11
# Testing
- flake8 >=3.7
- kerchunk
- moto >=4
- pytest >=4.4
- pytest-cov >=2.6
Expand Down
1 change: 1 addition & 0 deletions rtd-environment.yml
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ dependencies:
- python-blosc
# Testing
- flake8 >=3.7
- kerchunk
- moto >=4
- pytest >=4.4
- pytest-cov >=2.6
Expand Down
2 changes: 1 addition & 1 deletion test/core/gen2/local/test_usercode.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ def test_class_bad_params(self):
"\n"
"Failed validating 'minimum'"
" in schema['properties']['value']:\n"
" {'minimum': 1, 'type': 'integer'}\n"
" {'type': 'integer', 'minimum': 1}\n"
"\n"
"On instance['value']:\n"
" 0",
Expand Down
80 changes: 74 additions & 6 deletions test/core/store/fs/test_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,39 @@ def test_dataset_zarr(self):
expected_descriptor_type=DatasetDescriptor,
assert_data_ok=self._assert_zarr_store_direct_ok,
)
self._assert_dataset_supported(
data_store,
filename_ext=".zarr",
requested_dtype_alias="dataset",
expected_dtype_aliases={"dataset"},
expected_return_type=xr.Dataset,
expected_descriptor_type=DatasetDescriptor,
assert_data_ok=self._assert_zarr_store_direct_ok,
)
self._assert_dataset_supported(
data_store,
filename_ext=".zarr",
requested_dtype_alias="mldataset",
expected_dtype_aliases={"dataset"},
expected_return_type=xr.Dataset,
expected_descriptor_type=DatasetDescriptor,
assert_data_ok=self._assert_zarr_store_direct_ok,
assert_warnings=True,
warning_msg=(
"No data opener found for format 'zarr' and data type 'mldataset'. "
"Data type is changed to the default data type 'dataset'."
),
)
self._assert_dataset_supported(
data_store,
filename_ext=".zarr",
requested_dtype_alias="mldataset",
expected_dtype_aliases={"dataset"},
expected_return_type=xr.Dataset,
expected_descriptor_type=DatasetDescriptor,
opener_id=f"dataset:zarr:{data_store.protocol}",
assert_data_ok=self._assert_zarr_store_direct_ok,
)

def test_dataset_netcdf(self):
data_store = self.create_data_store()
Expand All @@ -164,6 +197,20 @@ def test_dataset_netcdf(self):
expected_descriptor_type=DatasetDescriptor,
assert_data_ok=self._assert_zarr_store_generic_ok,
)
self._assert_dataset_supported(
data_store,
filename_ext=".nc",
requested_dtype_alias="mldataset",
expected_dtype_aliases={"dataset"},
expected_return_type=xr.Dataset,
expected_descriptor_type=DatasetDescriptor,
assert_data_ok=self._assert_zarr_store_generic_ok,
assert_warnings=True,
warning_msg=(
"No data opener found for format 'netcdf' and data type 'mldataset'. "
"Data type is changed to the default data type 'dataset'."
),
)

def test_dataset_levels(self):
data_store = self.create_data_store()
Expand Down Expand Up @@ -330,9 +377,12 @@ def _assert_dataset_supported(
expected_descriptor_type: Optional[
Union[type[DatasetDescriptor], type[MultiLevelDatasetDescriptor]]
] = None,
opener_id: str = None,
write_params: Optional[dict[str, Any]] = None,
open_params: Optional[dict[str, Any]] = None,
assert_data_ok: Optional[Callable[[Any], Any]] = None,
assert_warnings: bool = False,
warning_msg: str = None,
):
"""Call all DataStore operations to ensure data of type
xr.Dataset//MultiLevelDataset is supported by *data_store*.
Expand All @@ -344,9 +394,13 @@ def _assert_dataset_supported(
expected_data_type_alias: The expected data type alias.
expected_return_type: The expected data type.
expected_descriptor_type: The expected data descriptor type.
opener_id: Optional opener identifier
write_params: Optional write parameters
open_params: Optional open parameters
assert_data_ok: Optional function to assert read data is ok
assert_warnings: Optional boolean if test may check for warnings
warning_msg: Optional warning message to be checked if
assert_warnings is True
"""

data_id = f"{DATA_PATH}/ds{filename_ext}"
Expand Down Expand Up @@ -388,13 +442,27 @@ def _assert_dataset_supported(
self.assertIsInstance(data_descriptors[0], DataDescriptor)
self.assertIsInstance(data_descriptors[0], expected_descriptor_type)

if requested_dtype_alias:
# noinspection PyProtectedMember
_, format_name, protocol = data_store._guess_accessor_id_parts(data_id)
opener_id = f"{requested_dtype_alias}:{format_name}:{protocol}"
if assert_warnings:
with warnings.catch_warnings(record=True) as w:
data = data_store.open_data(
data_id,
opener_id=opener_id,
data_type=requested_dtype_alias,
**open_params,
)
# if "s3" data store is tested, warnings from other
# libraries like botocore occur
if data_store.protocol is not "s3":
self.assertEqual(1, len(w))
self.assertEqual(w[0].category, UserWarning)
self.assertEqual(warning_msg, w[0].message.args[0])
else:
opener_id = None
data = data_store.open_data(data_id, opener_id=opener_id, **open_params)
data = data_store.open_data(
data_id,
opener_id=opener_id,
data_type=requested_dtype_alias,
**open_params,
)
self.assertIsInstance(data, expected_return_type)
if assert_data_ok is not None:
assert_data_ok(data)
Expand Down
16 changes: 15 additions & 1 deletion test/core/store/ref/test_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import unittest
from abc import ABC, abstractmethod
from pathlib import Path
from typing import List
import warnings

import fsspec
import pytest
Expand Down Expand Up @@ -149,6 +149,20 @@ def test_open_data(self):
store = self.get_store()
sst_cube = store.open_data("sst-cube")
self.assert_sst_cube_ok(sst_cube)
sst_cube = store.open_data("sst-cube", data_type="dataset")
self.assert_sst_cube_ok(sst_cube)
with warnings.catch_warnings(record=True) as w:
sst_cube = store.open_data("sst-cube", data_type="mldataset")
self.assertEqual(1, len(w))
self.assertEqual(w[0].category, UserWarning)
self.assertEqual(
(
"ReferenceDataStore can only represent "
"the data resource as xr.Dataset."
),
w[0].message.args[0],
)
self.assert_sst_cube_ok(sst_cube)

def test_get_search_params_schema(self):
store = self.get_store()
Expand Down
94 changes: 68 additions & 26 deletions xcube/core/store/fs/store.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,6 @@
from typing import (
Optional,
Any,
Tuple,
List,
Dict,
Union,
Callable,
)
Expand Down Expand Up @@ -317,7 +314,10 @@ def get_open_data_params_schema(
def open_data(
self, data_id: str, opener_id: str = None, **open_params
) -> xr.Dataset:
opener = self._find_opener(opener_id=opener_id, data_id=data_id)
data_type = open_params.pop("data_type", None)
opener = self._find_opener(
opener_id=opener_id, data_id=data_id, data_type=data_type
)
open_params_schema = self._get_open_data_params_schema(opener, data_id)
assert_valid_params(open_params, name="open_params", schema=open_params_schema)
fs_path = self._convert_data_id_into_fs_path(data_id)
Expand Down Expand Up @@ -447,10 +447,16 @@ def _guess_writer_id(self, data, data_id: str = None):
return extensions[0].name

def _find_opener(
self, opener_id: str = None, data_id: str = None, require: bool = True
self,
opener_id: str = None,
data_id: str = None,
data_type: DataTypeLike = None,
require: bool = True,
) -> Optional[DataOpener]:
if not opener_id:
opener_id = self._find_opener_id(data_id=data_id, require=require)
opener_id = self._find_opener_id(
data_id=data_id, data_type=data_type, require=require
)
if opener_id is None:
return None
return new_data_opener(opener_id)
Expand Down Expand Up @@ -530,9 +536,14 @@ def _assert_valid_data_type(self, data_type: DataType):
if data_type != ANY_TYPE:
assert_in(data_type, self.get_data_types(), name="data_type")

def _find_opener_id(self, data_id: str = None, require=True) -> Optional[str]:
def _find_opener_id(
self, data_id: str = None, data_type: DataTypeLike = None, require=True
) -> Optional[str]:
return self._find_accessor_id(
find_data_opener_extensions, data_id=data_id, require=require
find_data_opener_extensions,
data_id=data_id,
data_type=data_type,
require=require,
)

def _find_writer_id(self, data_id: str = None, require=True) -> Optional[str]:
Expand All @@ -541,10 +552,17 @@ def _find_writer_id(self, data_id: str = None, require=True) -> Optional[str]:
)

def _find_accessor_id(
self, find_data_accessor_extensions: Callable, data_id: str = None, require=True
self,
find_data_accessor_extensions: Callable,
data_id: str = None,
data_type: DataTypeLike = None,
require=True,
) -> Optional[str]:
extensions = self._find_accessor_extensions(
find_data_accessor_extensions, data_id=data_id, require=require
find_data_accessor_extensions,
data_id=data_id,
data_type=data_type,
require=require,
)
return extensions[0].name if extensions else None

Expand All @@ -559,30 +577,51 @@ def _find_writer_extensions(self, data_id: str = None, require=True):
)

def _find_accessor_extensions(
self, find_data_accessor_extensions: Callable, data_id: str = None, require=True
self,
find_data_accessor_extensions: Callable,
data_id: str = None,
data_type: DataTypeLike = None,
require=True,
) -> list[Extension]:
if data_id:
accessor_id_parts = self._guess_accessor_id_parts(data_id, require=require)
accessor_id_parts = self._guess_accessor_id_parts(
data_id, data_type=data_type, require=require
)
if not accessor_id_parts:
return []
data_type_alias = accessor_id_parts[0]
format_id = accessor_id_parts[1]
storage_id = accessor_id_parts[2]
else:
data_type_alias = _DEFAULT_DATA_TYPE
if data_type:
data_type_alias = DataType.normalize(data_type).alias
else:
data_type_alias = _DEFAULT_DATA_TYPE
format_id = _DEFAULT_FORMAT_ID
storage_id = self.protocol
predicate = get_data_accessor_predicate(
data_type=data_type_alias, format_id=format_id, storage_id=storage_id
)
extensions = find_data_accessor_extensions(predicate)

def _get_extension(type_alias: str) -> list[Extension]:
predicate = get_data_accessor_predicate(
data_type=type_alias, format_id=format_id, storage_id=storage_id
)
return find_data_accessor_extensions(predicate)

extensions = _get_extension(data_type_alias)
if not extensions:
if require:
msg = "No data accessor found"
if data_id:
msg += f" for data resource {data_id!r}"
raise DataStoreError(msg)
return []
extensions = _get_extension(_DEFAULT_DATA_TYPE)
if not extensions:
if require:
msg = "No data accessor found"
if data_id:
msg += f" for data resource {data_id!r}"
raise DataStoreError(msg)
return []
else:
warnings.warn(
f"No data opener found for format {format_id!r} and data type "
f"{data_type!r}. Data type is changed to the default data type "
f"{_DEFAULT_DATA_TYPE!r}."
)
return extensions

def _guess_data_type_for_data_id(
Expand All @@ -595,13 +634,16 @@ def _guess_data_type_for_data_id(
return DataType.normalize(data_type_alias)

def _guess_accessor_id_parts(
self, data_id: str, require=True
self, data_id: str, data_type: DataTypeLike = None, require=True
) -> Optional[tuple[str, str, str]]:
assert_given(data_id, "data_id")
ext = self._get_filename_ext(data_id)
data_type_aliases = None
if data_type:
data_type_aliases = DataType.normalize(data_type).aliases
else:
data_type_aliases = None
format_id = _FILENAME_EXT_TO_FORMAT.get(ext.lower())
if format_id is not None:
if format_id is not None and data_type_aliases is None:
data_type_aliases = _FORMAT_TO_DATA_TYPE_ALIASES.get(format_id)
if data_type_aliases is None or format_id is None:
if require:
Expand Down
9 changes: 7 additions & 2 deletions xcube/core/store/ref/store.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,15 +4,15 @@

import os
import warnings
from typing import Any, Tuple, Union, Dict, List
from typing import Any, Union
from collections.abc import Iterator, Container

import fsspec
import xarray as xr

from xcube.util.jsonschema import JsonObjectSchema

from ..datatype import DataTypeLike
from ..datatype import DataType, DataTypeLike
from ..descriptor import DataDescriptor, DatasetDescriptor
from ..descriptor import new_data_descriptor
from ..store import DataStore
Expand Down Expand Up @@ -103,6 +103,11 @@ def get_open_data_params_schema(
def open_data(
self, data_id: str, opener_id: str = None, **open_params
) -> xr.Dataset:
data_type = open_params.pop("data_type", None)
if DataType.normalize(data_type).alias == "mldataset":
warnings.warn(
"ReferenceDataStore can only represent the data resource as xr.Dataset."
)
if open_params:
warnings.warn(
f"open_params are not supported yet,"
Expand Down
Loading

0 comments on commit 7809dd8

Please sign in to comment.