Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add a Dataset Class to Load Datasets from Snowflake #71

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions dev-requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ mypy-extensions==0.4.3
numpy==1.21.2
packaging==21.0
pathspec==0.9.0
platformdirs==2.4.0
platformdirs==2.6.0
pluggy==1.0.0
py==1.10.0
py4j==0.10.9.2
Expand Down Expand Up @@ -54,7 +54,9 @@ sphinxcontrib-serializinghtml==1.1.5
toml==0.10.2
tomli==1.2.2
types-retry==0.9.2
typing-extensions==3.10.0.2
typing-extensions==4.3.0
urllib3==1.26.7
wandb==0.12.21
tqdm
snowflake-connector-python[pandas]==3.12.1
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We need to add this library to communicate with Snowflake.

pytest-mock==3.14.0
279 changes: 277 additions & 2 deletions tests/test_datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,19 +2,30 @@
import tempfile
import unittest
from contextlib import contextmanager
from typing import Any, Iterator, NamedTuple, Tuple
from typing import Any, Dict, Iterator, List, NamedTuple, Optional, Tuple
from unittest.mock import patch

import numpy as np
import pyarrow as pa # type: ignore
import pyarrow.fs as pafs # type: ignore
import pyarrow.parquet as papq # type: ignore
import pytest
from pytest_mock import MockFixture

from wicker.core.column_files import ColumnBytesFileWriter
from wicker.core.datasets import S3Dataset
from wicker.core.datasets import S3Dataset, SFDataset
from wicker.core.definitions import DatasetID, DatasetPartition
from wicker.core.storage import S3PathFactory
from wicker.schema import schema, serialization
from wicker.schema.schema import (
BoolField,
DatasetSchema,
FloatField,
IntField,
SchemaField,
SfNumpyField,
StringField,
)
from wicker.testing.storage import FakeS3DataStorage

FAKE_NAME = "dataset_name"
Expand Down Expand Up @@ -194,3 +205,267 @@ def get_dir_size(path="."):

expected_bytes = get_dir_size(fake_s3_storage._tmpdir)
assert expected_bytes == dataset_size


@pytest.mark.parametrize(
(
"dataset_name",
"version",
"dataset_partition_name",
"table_name",
"connection_parameters",
"columns_to_load",
"optional_condition",
"primary_keys",
),
[
("test", "0.0.0", "partition", "test_table", {}, ["column0", "column1"], "column0 = test", ["key1"]),
],
)
def test_sf_dataset_constructor(
mocker: MockFixture,
dataset_name: str,
version: str,
dataset_partition_name: str,
table_name: str,
columns_to_load: List[str],
primary_keys: List[str],
connection_parameters: Dict[str, str],
optional_condition: str,
) -> None:
mocker.patch("wicker.core.datasets.SFDataset.connection")
mocker.patch("wicker.core.datasets.SFDataset.schema")
dataset = SFDataset(
dataset_name=dataset_name,
dataset_version=version,
dataset_partition_name=dataset_partition_name,
table_name=table_name,
columns_to_load=columns_to_load,
primary_keys=primary_keys,
connection_parameters=connection_parameters,
optional_condition=optional_condition,
)
assert dataset._dataset_id.name == dataset_name
assert dataset._dataset_id.version == version
assert dataset._partition.dataset_id.name == dataset_name
assert dataset._partition.dataset_id.version == version
assert dataset._partition.partition == dataset_partition_name
assert dataset._table_name == table_name
assert dataset._columns_to_load == columns_to_load
assert dataset._primary_keys == primary_keys
assert dataset._connection_parameters == connection_parameters
assert dataset._optional_condition == optional_condition
assert dataset._dataset_definition.dataset_id.name == dataset_name
assert dataset._dataset_definition.dataset_id.version == version


@pytest.fixture
def sf_dataset(mocker: MockFixture) -> SFDataset:
return SFDataset("test", "0.0.0", "", "", {}, schema=mocker.MagicMock())


