Skip to content

Commit

Permalink
Fix Snowflake Agent Bug (#2605)
Browse files Browse the repository at this point in the history
* fix snowflake agent bug

Signed-off-by: Future-Outlier <[email protected]>

* a work version

Signed-off-by: Future-Outlier <[email protected]>

* Snowflake work version

Signed-off-by: Future-Outlier <[email protected]>

* fix secret encode

Signed-off-by: Future-Outlier <[email protected]>

* all works, I am so happy

Signed-off-by: Future-Outlier <[email protected]>

* improve additional protocol

Signed-off-by: Future-Outlier <[email protected]>

* fix tests

Signed-off-by: Future-Outlier <[email protected]>

* Fix Tests

Signed-off-by: Future-Outlier <[email protected]>

* update agent

Signed-off-by: Kevin Su <[email protected]>

* Add snowflake test

Signed-off-by: Kevin Su <[email protected]>

* nit

Signed-off-by: Kevin Su <[email protected]>

* sd

Signed-off-by: Kevin Su <[email protected]>

* snowflake loglinks

Signed-off-by: Future-Outlier <[email protected]>

* add metadata

Signed-off-by: Future-Outlier <[email protected]>

* secret

Signed-off-by: Kevin Su <[email protected]>

* nit

Signed-off-by: Kevin Su <[email protected]>

* remove table

Signed-off-by: Future-Outlier <[email protected]>

* add comment for get private key

Signed-off-by: Future-Outlier <[email protected]>

* update comments:

Signed-off-by: Future-Outlier <[email protected]>

* Fix Tests

Signed-off-by: Future-Outlier <[email protected]>

* update comments

Signed-off-by: Future-Outlier <[email protected]>

* update comments

Signed-off-by: Future-Outlier <[email protected]>

* Better Secrets

Signed-off-by: Future-Outlier <[email protected]>

* use union secret

Signed-off-by: Future-Outlier <[email protected]>

* Update Changes

Signed-off-by: Future-Outlier <[email protected]>

* use if not get_plugin().secret_requires_group()

Signed-off-by: Future-Outlier <[email protected]>

* Use Union SDK

Signed-off-by: Future-Outlier <[email protected]>

* Update

Signed-off-by: Future-Outlier <[email protected]>

* Fix Secrets

Signed-off-by: Future-Outlier <[email protected]>

* Fix Secrets

Signed-off-by: Future-Outlier <[email protected]>

* remove pacakge.json

Signed-off-by: Future-Outlier <[email protected]>

* lint

Signed-off-by: Future-Outlier <[email protected]>

* add snowflake-connector-python

Signed-off-by: Future-Outlier <[email protected]>

* fix test_snowflake

Signed-off-by: Future-Outlier <[email protected]>

* Try to fix tests

Signed-off-by: Future-Outlier <[email protected]>

* fix tests

Signed-off-by: Future-Outlier <[email protected]>

* Try Fix snowflake Import

Signed-off-by: Future-Outlier <[email protected]>

* snowflake test passed

Signed-off-by: Future-Outlier <[email protected]>

---------

Signed-off-by: Future-Outlier <[email protected]>
Signed-off-by: Kevin Su <[email protected]>
Co-authored-by: Kevin Su <[email protected]>
  • Loading branch information
Future-Outlier and pingsutw authored Jul 31, 2024
1 parent 4f96b33 commit 1b67f16
Show file tree
Hide file tree
Showing 13 changed files with 298 additions and 55 deletions.
1 change: 1 addition & 0 deletions dev-requirements.in
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ pre-commit
codespell
google-cloud-bigquery
google-cloud-bigquery-storage
snowflake-connector-python
IPython
keyrings.alt
setuptools_scm
Expand Down
6 changes: 6 additions & 0 deletions flytekit/core/context_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -367,6 +367,12 @@ def get(
Retrieves a secret using the resolution order -> Env followed by file. If not found raises a ValueError
param encode_mode, defines the mode to open files, it can either be "r" to read file, or "rb" to read binary file
"""

from flytekit.configuration.plugin import get_plugin

if not get_plugin().secret_requires_group():
group, group_version = None, None

env_var = self.get_secrets_env_var(group, key, group_version)
fpath = self.get_secrets_file(group, key, group_version)
v = os.environ.get(env_var)
Expand Down
6 changes: 6 additions & 0 deletions flytekit/core/type_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -983,6 +983,7 @@ def lazy_import_transformers(cls):
register_arrow_handlers,
register_bigquery_handlers,
register_pandas_handlers,
register_snowflake_handlers,
)
from flytekit.types.structured.structured_dataset import DuplicateHandlerError

Expand Down Expand Up @@ -1015,6 +1016,11 @@ def lazy_import_transformers(cls):
from flytekit.types import numpy # noqa: F401
if is_imported("PIL"):
from flytekit.types.file import image # noqa: F401
if is_imported("snowflake.connector"):
try:
register_snowflake_handlers()
except DuplicateHandlerError:
logger.debug("Transformer for snowflake is already registered.")

@classmethod
def to_literal_type(cls, python_type: Type) -> LiteralType:
Expand Down
14 changes: 14 additions & 0 deletions flytekit/types/structured/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,3 +68,17 @@ def register_bigquery_handlers():
"We won't register bigquery handler for structured dataset because "
"we can't find the packages google-cloud-bigquery-storage and google-cloud-bigquery"
)


def register_snowflake_handlers():
try:
from .snowflake import PandasToSnowflakeEncodingHandlers, SnowflakeToPandasDecodingHandler

StructuredDatasetTransformerEngine.register(SnowflakeToPandasDecodingHandler())
StructuredDatasetTransformerEngine.register(PandasToSnowflakeEncodingHandlers())

except ImportError:
logger.info(
"We won't register snowflake handler for structured dataset because "
"we can't find package snowflake-connector-python"
)
106 changes: 106 additions & 0 deletions flytekit/types/structured/snowflake.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
import re
import typing

import pandas as pd
import snowflake.connector
from snowflake.connector.pandas_tools import write_pandas

import flytekit
from flytekit import FlyteContext
from flytekit.models import literals
from flytekit.models.types import StructuredDatasetType
from flytekit.types.structured.structured_dataset import (
StructuredDataset,
StructuredDatasetDecoder,
StructuredDatasetEncoder,
StructuredDatasetMetadata,
)

SNOWFLAKE = "snowflake"
PROTOCOL_SEP = "\\/|://|:"


def get_private_key() -> bytes:
from cryptography.hazmat.backends import default_backend
from cryptography.hazmat.primitives import serialization

pk_string = flytekit.current_context().secrets.get("private_key", "snowflake", encode_mode="r")

# Cryptography needs the string to be stripped and converted to bytes
pk_string = pk_string.strip().encode()
p_key = serialization.load_pem_private_key(pk_string, password=None, backend=default_backend())

pkb = p_key.private_bytes(
encoding=serialization.Encoding.DER,
format=serialization.PrivateFormat.PKCS8,
encryption_algorithm=serialization.NoEncryption(),
)

return pkb


def _write_to_sf(structured_dataset: StructuredDataset):
if structured_dataset.uri is None:
raise ValueError("structured_dataset.uri cannot be None.")

uri = structured_dataset.uri
_, user, account, warehouse, database, schema, table = re.split(PROTOCOL_SEP, uri)
df = structured_dataset.dataframe

conn = snowflake.connector.connect(
user=user, account=account, private_key=get_private_key(), database=database, schema=schema, warehouse=warehouse
)

write_pandas(conn, df, table)


def _read_from_sf(
flyte_value: literals.StructuredDataset, current_task_metadata: StructuredDatasetMetadata
) -> pd.DataFrame:
if flyte_value.uri is None:
raise ValueError("structured_dataset.uri cannot be None.")

uri = flyte_value.uri
_, user, account, warehouse, database, schema, query_id = re.split(PROTOCOL_SEP, uri)

conn = snowflake.connector.connect(
user=user,
account=account,
private_key=get_private_key(),
database=database,
schema=schema,
warehouse=warehouse,
)

cs = conn.cursor()
cs.get_results_from_sfqid(query_id)
return cs.fetch_pandas_all()


class PandasToSnowflakeEncodingHandlers(StructuredDatasetEncoder):
def __init__(self):
super().__init__(python_type=pd.DataFrame, protocol=SNOWFLAKE, supported_format="")

def encode(
self,
ctx: FlyteContext,
structured_dataset: StructuredDataset,
structured_dataset_type: StructuredDatasetType,
) -> literals.StructuredDataset:
_write_to_sf(structured_dataset)
return literals.StructuredDataset(
uri=typing.cast(str, structured_dataset.uri), metadata=StructuredDatasetMetadata(structured_dataset_type)
)


class SnowflakeToPandasDecodingHandler(StructuredDatasetDecoder):
def __init__(self):
super().__init__(pd.DataFrame, protocol=SNOWFLAKE, supported_format="")

def decode(
self,
ctx: FlyteContext,
flyte_value: literals.StructuredDataset,
current_task_metadata: StructuredDatasetMetadata,
) -> pd.DataFrame:
return _read_from_sf(flyte_value, current_task_metadata)
17 changes: 14 additions & 3 deletions flytekit/types/structured/structured_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import typing
from abc import ABC, abstractmethod
from dataclasses import dataclass, field, is_dataclass
from typing import Dict, Generator, Optional, Type, Union
from typing import Dict, Generator, List, Optional, Type, Union

from dataclasses_json import config
from fsspec.utils import get_protocol
Expand Down Expand Up @@ -222,7 +222,12 @@ def extract_cols_and_format(


class StructuredDatasetEncoder(ABC):
def __init__(self, python_type: Type[T], protocol: Optional[str] = None, supported_format: Optional[str] = None):
def __init__(
self,
python_type: Type[T],
protocol: Optional[str] = None,
supported_format: Optional[str] = None,
):
"""
Extend this abstract class, implement the encode function, and register your concrete class with the
StructuredDatasetTransformerEngine class in order for the core flytekit type engine to handle
Expand Down Expand Up @@ -284,7 +289,13 @@ def encode(


class StructuredDatasetDecoder(ABC):
def __init__(self, python_type: Type[DF], protocol: Optional[str] = None, supported_format: Optional[str] = None):
def __init__(
self,
python_type: Type[DF],
protocol: Optional[str] = None,
supported_format: Optional[str] = None,
additional_protocols: Optional[List[str]] = None,
):
"""
Extend this abstract class, implement the decode function, and register your concrete class with the
StructuredDatasetTransformerEngine class in order for the core flytekit type engine to handle
Expand Down
2 changes: 1 addition & 1 deletion plugins/flytekit-bigquery/flytekitplugins/bigquery/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ def __init__(
self,
name: str,
query_template: str,
task_config: Optional[BigQueryConfig],
task_config: BigQueryConfig,
inputs: Optional[Dict[str, Type]] = None,
output_structured_dataset_type: Optional[Type[StructuredDataset]] = None,
**kwargs,
Expand Down
64 changes: 42 additions & 22 deletions plugins/flytekit-snowflake/flytekitplugins/snowflake/agent.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,17 @@
from dataclasses import dataclass
from typing import Optional

from flyteidl.core.execution_pb2 import TaskExecution
from flyteidl.core.execution_pb2 import TaskExecution, TaskLog

from flytekit import FlyteContextManager, StructuredDataset, lazy_module, logger
from flytekit import FlyteContextManager, StructuredDataset, logger
from flytekit.core.type_engine import TypeEngine
from flytekit.extend.backend.base_agent import AgentRegistry, AsyncAgentBase, Resource, ResourceMeta
from flytekit.extend.backend.utils import convert_to_flyte_phase
from flytekit.extend.backend.utils import convert_to_flyte_phase, get_agent_secret
from flytekit.models import literals
from flytekit.models.literals import LiteralMap
from flytekit.models.task import TaskTemplate
from flytekit.models.types import LiteralType, StructuredDatasetType

snowflake_connector = lazy_module("snowflake.connector")
from snowflake import connector as sc

TASK_TYPE = "snowflake"
SNOWFLAKE_PRIVATE_KEY = "snowflake_private_key"
Expand All @@ -25,17 +24,17 @@ class SnowflakeJobMetadata(ResourceMeta):
database: str
schema: str
warehouse: str
table: str
query_id: str
has_output: bool


def get_private_key():
from cryptography.hazmat.backends import default_backend
from cryptography.hazmat.primitives import serialization

import flytekit

pk_string = flytekit.current_context().secrets.get(SNOWFLAKE_PRIVATE_KEY, encode_mode="rb")
pk_string = get_agent_secret(SNOWFLAKE_PRIVATE_KEY)
# cryptography needs str to be stripped and converted to bytes
pk_string = pk_string.strip().encode()
p_key = serialization.load_pem_private_key(pk_string, password=None, backend=default_backend())

pkb = p_key.private_bytes(
Expand All @@ -47,8 +46,8 @@ def get_private_key():
return pkb


def get_connection(metadata: SnowflakeJobMetadata) -> snowflake_connector:
return snowflake_connector.connect(
def get_connection(metadata: SnowflakeJobMetadata) -> sc:
return sc.connect(
user=metadata.user,
account=metadata.account,
private_key=get_private_key(),
Expand All @@ -69,10 +68,11 @@ async def create(
) -> SnowflakeJobMetadata:
ctx = FlyteContextManager.current_context()
literal_types = task_template.interface.inputs
params = TypeEngine.literal_map_to_kwargs(ctx, inputs, literal_types=literal_types) if inputs else None

params = TypeEngine.literal_map_to_kwargs(ctx, inputs, literal_types=literal_types) if inputs.literals else None

config = task_template.config
conn = snowflake_connector.connect(
conn = sc.connect(
user=config["user"],
account=config["account"],
private_key=get_private_key(),
Expand All @@ -82,43 +82,50 @@ async def create(
)

cs = conn.cursor()
cs.execute_async(task_template.sql.statement, params=params)
cs.execute_async(task_template.sql.statement, params)

return SnowflakeJobMetadata(
user=config["user"],
account=config["account"],
database=config["database"],
schema=config["schema"],
warehouse=config["warehouse"],
table=config["table"],
query_id=str(cs.sfqid),
query_id=cs.sfqid,
has_output=task_template.interface.outputs is not None and len(task_template.interface.outputs) > 0,
)

async def get(self, resource_meta: SnowflakeJobMetadata, **kwargs) -> Resource:
conn = get_connection(resource_meta)
try:
query_status = conn.get_query_status_throw_if_error(resource_meta.query_id)
except snowflake_connector.ProgrammingError as err:
except sc.ProgrammingError as err:
logger.error("Failed to get snowflake job status with error:", err.msg)
return Resource(phase=TaskExecution.FAILED)

log_link = TaskLog(
uri=construct_query_link(resource_meta=resource_meta),
name="Snowflake Query Details",
)
# The snowflake job's state is determined by query status.
# https://github.com/snowflakedb/snowflake-connector-python/blob/main/src/snowflake/connector/constants.py#L373
cur_phase = convert_to_flyte_phase(str(query_status.name))
res = None

if cur_phase == TaskExecution.SUCCEEDED:
if cur_phase == TaskExecution.SUCCEEDED and resource_meta.has_output:
ctx = FlyteContextManager.current_context()
output_metadata = f"snowflake://{resource_meta.user}:{resource_meta.account}/{resource_meta.warehouse}/{resource_meta.database}/{resource_meta.schema}/{resource_meta.table}"
uri = f"snowflake://{resource_meta.user}:{resource_meta.account}/{resource_meta.warehouse}/{resource_meta.database}/{resource_meta.schema}/{resource_meta.query_id}"
res = literals.LiteralMap(
{
"results": TypeEngine.to_literal(
ctx,
StructuredDataset(uri=output_metadata),
StructuredDataset(uri=uri),
StructuredDataset,
LiteralType(structured_dataset_type=StructuredDatasetType(format="")),
)
}
).to_flyte_idl()
)

return Resource(phase=cur_phase, outputs=res)
return Resource(phase=cur_phase, outputs=res, log_links=[log_link])

async def delete(self, resource_meta: SnowflakeJobMetadata, **kwargs):
conn = get_connection(resource_meta)
Expand All @@ -131,4 +138,17 @@ async def delete(self, resource_meta: SnowflakeJobMetadata, **kwargs):
conn.close()


def construct_query_link(resource_meta: SnowflakeJobMetadata) -> str:
base_url = "https://app.snowflake.com"

# Extract the account and region (assuming the format is account-region, you might need to adjust this based on your actual account format)
account_parts = resource_meta.account.split("-")
account = account_parts[0]
region = account_parts[1] if len(account_parts) > 1 else ""

url = f"{base_url}/{region}/{account}/#/compute/history/queries/{resource_meta.query_id}/detail"

return url


AgentRegistry.register(SnowflakeAgent())
Loading

0 comments on commit 1b67f16

Please sign in to comment.