From 1b67f16ce768225f034fe4698a43e30047817fab Mon Sep 17 00:00:00 2001 From: Future-Outlier Date: Thu, 1 Aug 2024 06:55:26 +0800 Subject: [PATCH] Fix Snowflake Agent Bug (#2605) * fix snowflake agent bug Signed-off-by: Future-Outlier * a work version Signed-off-by: Future-Outlier * Snowflake work version Signed-off-by: Future-Outlier * fix secret encode Signed-off-by: Future-Outlier * all works, I am so happy Signed-off-by: Future-Outlier * improve additional protocol Signed-off-by: Future-Outlier * fix tests Signed-off-by: Future-Outlier * Fix Tests Signed-off-by: Future-Outlier * update agent Signed-off-by: Kevin Su * Add snowflake test Signed-off-by: Kevin Su * nit Signed-off-by: Kevin Su * sd Signed-off-by: Kevin Su * snowflake loglinks Signed-off-by: Future-Outlier * add metadata Signed-off-by: Future-Outlier * secret Signed-off-by: Kevin Su * nit Signed-off-by: Kevin Su * remove table Signed-off-by: Future-Outlier * add comment for get private key Signed-off-by: Future-Outlier * update comments: Signed-off-by: Future-Outlier * Fix Tests Signed-off-by: Future-Outlier * update comments Signed-off-by: Future-Outlier * update comments Signed-off-by: Future-Outlier * Better Secrets Signed-off-by: Future-Outlier * use union secret Signed-off-by: Future-Outlier * Update Changes Signed-off-by: Future-Outlier * use if not get_plugin().secret_requires_group() Signed-off-by: Future-Outlier * Use Union SDK Signed-off-by: Future-Outlier * Update Signed-off-by: Future-Outlier * Fix Secrets Signed-off-by: Future-Outlier * Fix Secrets Signed-off-by: Future-Outlier * remove pacakge.json Signed-off-by: Future-Outlier * lint Signed-off-by: Future-Outlier * add snowflake-connector-python Signed-off-by: Future-Outlier * fix test_snowflake Signed-off-by: Future-Outlier * Try to fix tests Signed-off-by: Future-Outlier * fix tests Signed-off-by: Future-Outlier * Try Fix snowflake Import Signed-off-by: Future-Outlier * snowflake test passed Signed-off-by: Future-Outlier --------- Signed-off-by: Future-Outlier Signed-off-by: Kevin Su Co-authored-by: Kevin Su --- dev-requirements.in | 1 + flytekit/core/context_manager.py | 6 + flytekit/core/type_engine.py | 6 + flytekit/types/structured/__init__.py | 14 +++ flytekit/types/structured/snowflake.py | 106 ++++++++++++++++++ .../types/structured/structured_dataset.py | 17 ++- .../flytekitplugins/bigquery/task.py | 2 +- .../flytekitplugins/snowflake/agent.py | 64 +++++++---- .../flytekitplugins/snowflake/task.py | 33 +++--- plugins/flytekit-snowflake/setup.py | 2 +- .../flytekit-snowflake/tests/test_agent.py | 8 +- .../tests/test_snowflake.py | 24 +++- .../structured_dataset/test_snowflake.py | 70 ++++++++++++ 13 files changed, 298 insertions(+), 55 deletions(-) create mode 100644 flytekit/types/structured/snowflake.py create mode 100644 tests/flytekit/unit/types/structured_dataset/test_snowflake.py diff --git a/dev-requirements.in b/dev-requirements.in index a5758758e9..ce4171018b 100644 --- a/dev-requirements.in +++ b/dev-requirements.in @@ -16,6 +16,7 @@ pre-commit codespell google-cloud-bigquery google-cloud-bigquery-storage +snowflake-connector-python IPython keyrings.alt setuptools_scm diff --git a/flytekit/core/context_manager.py b/flytekit/core/context_manager.py index 340046e941..13691162d5 100644 --- a/flytekit/core/context_manager.py +++ b/flytekit/core/context_manager.py @@ -367,6 +367,12 @@ def get( Retrieves a secret using the resolution order -> Env followed by file. If not found raises a ValueError param encode_mode, defines the mode to open files, it can either be "r" to read file, or "rb" to read binary file """ + + from flytekit.configuration.plugin import get_plugin + + if not get_plugin().secret_requires_group(): + group, group_version = None, None + env_var = self.get_secrets_env_var(group, key, group_version) fpath = self.get_secrets_file(group, key, group_version) v = os.environ.get(env_var) diff --git a/flytekit/core/type_engine.py b/flytekit/core/type_engine.py index 15c03059bb..c8bc881791 100644 --- a/flytekit/core/type_engine.py +++ b/flytekit/core/type_engine.py @@ -983,6 +983,7 @@ def lazy_import_transformers(cls): register_arrow_handlers, register_bigquery_handlers, register_pandas_handlers, + register_snowflake_handlers, ) from flytekit.types.structured.structured_dataset import DuplicateHandlerError @@ -1015,6 +1016,11 @@ def lazy_import_transformers(cls): from flytekit.types import numpy # noqa: F401 if is_imported("PIL"): from flytekit.types.file import image # noqa: F401 + if is_imported("snowflake.connector"): + try: + register_snowflake_handlers() + except DuplicateHandlerError: + logger.debug("Transformer for snowflake is already registered.") @classmethod def to_literal_type(cls, python_type: Type) -> LiteralType: diff --git a/flytekit/types/structured/__init__.py b/flytekit/types/structured/__init__.py index 7dffa49eec..05d1fa86e3 100644 --- a/flytekit/types/structured/__init__.py +++ b/flytekit/types/structured/__init__.py @@ -68,3 +68,17 @@ def register_bigquery_handlers(): "We won't register bigquery handler for structured dataset because " "we can't find the packages google-cloud-bigquery-storage and google-cloud-bigquery" ) + + +def register_snowflake_handlers(): + try: + from .snowflake import PandasToSnowflakeEncodingHandlers, SnowflakeToPandasDecodingHandler + + StructuredDatasetTransformerEngine.register(SnowflakeToPandasDecodingHandler()) + StructuredDatasetTransformerEngine.register(PandasToSnowflakeEncodingHandlers()) + + except ImportError: + logger.info( + "We won't register snowflake handler for structured dataset because " + "we can't find package snowflake-connector-python" + ) diff --git a/flytekit/types/structured/snowflake.py b/flytekit/types/structured/snowflake.py new file mode 100644 index 0000000000..c603b55669 --- /dev/null +++ b/flytekit/types/structured/snowflake.py @@ -0,0 +1,106 @@ +import re +import typing + +import pandas as pd +import snowflake.connector +from snowflake.connector.pandas_tools import write_pandas + +import flytekit +from flytekit import FlyteContext +from flytekit.models import literals +from flytekit.models.types import StructuredDatasetType +from flytekit.types.structured.structured_dataset import ( + StructuredDataset, + StructuredDatasetDecoder, + StructuredDatasetEncoder, + StructuredDatasetMetadata, +) + +SNOWFLAKE = "snowflake" +PROTOCOL_SEP = "\\/|://|:" + + +def get_private_key() -> bytes: + from cryptography.hazmat.backends import default_backend + from cryptography.hazmat.primitives import serialization + + pk_string = flytekit.current_context().secrets.get("private_key", "snowflake", encode_mode="r") + + # Cryptography needs the string to be stripped and converted to bytes + pk_string = pk_string.strip().encode() + p_key = serialization.load_pem_private_key(pk_string, password=None, backend=default_backend()) + + pkb = p_key.private_bytes( + encoding=serialization.Encoding.DER, + format=serialization.PrivateFormat.PKCS8, + encryption_algorithm=serialization.NoEncryption(), + ) + + return pkb + + +def _write_to_sf(structured_dataset: StructuredDataset): + if structured_dataset.uri is None: + raise ValueError("structured_dataset.uri cannot be None.") + + uri = structured_dataset.uri + _, user, account, warehouse, database, schema, table = re.split(PROTOCOL_SEP, uri) + df = structured_dataset.dataframe + + conn = snowflake.connector.connect( + user=user, account=account, private_key=get_private_key(), database=database, schema=schema, warehouse=warehouse + ) + + write_pandas(conn, df, table) + + +def _read_from_sf( + flyte_value: literals.StructuredDataset, current_task_metadata: StructuredDatasetMetadata +) -> pd.DataFrame: + if flyte_value.uri is None: + raise ValueError("structured_dataset.uri cannot be None.") + + uri = flyte_value.uri + _, user, account, warehouse, database, schema, query_id = re.split(PROTOCOL_SEP, uri) + + conn = snowflake.connector.connect( + user=user, + account=account, + private_key=get_private_key(), + database=database, + schema=schema, + warehouse=warehouse, + ) + + cs = conn.cursor() + cs.get_results_from_sfqid(query_id) + return cs.fetch_pandas_all() + + +class PandasToSnowflakeEncodingHandlers(StructuredDatasetEncoder): + def __init__(self): + super().__init__(python_type=pd.DataFrame, protocol=SNOWFLAKE, supported_format="") + + def encode( + self, + ctx: FlyteContext, + structured_dataset: StructuredDataset, + structured_dataset_type: StructuredDatasetType, + ) -> literals.StructuredDataset: + _write_to_sf(structured_dataset) + return literals.StructuredDataset( + uri=typing.cast(str, structured_dataset.uri), metadata=StructuredDatasetMetadata(structured_dataset_type) + ) + + +class SnowflakeToPandasDecodingHandler(StructuredDatasetDecoder): + def __init__(self): + super().__init__(pd.DataFrame, protocol=SNOWFLAKE, supported_format="") + + def decode( + self, + ctx: FlyteContext, + flyte_value: literals.StructuredDataset, + current_task_metadata: StructuredDatasetMetadata, + ) -> pd.DataFrame: + return _read_from_sf(flyte_value, current_task_metadata) diff --git a/flytekit/types/structured/structured_dataset.py b/flytekit/types/structured/structured_dataset.py index c11519462e..128ddab168 100644 --- a/flytekit/types/structured/structured_dataset.py +++ b/flytekit/types/structured/structured_dataset.py @@ -6,7 +6,7 @@ import typing from abc import ABC, abstractmethod from dataclasses import dataclass, field, is_dataclass -from typing import Dict, Generator, Optional, Type, Union +from typing import Dict, Generator, List, Optional, Type, Union from dataclasses_json import config from fsspec.utils import get_protocol @@ -222,7 +222,12 @@ def extract_cols_and_format( class StructuredDatasetEncoder(ABC): - def __init__(self, python_type: Type[T], protocol: Optional[str] = None, supported_format: Optional[str] = None): + def __init__( + self, + python_type: Type[T], + protocol: Optional[str] = None, + supported_format: Optional[str] = None, + ): """ Extend this abstract class, implement the encode function, and register your concrete class with the StructuredDatasetTransformerEngine class in order for the core flytekit type engine to handle @@ -284,7 +289,13 @@ def encode( class StructuredDatasetDecoder(ABC): - def __init__(self, python_type: Type[DF], protocol: Optional[str] = None, supported_format: Optional[str] = None): + def __init__( + self, + python_type: Type[DF], + protocol: Optional[str] = None, + supported_format: Optional[str] = None, + additional_protocols: Optional[List[str]] = None, + ): """ Extend this abstract class, implement the decode function, and register your concrete class with the StructuredDatasetTransformerEngine class in order for the core flytekit type engine to handle diff --git a/plugins/flytekit-bigquery/flytekitplugins/bigquery/task.py b/plugins/flytekit-bigquery/flytekitplugins/bigquery/task.py index 5ae03b3f88..c1707f09af 100644 --- a/plugins/flytekit-bigquery/flytekitplugins/bigquery/task.py +++ b/plugins/flytekit-bigquery/flytekitplugins/bigquery/task.py @@ -38,7 +38,7 @@ def __init__( self, name: str, query_template: str, - task_config: Optional[BigQueryConfig], + task_config: BigQueryConfig, inputs: Optional[Dict[str, Type]] = None, output_structured_dataset_type: Optional[Type[StructuredDataset]] = None, **kwargs, diff --git a/plugins/flytekit-snowflake/flytekitplugins/snowflake/agent.py b/plugins/flytekit-snowflake/flytekitplugins/snowflake/agent.py index 71eba91186..831b431afa 100644 --- a/plugins/flytekit-snowflake/flytekitplugins/snowflake/agent.py +++ b/plugins/flytekit-snowflake/flytekitplugins/snowflake/agent.py @@ -1,18 +1,17 @@ from dataclasses import dataclass from typing import Optional -from flyteidl.core.execution_pb2 import TaskExecution +from flyteidl.core.execution_pb2 import TaskExecution, TaskLog -from flytekit import FlyteContextManager, StructuredDataset, lazy_module, logger +from flytekit import FlyteContextManager, StructuredDataset, logger from flytekit.core.type_engine import TypeEngine from flytekit.extend.backend.base_agent import AgentRegistry, AsyncAgentBase, Resource, ResourceMeta -from flytekit.extend.backend.utils import convert_to_flyte_phase +from flytekit.extend.backend.utils import convert_to_flyte_phase, get_agent_secret from flytekit.models import literals from flytekit.models.literals import LiteralMap from flytekit.models.task import TaskTemplate from flytekit.models.types import LiteralType, StructuredDatasetType - -snowflake_connector = lazy_module("snowflake.connector") +from snowflake import connector as sc TASK_TYPE = "snowflake" SNOWFLAKE_PRIVATE_KEY = "snowflake_private_key" @@ -25,17 +24,17 @@ class SnowflakeJobMetadata(ResourceMeta): database: str schema: str warehouse: str - table: str query_id: str + has_output: bool def get_private_key(): from cryptography.hazmat.backends import default_backend from cryptography.hazmat.primitives import serialization - import flytekit - - pk_string = flytekit.current_context().secrets.get(SNOWFLAKE_PRIVATE_KEY, encode_mode="rb") + pk_string = get_agent_secret(SNOWFLAKE_PRIVATE_KEY) + # cryptography needs str to be stripped and converted to bytes + pk_string = pk_string.strip().encode() p_key = serialization.load_pem_private_key(pk_string, password=None, backend=default_backend()) pkb = p_key.private_bytes( @@ -47,8 +46,8 @@ def get_private_key(): return pkb -def get_connection(metadata: SnowflakeJobMetadata) -> snowflake_connector: - return snowflake_connector.connect( +def get_connection(metadata: SnowflakeJobMetadata) -> sc: + return sc.connect( user=metadata.user, account=metadata.account, private_key=get_private_key(), @@ -69,10 +68,11 @@ async def create( ) -> SnowflakeJobMetadata: ctx = FlyteContextManager.current_context() literal_types = task_template.interface.inputs - params = TypeEngine.literal_map_to_kwargs(ctx, inputs, literal_types=literal_types) if inputs else None + + params = TypeEngine.literal_map_to_kwargs(ctx, inputs, literal_types=literal_types) if inputs.literals else None config = task_template.config - conn = snowflake_connector.connect( + conn = sc.connect( user=config["user"], account=config["account"], private_key=get_private_key(), @@ -82,7 +82,7 @@ async def create( ) cs = conn.cursor() - cs.execute_async(task_template.sql.statement, params=params) + cs.execute_async(task_template.sql.statement, params) return SnowflakeJobMetadata( user=config["user"], @@ -90,35 +90,42 @@ async def create( database=config["database"], schema=config["schema"], warehouse=config["warehouse"], - table=config["table"], - query_id=str(cs.sfqid), + query_id=cs.sfqid, + has_output=task_template.interface.outputs is not None and len(task_template.interface.outputs) > 0, ) async def get(self, resource_meta: SnowflakeJobMetadata, **kwargs) -> Resource: conn = get_connection(resource_meta) try: query_status = conn.get_query_status_throw_if_error(resource_meta.query_id) - except snowflake_connector.ProgrammingError as err: + except sc.ProgrammingError as err: logger.error("Failed to get snowflake job status with error:", err.msg) return Resource(phase=TaskExecution.FAILED) + + log_link = TaskLog( + uri=construct_query_link(resource_meta=resource_meta), + name="Snowflake Query Details", + ) + # The snowflake job's state is determined by query status. + # https://github.com/snowflakedb/snowflake-connector-python/blob/main/src/snowflake/connector/constants.py#L373 cur_phase = convert_to_flyte_phase(str(query_status.name)) res = None - if cur_phase == TaskExecution.SUCCEEDED: + if cur_phase == TaskExecution.SUCCEEDED and resource_meta.has_output: ctx = FlyteContextManager.current_context() - output_metadata = f"snowflake://{resource_meta.user}:{resource_meta.account}/{resource_meta.warehouse}/{resource_meta.database}/{resource_meta.schema}/{resource_meta.table}" + uri = f"snowflake://{resource_meta.user}:{resource_meta.account}/{resource_meta.warehouse}/{resource_meta.database}/{resource_meta.schema}/{resource_meta.query_id}" res = literals.LiteralMap( { "results": TypeEngine.to_literal( ctx, - StructuredDataset(uri=output_metadata), + StructuredDataset(uri=uri), StructuredDataset, LiteralType(structured_dataset_type=StructuredDatasetType(format="")), ) } - ).to_flyte_idl() + ) - return Resource(phase=cur_phase, outputs=res) + return Resource(phase=cur_phase, outputs=res, log_links=[log_link]) async def delete(self, resource_meta: SnowflakeJobMetadata, **kwargs): conn = get_connection(resource_meta) @@ -131,4 +138,17 @@ async def delete(self, resource_meta: SnowflakeJobMetadata, **kwargs): conn.close() +def construct_query_link(resource_meta: SnowflakeJobMetadata) -> str: + base_url = "https://app.snowflake.com" + + # Extract the account and region (assuming the format is account-region, you might need to adjust this based on your actual account format) + account_parts = resource_meta.account.split("-") + account = account_parts[0] + region = account_parts[1] if len(account_parts) > 1 else "" + + url = f"{base_url}/{region}/{account}/#/compute/history/queries/{resource_meta.query_id}/detail" + + return url + + AgentRegistry.register(SnowflakeAgent()) diff --git a/plugins/flytekit-snowflake/flytekitplugins/snowflake/task.py b/plugins/flytekit-snowflake/flytekitplugins/snowflake/task.py index 9ac9980a88..13cd15bee0 100644 --- a/plugins/flytekit-snowflake/flytekitplugins/snowflake/task.py +++ b/plugins/flytekit-snowflake/flytekitplugins/snowflake/task.py @@ -12,27 +12,27 @@ _DATABASE_FIELD = "database" _SCHEMA_FIELD = "schema" _WAREHOUSE_FIELD = "warehouse" -_TABLE_FIELD = "table" @dataclass class SnowflakeConfig(object): """ SnowflakeConfig should be used to configure a Snowflake Task. + You can use the query below to retrieve all metadata for this config. + + SELECT + CURRENT_USER() AS "User", + CONCAT(CURRENT_ORGANIZATION_NAME(), '-', CURRENT_ACCOUNT_NAME()) AS "Account", + CURRENT_DATABASE() AS "Database", + CURRENT_SCHEMA() AS "Schema", + CURRENT_WAREHOUSE() AS "Warehouse"; """ - # The user to query against - user: Optional[str] = None - # The account to query against - account: Optional[str] = None - # The database to query against - database: Optional[str] = None - # The optional schema to separate query execution. - schema: Optional[str] = None - # The optional warehouse to set for the given Snowflake query - warehouse: Optional[str] = None - # The optional table to set for the given Snowflake query - table: Optional[str] = None + user: str + account: str + database: str + schema: str + warehouse: str class SnowflakeTask(AsyncAgentExecutorMixin, SQLTask[SnowflakeConfig]): @@ -47,7 +47,7 @@ def __init__( self, name: str, query_template: str, - task_config: Optional[SnowflakeConfig] = None, + task_config: SnowflakeConfig, inputs: Optional[Dict[str, Type]] = None, output_schema_type: Optional[Type[StructuredDataset]] = None, **kwargs, @@ -63,13 +63,13 @@ def __init__( :param output_schema_type: If some data is produced by this query, then you can specify the output schema type :param kwargs: All other args required by Parent type - SQLTask """ + outputs = None if output_schema_type is not None: outputs = { "results": output_schema_type, } - if task_config is None: - task_config = SnowflakeConfig() + super().__init__( name=name, task_config=task_config, @@ -88,7 +88,6 @@ def get_config(self, settings: SerializationSettings) -> Dict[str, str]: _DATABASE_FIELD: self.task_config.database, _SCHEMA_FIELD: self.task_config.schema, _WAREHOUSE_FIELD: self.task_config.warehouse, - _TABLE_FIELD: self.task_config.table, } def get_sql(self, settings: SerializationSettings) -> Optional[_task_model.Sql]: diff --git a/plugins/flytekit-snowflake/setup.py b/plugins/flytekit-snowflake/setup.py index b5265c299e..ec1d6e0158 100644 --- a/plugins/flytekit-snowflake/setup.py +++ b/plugins/flytekit-snowflake/setup.py @@ -4,7 +4,7 @@ microlib_name = f"flytekitplugins-{PLUGIN_NAME}" -plugin_requires = ["flytekit>1.10.7", "snowflake-connector-python>=3.1.0"] +plugin_requires = ["flytekit>1.13.1", "snowflake-connector-python>=3.11.0"] __version__ = "0.0.0+develop" diff --git a/plugins/flytekit-snowflake/tests/test_agent.py b/plugins/flytekit-snowflake/tests/test_agent.py index adc699061f..e63ddb9f85 100644 --- a/plugins/flytekit-snowflake/tests/test_agent.py +++ b/plugins/flytekit-snowflake/tests/test_agent.py @@ -55,7 +55,6 @@ async def test_snowflake_agent(mock_get_private_key): "database": "dummy_database", "schema": "dummy_schema", "warehouse": "dummy_warehouse", - "table": "dummy_table", } int_type = types.LiteralType(types.SimpleType.INTEGER) @@ -86,11 +85,11 @@ async def test_snowflake_agent(mock_get_private_key): snowflake_metadata = SnowflakeJobMetadata( user="dummy_user", account="dummy_account", - table="dummy_table", database="dummy_database", schema="dummy_schema", warehouse="dummy_warehouse", query_id="dummy_id", + has_output=False, ) metadata = await agent.create(dummy_template, task_inputs) @@ -98,10 +97,7 @@ async def test_snowflake_agent(mock_get_private_key): resource = await agent.get(metadata) assert resource.phase == TaskExecution.SUCCEEDED - assert ( - resource.outputs.literals["results"].scalar.structured_dataset.uri - == "snowflake://dummy_user:dummy_account/dummy_warehouse/dummy_database/dummy_schema/dummy_table" - ) + assert resource.outputs == None delete_response = await agent.delete(snowflake_metadata) assert delete_response is None diff --git a/plugins/flytekit-snowflake/tests/test_snowflake.py b/plugins/flytekit-snowflake/tests/test_snowflake.py index 672f4a19ad..61db311c68 100644 --- a/plugins/flytekit-snowflake/tests/test_snowflake.py +++ b/plugins/flytekit-snowflake/tests/test_snowflake.py @@ -21,7 +21,11 @@ def test_serialization(): name="flytekit.demo.snowflake_task.query", inputs=kwtypes(ds=str), task_config=SnowflakeConfig( - account="snowflake", warehouse="my_warehouse", schema="my_schema", database="my_database" + account="snowflake", + user="my_user", + warehouse="my_warehouse", + schema="my_schema", + database="my_database", ), query_template=query_template, # the schema literal's backend uri will be equal to the value of .raw_output_data @@ -64,6 +68,13 @@ def test_local_exec(): snowflake_task = SnowflakeTask( name="flytekit.demo.snowflake_task.query2", inputs=kwtypes(ds=str), + task_config=SnowflakeConfig( + account="TEST-ACCOUNT", + user="FLYTE", + database="FLYTEAGENT", + schema="PUBLIC", + warehouse="COMPUTE_WH", + ), query_template="select 1\n", # the schema literal's backend uri will be equal to the value of .raw_output_data output_schema_type=FlyteSchema, @@ -73,15 +84,18 @@ def test_local_exec(): assert snowflake_task.query_template == "select 1" assert len(snowflake_task.interface.outputs) == 1 - # will not run locally - with pytest.raises(Exception): - snowflake_task() - def test_sql_template(): snowflake_task = SnowflakeTask( name="flytekit.demo.snowflake_task.query2", inputs=kwtypes(ds=str), + task_config=SnowflakeConfig( + account="TEST-ACCOUNT", + user="FLYTE", + database="FLYTEAGENT", + schema="PUBLIC", + warehouse="COMPUTE_WH", + ), query_template="""select 1 from\t custom where column = 1""", output_schema_type=FlyteSchema, diff --git a/tests/flytekit/unit/types/structured_dataset/test_snowflake.py b/tests/flytekit/unit/types/structured_dataset/test_snowflake.py new file mode 100644 index 0000000000..ab85f9e013 --- /dev/null +++ b/tests/flytekit/unit/types/structured_dataset/test_snowflake.py @@ -0,0 +1,70 @@ +import mock +import pytest +from typing_extensions import Annotated +import sys + +from flytekit import StructuredDataset, kwtypes, task, workflow + +try: + import numpy as np + numpy_installed = True +except ImportError: + numpy_installed = False + +skip_if_wrong_numpy_version = pytest.mark.skipif( + not numpy_installed or np.__version__ > '1.26.4', + reason="Test skipped because either NumPy is not installed or the installed version is greater than 1.26.4. " + "Ensure that NumPy is installed and the version is <= 1.26.4, as required by the Snowflake connector." + +) + +@pytest.mark.skipif("pandas" not in sys.modules, reason="Pandas is not installed.") +@skip_if_wrong_numpy_version +@mock.patch("flytekit.types.structured.snowflake.get_private_key", return_value="pb") +@mock.patch("snowflake.connector.connect") +def test_sf_wf(mock_connect, mock_get_private_key): + import pandas as pd + from flytekit.lazy_import.lazy_module import is_imported + from flytekit.types.structured import register_snowflake_handlers + from flytekit.types.structured.structured_dataset import DuplicateHandlerError + + if is_imported("snowflake.connector"): + try: + register_snowflake_handlers() + except DuplicateHandlerError: + pass + + + pd_df = pd.DataFrame({"Name": ["Tom", "Joseph"], "Age": [20, 22]}) + my_cols = kwtypes(Name=str, Age=int) + + @task + def gen_df() -> Annotated[pd.DataFrame, my_cols, "parquet"]: + return pd_df + + @task + def t1(df: pd.DataFrame) -> Annotated[StructuredDataset, my_cols]: + return StructuredDataset( + dataframe=df, + uri="snowflake://dummy_user/dummy_account/COMPUTE_WH/FLYTEAGENT/PUBLIC/TEST" + ) + + @task + def t2(sd: Annotated[StructuredDataset, my_cols]) -> pd.DataFrame: + return sd.open(pd.DataFrame).all() + + @workflow + def wf() -> pd.DataFrame: + df = gen_df() + sd = t1(df=df) + return t2(sd=sd) + + class mock_dataframe: + def to_dataframe(self): + return pd.DataFrame({"Name": ["Tom", "Joseph"], "Age": [20, 22]}) + + mock_connect_instance = mock_connect.return_value + mock_coursor_instance = mock_connect_instance.cursor.return_value + mock_coursor_instance.fetch_pandas_all.return_value = mock_dataframe().to_dataframe() + + assert wf().equals(pd_df)