@pytest.mark.parametrize(
("connection", "connection_parameter", "expectation"),
[
(None, {}, True),
(1, {}, 1),
(1, None, 1),
],
)
def test_connection(
mocker: MockFixture, sf_dataset: SFDataset, connection: Any, connection_parameter: Dict[str, str], expectation: Any
) -> None:
mocker.patch("snowflake.connector.connect", return_value=expectation)
sf_dataset._connection = connection
sf_dataset._connection_parameters = connection_parameter
assert sf_dataset.connection == expectation


@pytest.mark.parametrize(
("schema", "schema_fields", "primary_keys", "expectation"),
[
(DatasetSchema([], [], True), [], [], DatasetSchema([], [], True)),
(None, [StringField("key1")], ["key1"], DatasetSchema([StringField("key1")], ["key1"])),
],
)
def test_schema(
mocker: MockFixture,
sf_dataset: SFDataset,
schema: Any,
schema_fields: List[SchemaField],
primary_keys: List[str],
expectation: DatasetSchema,
) -> None:
mocker.patch.object(sf_dataset, "_get_schema_from_database")
mocker.patch.object(sf_dataset, "_get_schema_fields", return_value=schema_fields)
mocker.patch.object(sf_dataset, "_primary_keys", new=primary_keys)
sf_dataset._schema = schema
ret = sf_dataset.schema()
assert ret == expectation


@pytest.mark.parametrize(
("input_table", "table", "schema", "expectation"),
[
(
pa.Table.from_arrays([[1, 2], [3, 4]], names=["col1", "col2"]),
None,
None,
pa.Table.from_arrays([[1, 2], [3, 4]], names=["col1", "col2"]),
),
(
None,
pa.Table.from_arrays([[1, 2], [3, 4]], names=["col1", "col2"]),
DatasetSchema([StringField("col1"), StringField("col2")], primary_keys=["col1"]),
pa.Table.from_arrays([[1, 2], [3, 4]], names=["col1", "col2"]),
),
(
None,
pa.Table.from_arrays([[1, 2], [3, 4], ["[1]", "[2]"]], names=["col1", "col2", "sf1"]),
DatasetSchema(
[StringField("col1"), StringField("col2"), SfNumpyField("sf1", (1, -1), "float")], primary_keys=["col1"]
),
pa.Table.from_arrays([[1, 2], [3, 4], ["[1]".encode(), "[2]".encode()]], names=["col1", "col2", "sf1"]),
),
],
)
def test_arrow_table(
mocker: MockFixture,
sf_dataset: SFDataset,
input_table: Optional[pa.Table],
table: Optional[pa.Table],
schema: DatasetSchema,
expectation: pa.Table,
) -> None:
mocker.patch.object(sf_dataset, "_arrow_table", new=input_table)
mocker.patch.object(sf_dataset, "_get_data")
mocker.patch.object(sf_dataset, "_get_lower_case_columns", return_value=table)
mocker.patch.object(sf_dataset, "schema", return_value=schema)
ret = sf_dataset.arrow_table()
assert ret == expectation


@pytest.mark.parametrize(
("table", "expectation"),
[
(pa.Table.from_arrays([[1, 2]], names=["col1"]), 2),
(pa.Table.from_arrays([[1, 2, 3, 4]], names=["col1"]), 4),
],
)
def test_len(mocker: MockFixture, sf_dataset: SFDataset, table: pa.Table, expectation: int) -> None:
mocker.patch.object(sf_dataset, "arrow_table", return_value=table)
assert len(sf_dataset) == expectation


@pytest.mark.parametrize(
("table", "columns_to_load", "schema", "expectations"),
[
(
pa.Table.from_arrays([[1, 2]], names=["col1"]),
["col1"],
DatasetSchema([IntField("col1")], ["col1"]),
[{"col1": 1}, {"col1": 2}],
),
(
pa.Table.from_arrays([[1, 2], ["[1]", "[2]"]], names=["col1", "sf1"]),
["col1"],
DatasetSchema([IntField("col1"), SfNumpyField("sf1", (1, -1), "int")], ["col1"]),
[{"col1": 1}, {"col1": 2}],
),
(
pa.Table.from_arrays([[1, 2], ["[1]", "[2]"]], names=["col1", "sf1"]),
["col1", "sf1"],
DatasetSchema([IntField("col1"), SfNumpyField("sf1", (1, -1), "int")], ["col1"]),
[{"col1": 1, "sf1": np.array([1])}, {"col1": 2, "sf1": np.array([2])}],
),
],
)
def test_getitem(
mocker: MockFixture,
sf_dataset: SFDataset,
table: pa.Table,
columns_to_load: List[str],
schema: DatasetSchema,
expectations: List[Dict[str, Any]],
) -> None:
mocker.patch.object(sf_dataset, "arrow_table", return_value=table)
mocker.patch.object(sf_dataset, "schema", return_value=schema)
sf_dataset._columns_to_load = columns_to_load
for idx, expectation in enumerate(expectations):
assert sf_dataset[idx] == expectation


