From 2fbf2ed2afde64dbda3de2190f7b7b8348c7a41a Mon Sep 17 00:00:00 2001 From: kds1010 Date: Mon, 2 Sep 2024 12:33:08 +0900 Subject: [PATCH] add dataset to load data from snowflake --- dev-requirements.txt | 6 +- tests/test_datasets.py | 279 ++++++++++++++++++++++++++++++++- wicker/core/column_files.py | 3 + wicker/core/datasets.py | 259 +++++++++++++++++++++++++++++- wicker/schema/dataloading.py | 6 + wicker/schema/dataparsing.py | 8 + wicker/schema/schema.py | 145 +++++++++++++++++ wicker/schema/serialization.py | 5 + 8 files changed, 705 insertions(+), 6 deletions(-) diff --git a/dev-requirements.txt b/dev-requirements.txt index a9cf0cb..bbdc41c 100644 --- a/dev-requirements.txt +++ b/dev-requirements.txt @@ -24,7 +24,7 @@ mypy-extensions==0.4.3 numpy==1.21.2 packaging==21.0 pathspec==0.9.0 -platformdirs==2.4.0 +platformdirs==2.6.0 pluggy==1.0.0 py==1.10.0 py4j==0.10.9.2 @@ -54,7 +54,9 @@ sphinxcontrib-serializinghtml==1.1.5 toml==0.10.2 tomli==1.2.2 types-retry==0.9.2 -typing-extensions==3.10.0.2 +typing-extensions==4.3.0 urllib3==1.26.7 wandb==0.12.21 tqdm +snowflake-connector-python[pandas]==3.12.1 +pytest-mock==3.14.0 diff --git a/tests/test_datasets.py b/tests/test_datasets.py index 14bd22c..d17048d 100644 --- a/tests/test_datasets.py +++ b/tests/test_datasets.py @@ -2,19 +2,30 @@ import tempfile import unittest from contextlib import contextmanager -from typing import Any, Iterator, NamedTuple, Tuple +from typing import Any, Dict, Iterator, List, NamedTuple, Optional, Tuple from unittest.mock import patch import numpy as np import pyarrow as pa # type: ignore import pyarrow.fs as pafs # type: ignore import pyarrow.parquet as papq # type: ignore +import pytest +from pytest_mock import MockFixture from wicker.core.column_files import ColumnBytesFileWriter -from wicker.core.datasets import S3Dataset +from wicker.core.datasets import S3Dataset, SFDataset from wicker.core.definitions import DatasetID, DatasetPartition from wicker.core.storage import S3PathFactory from wicker.schema import schema, serialization +from wicker.schema.schema import ( + BoolField, + DatasetSchema, + FloatField, + IntField, + SchemaField, + SfNumpyField, + StringField, +) from wicker.testing.storage import FakeS3DataStorage FAKE_NAME = "dataset_name" @@ -194,3 +205,267 @@ def get_dir_size(path="."): expected_bytes = get_dir_size(fake_s3_storage._tmpdir) assert expected_bytes == dataset_size + + +@pytest.mark.parametrize( + ( + "dataset_name", + "version", + "dataset_partition_name", + "table_name", + "connection_parameters", + "columns_to_load", + "optional_condition", + "primary_keys", + ), + [ + ("test", "0.0.0", "partition", "test_table", {}, ["column0", "column1"], "column0 = test", ["key1"]), + ], +) +def test_sf_dataset_constructor( + mocker: MockFixture, + dataset_name: str, + version: str, + dataset_partition_name: str, + table_name: str, + columns_to_load: List[str], + primary_keys: List[str], + connection_parameters: Dict[str, str], + optional_condition: str, +) -> None: + mocker.patch("wicker.core.datasets.SFDataset.connection") + mocker.patch("wicker.core.datasets.SFDataset.schema") + dataset = SFDataset( + dataset_name=dataset_name, + dataset_version=version, + dataset_partition_name=dataset_partition_name, + table_name=table_name, + columns_to_load=columns_to_load, + primary_keys=primary_keys, + connection_parameters=connection_parameters, + optional_condition=optional_condition, + ) + assert dataset._dataset_id.name == dataset_name + assert dataset._dataset_id.version == version + assert dataset._partition.dataset_id.name == dataset_name + assert dataset._partition.dataset_id.version == version + assert dataset._partition.partition == dataset_partition_name + assert dataset._table_name == table_name + assert dataset._columns_to_load == columns_to_load + assert dataset._primary_keys == primary_keys + assert dataset._connection_parameters == connection_parameters + assert dataset._optional_condition == optional_condition + assert dataset._dataset_definition.dataset_id.name == dataset_name + assert dataset._dataset_definition.dataset_id.version == version + + +@pytest.fixture +def sf_dataset(mocker: MockFixture) -> SFDataset: + return SFDataset("test", "0.0.0", "", "", {}, schema=mocker.MagicMock()) + + +@pytest.mark.parametrize( + ("connection", "connection_parameter", "expectation"), + [ + (None, {}, True), + (1, {}, 1), + (1, None, 1), + ], +) +def test_connection( + mocker: MockFixture, sf_dataset: SFDataset, connection: Any, connection_parameter: Dict[str, str], expectation: Any +) -> None: + mocker.patch("snowflake.connector.connect", return_value=expectation) + sf_dataset._connection = connection + sf_dataset._connection_parameters = connection_parameter + assert sf_dataset.connection == expectation + + +@pytest.mark.parametrize( + ("schema", "schema_fields", "primary_keys", "expectation"), + [ + (DatasetSchema([], [], True), [], [], DatasetSchema([], [], True)), + (None, [StringField("key1")], ["key1"], DatasetSchema([StringField("key1")], ["key1"])), + ], +) +def test_schema( + mocker: MockFixture, + sf_dataset: SFDataset, + schema: Any, + schema_fields: List[SchemaField], + primary_keys: List[str], + expectation: DatasetSchema, +) -> None: + mocker.patch.object(sf_dataset, "_get_schema_from_database") + mocker.patch.object(sf_dataset, "_get_schema_fields", return_value=schema_fields) + mocker.patch.object(sf_dataset, "_primary_keys", new=primary_keys) + sf_dataset._schema = schema + ret = sf_dataset.schema() + assert ret == expectation + + +@pytest.mark.parametrize( + ("input_table", "table", "schema", "expectation"), + [ + ( + pa.Table.from_arrays([[1, 2], [3, 4]], names=["col1", "col2"]), + None, + None, + pa.Table.from_arrays([[1, 2], [3, 4]], names=["col1", "col2"]), + ), + ( + None, + pa.Table.from_arrays([[1, 2], [3, 4]], names=["col1", "col2"]), + DatasetSchema([StringField("col1"), StringField("col2")], primary_keys=["col1"]), + pa.Table.from_arrays([[1, 2], [3, 4]], names=["col1", "col2"]), + ), + ( + None, + pa.Table.from_arrays([[1, 2], [3, 4], ["[1]", "[2]"]], names=["col1", "col2", "sf1"]), + DatasetSchema( + [StringField("col1"), StringField("col2"), SfNumpyField("sf1", (1, -1), "float")], primary_keys=["col1"] + ), + pa.Table.from_arrays([[1, 2], [3, 4], ["[1]".encode(), "[2]".encode()]], names=["col1", "col2", "sf1"]), + ), + ], +) +def test_arrow_table( + mocker: MockFixture, + sf_dataset: SFDataset, + input_table: Optional[pa.Table], + table: Optional[pa.Table], + schema: DatasetSchema, + expectation: pa.Table, +) -> None: + mocker.patch.object(sf_dataset, "_arrow_table", new=input_table) + mocker.patch.object(sf_dataset, "_get_data") + mocker.patch.object(sf_dataset, "_get_lower_case_columns", return_value=table) + mocker.patch.object(sf_dataset, "schema", return_value=schema) + ret = sf_dataset.arrow_table() + assert ret == expectation + + +@pytest.mark.parametrize( + ("table", "expectation"), + [ + (pa.Table.from_arrays([[1, 2]], names=["col1"]), 2), + (pa.Table.from_arrays([[1, 2, 3, 4]], names=["col1"]), 4), + ], +) +def test_len(mocker: MockFixture, sf_dataset: SFDataset, table: pa.Table, expectation: int) -> None: + mocker.patch.object(sf_dataset, "arrow_table", return_value=table) + assert len(sf_dataset) == expectation + + +@pytest.mark.parametrize( + ("table", "columns_to_load", "schema", "expectations"), + [ + ( + pa.Table.from_arrays([[1, 2]], names=["col1"]), + ["col1"], + DatasetSchema([IntField("col1")], ["col1"]), + [{"col1": 1}, {"col1": 2}], + ), + ( + pa.Table.from_arrays([[1, 2], ["[1]", "[2]"]], names=["col1", "sf1"]), + ["col1"], + DatasetSchema([IntField("col1"), SfNumpyField("sf1", (1, -1), "int")], ["col1"]), + [{"col1": 1}, {"col1": 2}], + ), + ( + pa.Table.from_arrays([[1, 2], ["[1]", "[2]"]], names=["col1", "sf1"]), + ["col1", "sf1"], + DatasetSchema([IntField("col1"), SfNumpyField("sf1", (1, -1), "int")], ["col1"]), + [{"col1": 1, "sf1": np.array([1])}, {"col1": 2, "sf1": np.array([2])}], + ), + ], +) +def test_getitem( + mocker: MockFixture, + sf_dataset: SFDataset, + table: pa.Table, + columns_to_load: List[str], + schema: DatasetSchema, + expectations: List[Dict[str, Any]], +) -> None: + mocker.patch.object(sf_dataset, "arrow_table", return_value=table) + mocker.patch.object(sf_dataset, "schema", return_value=schema) + sf_dataset._columns_to_load = columns_to_load + for idx, expectation in enumerate(expectations): + assert sf_dataset[idx] == expectation + + +@pytest.mark.parametrize( + ("table", "columns_to_load", "expectation"), + [ + ( + pa.Table.from_arrays([["COL1", "COL2"], ["str", "int"]], names=["name", "type"]), + None, + pa.Table.from_arrays( + [["COL1", "COL2"], ["str", "int"], ["col1", "col2"]], names=["name", "type", "lowercase_name"] + ), + ), + ( + pa.Table.from_arrays([["COL1", "COL2"], ["str", "int"]], names=["name", "type"]), + ["col1"], + pa.Table.from_arrays([["COL1"], ["str"], ["col1"]], names=["name", "type", "lowercase_name"]), + ), + ], +) +def test_get_schema_from_database( + mocker: MockFixture, + sf_dataset: SFDataset, + table: pa.Table, + columns_to_load: Optional[List[str]], + expectation: pa.Table, +) -> None: + base_conn_mock = mocker.MagicMock() + conn_mock = mocker.MagicMock() + cur_mock = mocker.MagicMock() + mocker.patch.object(cur_mock, "fetch_arrow_all", return_value=table) + mocker.patch.object(conn_mock, "__enter__", return_value=cur_mock) + mocker.patch.object(base_conn_mock, "cursor", return_value=conn_mock) + sf_dataset._connection = base_conn_mock + sf_dataset._columns_to_load = columns_to_load + assert sf_dataset._get_schema_from_database() == expectation + + +@pytest.mark.parametrize( + ("table", "expectation"), + [ + ( + pa.Table.from_arrays( + [["col1", "col2", "col3", "col4"], ["VARCHAR", "NUMBER(10,2)", "NUMBER(10,0)", "VARIANT"]], + names=["lowercase_name", "type"], + ), + [StringField("col1"), FloatField("col2"), IntField("col3"), SfNumpyField("col4", (1, -1), "float32")], + ), + ], +) +def test_get_schema_fields(sf_dataset: SFDataset, table: pa.Table, expectation: List[SchemaField]) -> None: + assert sf_dataset._get_schema_fields(table) == expectation + + +@pytest.mark.parametrize( + ("type", "expectation"), + [ + ("varchar", StringField), + ("boolean", BoolField), + ("number(1,2)", FloatField), + ("number(1,0)", IntField), + ("variant", SfNumpyField), + ], +) +def test_get_schema_type(sf_dataset: SFDataset, type: str, expectation: SchemaField) -> None: + assert sf_dataset._get_schema_type(type) == expectation + + +@pytest.mark.parametrize( + ("table", "expectation"), + [ + (pa.Table.from_arrays([[1, 2]], names=["COL1"]), pa.Table.from_arrays([[1, 2]], names=["col1"])), + (pa.Table.from_arrays([[1, 2]], names=["col1"]), pa.Table.from_arrays([[1, 2]], names=["col1"])), + ], +) +def test_get_lower_case_columns(sf_dataset: SFDataset, table: pa.Table, expectation: pa.Table) -> None: + assert sf_dataset._get_lower_case_columns(table) == expectation diff --git a/wicker/core/column_files.py b/wicker/core/column_files.py index 83fbd65..012093c 100644 --- a/wicker/core/column_files.py +++ b/wicker/core/column_files.py @@ -318,3 +318,6 @@ def process_object_field(self, field: schema.ObjectField) -> Any: return data cbf_info = ColumnBytesFileLocationV1.from_bytes(data) return self.cbf_cache.read(cbf_info) + + def process_sf_variant_field(self, field: schema.VariantField) -> Any: + return self._current_data diff --git a/wicker/core/datasets.py b/wicker/core/datasets.py index 8864cd6..5e9f0b6 100644 --- a/wicker/core/datasets.py +++ b/wicker/core/datasets.py @@ -4,19 +4,29 @@ from functools import cached_property from multiprocessing import Pool, cpu_count from multiprocessing.pool import ThreadPool -from typing import Any, Dict, List, Optional, Tuple +from typing import Any, Dict, List, Optional, Tuple, Type import boto3 import pyarrow # type: ignore +import pyarrow.compute as pc import pyarrow.fs as pafs # type: ignore import pyarrow.parquet as papq # type: ignore +import snowflake.connector from wicker.core.column_files import ColumnBytesFileCache, ColumnBytesFileLocationV1 from wicker.core.config import get_config # type: ignore from wicker.core.definitions import DatasetDefinition, DatasetID, DatasetPartition from wicker.core.storage import S3DataStorage, S3PathFactory from wicker.schema import dataloading, serialization -from wicker.schema.schema import DatasetSchema +from wicker.schema.schema import ( + BoolField, + DatasetSchema, + FloatField, + IntField, + SchemaField, + SfNumpyField, + StringField, +) # How long to wait before timing out on filelocks in seconds FILE_LOCK_TIMEOUT_SECONDS = 300 @@ -303,3 +313,248 @@ def dataset_size(self) -> int: int: total dataset size in bytes """ return self._get_dataset_partition_size() + + +class SFDataset(AbstractDataset): + """Loading dataset from Snowflake table""" + + def __init__( + self, + dataset_name: str, + dataset_version: str, + dataset_partition_name: str, + table_name: str, + connection_parameters: Optional[Dict[str, str]] = None, + connection: Any = None, + schema: Optional[DatasetSchema] = None, + columns_to_load: Optional[List[str]] = None, + optional_condition: str = "", + primary_keys: List[str] = ["scene_id", "frame_idx"], + ): + """Initialize SFDataset. + + Args: + dataset_name (str): Name of the dataset to load. + dataset_version (str): Version of the dataset to load. + dataset_partition_name (str): Partition name. + table_name (str): Name of the table in database to load. + connection_parameters (Optional[Dict[str, str]], optional): + Parameters to connect to the database. Defaults to None. + connection (Any, optional): Established connection to the database. Defaults to None. + schema (Optional[DatasetSchema], optional): Expected DatasetSchema. Defaults to None. + columns_to_load (Optional[List[str]], optional): Name of columns to load. Defaults to None. + optional_condition (str, optional): Query condition to specify the dataset. Defaults to "". + primary_keys (List[str], optional): Primary key of the dataset. + Defaults to ["scene_id", "frame_idx"]. + """ + self._arrow_table: Optional[pyarrow.Table] = None + + self._dataset_id = DatasetID(name=dataset_name, version=dataset_version) + self._partition = DatasetPartition(dataset_id=self._dataset_id, partition=dataset_partition_name) + self._schema = schema + self._table_name = table_name + self._columns_to_load = columns_to_load + self._primary_keys = primary_keys + self._connection_parameters = connection_parameters + self._optional_condition = optional_condition + self._connection = connection + self._dataset_definition = DatasetDefinition( + self._dataset_id, + schema=self.schema(), + ) + + @property + def connection(self): + """Returns a connection with the database. + It tries to connect if no connection exists. + + Raises: + ValueError: Error if both connection and connection_parameters are None. + + Returns: + _type_: Connection to the database. + """ + if self._connection is None: + if self._connection_parameters is not None: + self._connection = snowflake.connector.connect(**self._connection_parameters) + else: + raise ValueError( + "Invalid input: Both 'connection' and 'connection_parameters' cannot be None.", + "At least one of them must be provided.", + ) + return self._connection + + def schema(self) -> DatasetSchema: + """Returns schema of the dataset. It can be inputted by end-user to ensure the expected schema. + + Returns: + DatasetSchema: Schema of the dataset. + """ + if self._schema is None: + schema_table = self._get_schema_from_database() + schema_fields = self._get_schema_fields(schema_table) + self._schema = DatasetSchema(schema_fields, self._primary_keys) + return self._schema + + def arrow_table(self) -> pyarrow.Table: + """Returns a table of the dataset as pyarrow table. + + Returns: + pyarrow.Table: Contents of the dataset. + """ + if self._arrow_table is None: + arrow_table = self._get_data() + arrow_table = self._get_lower_case_columns(arrow_table) + df = arrow_table.to_pandas() + # TODO: Here is workaround because ObjectField accepts only bytes as input + for col in self.schema().schema_record.fields: + if isinstance(col, SfNumpyField): + df[col.name] = df[col.name].apply(lambda x: x.encode("utf-8")) + self._arrow_table = pyarrow.Table.from_pandas(df) + return self._arrow_table + + def __len__(self) -> int: + """Returns a number of rows of the dataset. + + Returns: + int: Number of rows of the dataset. + """ + return len(self.arrow_table()) + + def __getitem__(self, idx: int) -> Dict[str, Any]: + """Returns contents of a row of the dataset. + + Args: + idx (int): Index of the dataset to load. + + Returns: + Dict[str, Any]: Parsed/Transformed contents of a row. + """ + tbl = self.arrow_table() + columns = self._columns_to_load if self._columns_to_load is not None else tbl.column_names + row = {col: tbl[col][idx].as_py() for col in columns} + example = dataloading.load_example(row, self.schema()) + return example + + def _get_schema_from_database(self) -> pyarrow.Table: + """Returns schema of the dataset as a table. + It converts the name of columns to lower case because Snowflake returns each column as upper case. + + Returns: + pyarrow.Table: Schema of the dataset. + """ + with self.connection.cursor() as cur: + cur.execute(f"describe table {self._table_name};") + cur.execute( + """ + select + * + from + table(result_scan(last_query_id())) + ; + """ + ) + schema_table = cur.fetch_arrow_all() # type: ignore + lowercase_column = pc.utf8_lower(schema_table["name"]) # type: ignore + schema_table = schema_table.append_column("lowercase_name", lowercase_column) + if self._columns_to_load is not None: + mask = pc.is_in( # type: ignore + schema_table["lowercase_name"], + value_set=pyarrow.array(self._columns_to_load), + ) + schema_table = schema_table.filter(mask) + return schema_table + + def _get_schema_fields(self, schema_table: pyarrow.Table) -> List[SchemaField]: + """Returns fields of the schema by mapping name of types. + + Args: + schema_table (pyarrow.Table): Table explains name and type of the columns. + + Returns: + List[SchemaField]: Fields of schema of the dataset. + """ + schema_fields = [] + for name, type in zip(schema_table["lowercase_name"], schema_table["type"]): + schema = self._get_schema_type(type.as_py()) + if schema == SfNumpyField: + schema_instance = schema(name=name.as_py(), shape=(1, -1), dtype="float32") # type: ignore + else: + schema_instance = schema(name=name.as_py()) # type: ignore + schema_fields.append(schema_instance) + return schema_fields + + def _get_schema_type(self, type: str) -> Type[SchemaField]: + """Returns type of schema filed by mapping name of type in the database. + + Args: + type (str): Name of type in the database. + + Raises: + NotImplementedError: Error if the type is not supported. + + Returns: + Type[SchemaField]: Type of SchemaField corresponding to the input type. + """ + upper_type_name = type.upper() + if upper_type_name.startswith("VARCHAR"): + return StringField + elif upper_type_name.startswith("BOOLEAN"): + return BoolField + elif upper_type_name.startswith("NUMBER"): + if upper_type_name.endswith("0)"): + return IntField + else: + return FloatField + elif upper_type_name.startswith("VARIANT"): + return SfNumpyField + else: + raise NotImplementedError(f"{upper_type_name} is not Implemented as schema") + + def _get_data(self) -> pyarrow.Table: + """Returns a table contains dataset contents loaded from the database. + It accepts parameters to specify the dataset. + + Returns: + pyarrow.Table: Contents of the dataset. + """ + columns = "*" + if self._columns_to_load is not None: + columns = ",".join([f'{f} as "{f}"' for f in self._columns_to_load]) + optional_query = "" + if self._optional_condition: + optional_query = f"and {self._optional_condition}" + with self.connection.cursor() as cur: + sql = f""" + select + {columns} + from + {self._table_name} + where + partition = '{self._partition.partition}' + and dataset_name = '{self._dataset_id.name}' + and version = '{self._dataset_id.version}' + {optional_query} + ; + """ + cur.execute(sql) + arrow_table = cur.fetch_arrow_all() # type: ignore + return arrow_table + + def _get_lower_case_columns(self, arrow_table: pyarrow.Table) -> pyarrow.Table: + """Returns a table by renaming columns to lower case. + + Args: + arrow_table (pyarrow.Table): Input table. + + Returns: + pyarrow.Table: Table with column names with lower case. + """ + new_schema = pyarrow.schema( + [ + pyarrow.field(name.lower(), field.type) + for name, field in zip(arrow_table.schema.names, arrow_table.schema) + ] + ) + arrow_table = pyarrow.Table.from_arrays(arrow_table.columns, schema=new_schema) + return arrow_table diff --git a/wicker/schema/dataloading.py b/wicker/schema/dataloading.py index 0c5503f..1c3c089 100644 --- a/wicker/schema/dataloading.py +++ b/wicker/schema/dataloading.py @@ -87,6 +87,12 @@ def process_object_field(self, field: schema.ObjectField) -> Optional[Any]: return data return field.codec.decode_object(data) + def process_sf_variant_field(self, field: schema.VariantField) -> Optional[Any]: + data = validation.validate_field_type(self._current_data, str, field.required, self._current_path) + if data is None: + return data + return field.codec.decode_object(data.encode()) + def process_array_field(self, field: schema.ArrayField) -> Optional[List[Any]]: current_data = validation.validate_field_type(self._current_data, list, field.required, self._current_path) if current_data is None: diff --git a/wicker/schema/dataparsing.py b/wicker/schema/dataparsing.py index d5555a0..ca4030c 100644 --- a/wicker/schema/dataparsing.py +++ b/wicker/schema/dataparsing.py @@ -149,6 +149,14 @@ def process_array_field(self, field: schema.ArrayField) -> Optional[List[Any]]: res.append(field.element_field._accept_visitor(self)) return res + def process_sf_variant_field(self, field: schema.VariantField) -> Any: + data = validation.validate_field_type( + self._current_data, field.codec.object_type(), field.required, self._current_path + ) + if data is None: + return None + return field.codec.validate_and_encode_object(data) + class ParseExampleMetadataVisitor(ParseExampleVisitor): """Specialization of ParseExampleVisitor which skips over certain fields that are now parsed as metadata""" diff --git a/wicker/schema/schema.py b/wicker/schema/schema.py index c18b526..7fd7dcb 100644 --- a/wicker/schema/schema.py +++ b/wicker/schema/schema.py @@ -404,6 +404,147 @@ def __init__( ) +class VariantField(SchemaField): + """A field that contains variant of a specific type.""" + + def __init__( + self, + name: str, + codec: codecs.Codec, + description: str = "", + required: bool = True, + is_heavy_pointer: bool = True, + ) -> None: + """Create an VariantField. The VariantField is parametrized with a Codec that it will use when + serializing/deserializing data to/from storage. Users are responsible for providing the Codec + at both write and read time. + + :param name: name of the field + :param codec: Encoder/decoder to serialize/deserialize the object. See codecs.Codec for more details on + the codecs. + """ + if not codec.get_codec_name(): + raise WickerSchemaException("Codec names must be non-empty. Encountered at field={name}") + + custom_field_tags = { + "_l5ml_metatype": "object", + "_codec_params": json.dumps(codec.save_codec_to_dict()), + "_codec_name": codec.get_codec_name(), + } + super().__init__( + name, + description=description, + required=required, + custom_field_tags=custom_field_tags, + ) + self._is_heavy_pointer = is_heavy_pointer + self.codec = codec + + def _accept_visitor(self, visitor: DatasetSchemaVisitor[_T]) -> _T: + """Processes the current schema field with the visitor object""" + return visitor.process_sf_variant_field(self) + + def __eq__(self, other: Any) -> bool: + return super().__eq__(other) and self.codec == other.codec + + +class SfNumpyCodec(codecs.Codec): + def __init__(self, shape: Optional[Tuple[int, ...]], dtype=str): + self.shape = shape + # Validate the dtype is a valid numpy dtype + try: + self.dtype = np.dtype(dtype) + except TypeError: + raise WickerSchemaException(f"Specified dtype: {dtype} not understood by numpy") + + @staticmethod + def _codec_name() -> str: + return "sf_numpy" + + def save_codec_to_dict(self) -> Dict[str, Any]: + """If you want to save some parameters of this codec with the dataset + schema, return the fields here. The returned dictionary should be JSON compatible. + Note that this is a dataset-level value, not a per example value.""" + return { + "shape": [d for d in self.shape] if self.shape is not None else None, + "dtype": str(self.dtype), + } + + @staticmethod + def load_codec_from_dict(data: Dict[str, Any]) -> SfNumpyCodec: + """Create a new instance of this codec with the given parameters.""" + return SfNumpyCodec( + shape=tuple(data["shape"]) if data["shape"] is not None else None, + dtype=data["dtype"], + ) + + def validate_and_encode_object(self, obj: np.ndarray) -> bytes: + """Encode the given object into json string. The function is also responsible for validating the data. + :param obj: Object to encode + :return: The encoded json string for the given object.""" + if obj.dtype != self.dtype: + raise WickerSchemaException( + f"Example provided a numpy array with dtype {obj.dtype}, " f"expected {self.dtype}" + ) + if self.shape is not None: + if len(obj.shape) != len(self.shape): + raise WickerSchemaException( + f"Example provided a numpy array with shape {obj.shape}, " + f"which has a different number of dimensions from expected shape {self.shape}" + ) + for arr_dim_size, field_dim_size in zip(obj.shape, self.shape): + if field_dim_size == -1: + continue + if arr_dim_size != field_dim_size: + raise WickerSchemaException( + f"Example provided a numpy array with shape {obj.shape}, " + f"which is incompatible with expected shape {self.shape}" + ) + + # Serialize array as json string + return json.dumps(obj.tolist()).encode() + + def decode_object(self, data: bytes) -> np.ndarray: + """Decode an object from the given bytes. This is the opposite of validate_and_encode_object. + We expect obj == decode_object(validate_and_encode_object(obj)) + :param data: bytes to decode. + :return: Decoded object.""" + return np.array(json.loads(data)) + + def object_type(self) -> Type[Any]: + """Return the expected type of the objects handled by this codec. + This method can be overriden to match more specific classes.""" + return np.ndarray + + +class SfNumpyField(VariantField): + """An ObjectField that uses a Codec for encoding Numpy arrays""" + + def __init__( + self, + name: str, + shape: Optional[Tuple[int, ...]], + dtype: str, + description: str = "", + required: bool = True, + is_heavy_pointer: bool = False, + ) -> None: + """Create a NumpyField + + :param name: name of the field + :param shape: shape of the numpy array that we expect, or None to indicate that all shapes are acceptable, + `-1` denotes that a given dimension can have any size. + :param dtype: dtype of the numpy array that we expect, + """ + super().__init__( + name=name, + codec=SfNumpyCodec(shape=shape, dtype=dtype), + description=description, + required=required, + is_heavy_pointer=is_heavy_pointer, + ) + + class DatasetSchema: """A schema definition that serializes into an Avro-compatible schema""" @@ -507,6 +648,10 @@ def process_double_field(self, field: DoubleField) -> _T: def process_object_field(self, field: ObjectField) -> _T: pass + @abc.abstractmethod + def process_sf_variant_field(self, field: VariantField) -> _T: + pass + @abc.abstractmethod def process_record_field(self, field: RecordField) -> _T: pass diff --git a/wicker/schema/serialization.py b/wicker/schema/serialization.py index 8917529..f8d0f1d 100644 --- a/wicker/schema/serialization.py +++ b/wicker/schema/serialization.py @@ -286,3 +286,8 @@ def process_array_field(self, field: schema.ArrayField) -> Dict[str, Any]: "name": field.name, "type": field_type, } + + def process_sf_variant_field(self, field: schema.VariantField) -> Any: + return { + **self.process_schema_field(field, "sf_numpy"), + }