@pytest.mark.parametrize(
("table", "columns_to_load", "expectation"),
[
(
pa.Table.from_arrays([["COL1", "COL2"], ["str", "int"]], names=["name", "type"]),
None,
pa.Table.from_arrays(
[["COL1", "COL2"], ["str", "int"], ["col1", "col2"]], names=["name", "type", "lowercase_name"]
),
),
(
pa.Table.from_arrays([["COL1", "COL2"], ["str", "int"]], names=["name", "type"]),
["col1"],
pa.Table.from_arrays([["COL1"], ["str"], ["col1"]], names=["name", "type", "lowercase_name"]),
),
],
)
def test_get_schema_from_database(
mocker: MockFixture,
sf_dataset: SFDataset,
table: pa.Table,
columns_to_load: Optional[List[str]],
expectation: pa.Table,
) -> None:
base_conn_mock = mocker.MagicMock()
conn_mock = mocker.MagicMock()
cur_mock = mocker.MagicMock()
mocker.patch.object(cur_mock, "fetch_arrow_all", return_value=table)
mocker.patch.object(conn_mock, "__enter__", return_value=cur_mock)
mocker.patch.object(base_conn_mock, "cursor", return_value=conn_mock)
sf_dataset._connection = base_conn_mock
sf_dataset._columns_to_load = columns_to_load
assert sf_dataset._get_schema_from_database() == expectation


@pytest.mark.parametrize(
("table", "expectation"),
[
(
pa.Table.from_arrays(
[["col1", "col2", "col3", "col4"], ["VARCHAR", "NUMBER(10,2)", "NUMBER(10,0)", "VARIANT"]],
names=["lowercase_name", "type"],
),
[StringField("col1"), FloatField("col2"), IntField("col3"), SfNumpyField("col4", (1, -1), "float32")],
),
],
)
def test_get_schema_fields(sf_dataset: SFDataset, table: pa.Table, expectation: List[SchemaField]) -> None:
assert sf_dataset._get_schema_fields(table) == expectation


@pytest.mark.parametrize(
("type", "expectation"),
[
("varchar", StringField),
("boolean", BoolField),
("number(1,2)", FloatField),
("number(1,0)", IntField),
("variant", SfNumpyField),
],
)
def test_get_schema_type(sf_dataset: SFDataset, type: str, expectation: SchemaField) -> None:
assert sf_dataset._get_schema_type(type) == expectation


@pytest.mark.parametrize(
("table", "expectation"),
[
(pa.Table.from_arrays([[1, 2]], names=["COL1"]), pa.Table.from_arrays([[1, 2]], names=["col1"])),
(pa.Table.from_arrays([[1, 2]], names=["col1"]), pa.Table.from_arrays([[1, 2]], names=["col1"])),
],
)
def test_get_lower_case_columns(sf_dataset: SFDataset, table: pa.Table, expectation: pa.Table) -> None:
assert sf_dataset._get_lower_case_columns(table) == expectation
3 changes: 3 additions & 0 deletions wicker/core/column_files.py
Original file line number Diff line number Diff line change
Expand Up @@ -318,3 +318,6 @@ def process_object_field(self, field: schema.ObjectField) -> Any:
return data
cbf_info = ColumnBytesFileLocationV1.from_bytes(data)
return self.cbf_cache.read(cbf_info)

def process_sf_variant_field(self, field: schema.VariantField) -> Any:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do you want to store the reference in the s3 backed column files? I don't think we should store numpy in the column files, that would generated tons of small files that both kill the data governance and loading performance?

return self._current_data
Loading
Loading