From ff9c57eb13cb95c8fcd52fcc0ad9bc01eb698fd7 Mon Sep 17 00:00:00 2001 From: Thomas Newton Date: Tue, 12 Sep 2023 16:13:13 +0100 Subject: [PATCH 01/24] Add Azure blob storage configs Signed-off-by: Thomas Newton --- flytekit/configuration/__init__.py | 23 +++++++++++++++++++++++ flytekit/configuration/internal.py | 8 ++++++++ 2 files changed, 31 insertions(+) diff --git a/flytekit/configuration/__init__.py b/flytekit/configuration/__init__.py index 8e5ccf2fe2..35458e2f5e 100644 --- a/flytekit/configuration/__init__.py +++ b/flytekit/configuration/__init__.py @@ -560,6 +560,27 @@ def auto(cls, config_file: typing.Union[str, ConfigFile] = None) -> GCSConfig: kwargs = {} kwargs = set_if_exists(kwargs, "gsutil_parallelism", _internal.GCP.GSUTIL_PARALLELISM.read(config_file)) return GCSConfig(**kwargs) + +@dataclass(init=True, repr=True, eq=True, frozen=True) +class AzureBlobStorageConfig(object): + """ + Any Azure Blob Storage specific configuration. + """ + + account_name: typing.Optional[str] = None + account_key: typing.Optional[str] = None + client_id: typing.Optional[str] = None + client_secret: typing.Optional[str] = None + + @classmethod + def auto(cls, config_file: typing.Union[str, ConfigFile] = None) -> GCSConfig: + config_file = get_config_file(config_file) + kwargs = {} + kwargs = set_if_exists(kwargs, "account_name", _internal.AZURE.ACCOUNT_NAME.read(config_file)) + kwargs = set_if_exists(kwargs, "account_key", _internal.AZURE.ACCOUNT_KEY.read(config_file)) + kwargs = set_if_exists(kwargs, "client_id", _internal.AZURE.CLIENT_ID.read(config_file)) + kwargs = set_if_exists(kwargs, "client_secret", _internal.AZURE.CLIENT_SECRET.read(config_file)) + return AzureBlobStorageConfig(**kwargs) @dataclass(init=True, repr=True, eq=True, frozen=True) @@ -572,11 +593,13 @@ class DataConfig(object): s3: S3Config = S3Config() gcs: GCSConfig = GCSConfig() + azure: AzureBlobStorageConfig = AzureBlobStorageConfig() @classmethod def auto(cls, config_file: typing.Union[str, ConfigFile] = None) -> DataConfig: config_file = get_config_file(config_file) return DataConfig( + azure=AzureBlobStorageConfig.auto(config_file), s3=S3Config.auto(config_file), gcs=GCSConfig.auto(config_file), ) diff --git a/flytekit/configuration/internal.py b/flytekit/configuration/internal.py index f34321f57b..ec08576c3f 100644 --- a/flytekit/configuration/internal.py +++ b/flytekit/configuration/internal.py @@ -57,6 +57,14 @@ class GCP(object): GSUTIL_PARALLELISM = ConfigEntry(LegacyConfigEntry(SECTION, "gsutil_parallelism", bool)) +class AZURE(object): + SECTION = "azure" + ACCOUNT_NAME = ConfigEntry(LegacyConfigEntry(SECTION, "account_name")) + ACCOUNT_KEY = ConfigEntry(LegacyConfigEntry(SECTION, "account_key")) + CLIENT_ID = ConfigEntry(LegacyConfigEntry(SECTION, "account_key")) + CLIENT_SECRET = ConfigEntry(LegacyConfigEntry(SECTION, "access_key_id")) + + class Credentials(object): SECTION = "credentials" COMMAND = ConfigEntry(LegacyConfigEntry(SECTION, "command", list), YamlConfigEntry("admin.command", list)) From 2719f6af8e63bba633142c5e12648ec55e3a31dd Mon Sep 17 00:00:00 2001 From: Thomas Newton Date: Tue, 12 Sep 2023 16:18:38 +0100 Subject: [PATCH 02/24] Use Azure args Signed-off-by: Thomas Newton --- flytekit/core/data_persistence.py | 24 +++++++++++++++++++++++- 1 file changed, 23 insertions(+), 1 deletion(-) diff --git a/flytekit/core/data_persistence.py b/flytekit/core/data_persistence.py index 57311ed415..25a8ea2323 100644 --- a/flytekit/core/data_persistence.py +++ b/flytekit/core/data_persistence.py @@ -61,6 +61,27 @@ def s3_setup_args(s3_cfg: configuration.S3Config, anonymous: bool = False): return kwargs +def azure_setup_args(azure_cfg: configuration.AzureBlobStorageConfig, anonymous: bool = False): + kwargs: Dict[str, Any] = {} + + if azure_cfg.account_name: + kwargs["account_name"] = azure_cfg.account_name + + if azure_cfg.account_name: + kwargs["account_key"] = azure_cfg.account_key + + if azure_cfg.client_id: + kwargs["account_key"] = azure_cfg.client_id + + if azure_cfg.client_secret: + kwargs["account_key"] = azure_cfg.client_secret + + if anonymous: + kwargs[_ANON] = True + + return kwargs + + class FileAccessProvider(object): """ This is the class that is available through the FlyteContext and can be used for persisting data to the remote @@ -120,7 +141,8 @@ def get_filesystem( kwargs["token"] = _ANON return fsspec.filesystem(protocol, **kwargs) # type: ignore elif protocol == "abfs": - kwargs["anon"] = False + azurekwargs = azure_setup_args(self._data_config.azure, anonymous=anonymous) + azurekwargs.update(kwargs) return fsspec.filesystem(protocol, **kwargs) # type: ignore # Preserve old behavior of returning None for file systems that don't have an explicit anonymous option. From 1547f8877d9ab6ee8d7778d70481f347e3d0cb80 Mon Sep 17 00:00:00 2001 From: Thomas Newton Date: Tue, 12 Sep 2023 16:25:00 +0100 Subject: [PATCH 03/24] Add tenant id and correct some typos Signed-off-by: Thomas Newton --- flytekit/configuration/__init__.py | 2 ++ flytekit/configuration/internal.py | 5 +++-- flytekit/core/data_persistence.py | 7 +++++-- 3 files changed, 10 insertions(+), 4 deletions(-) diff --git a/flytekit/configuration/__init__.py b/flytekit/configuration/__init__.py index 35458e2f5e..7f325311e8 100644 --- a/flytekit/configuration/__init__.py +++ b/flytekit/configuration/__init__.py @@ -569,6 +569,7 @@ class AzureBlobStorageConfig(object): account_name: typing.Optional[str] = None account_key: typing.Optional[str] = None + tenant_id: typing.Optional[str] = None client_id: typing.Optional[str] = None client_secret: typing.Optional[str] = None @@ -578,6 +579,7 @@ def auto(cls, config_file: typing.Union[str, ConfigFile] = None) -> GCSConfig: kwargs = {} kwargs = set_if_exists(kwargs, "account_name", _internal.AZURE.ACCOUNT_NAME.read(config_file)) kwargs = set_if_exists(kwargs, "account_key", _internal.AZURE.ACCOUNT_KEY.read(config_file)) + kwargs = set_if_exists(kwargs, "tenant_id", _internal.AZURE.TENANT_ID.read(config_file)) kwargs = set_if_exists(kwargs, "client_id", _internal.AZURE.CLIENT_ID.read(config_file)) kwargs = set_if_exists(kwargs, "client_secret", _internal.AZURE.CLIENT_SECRET.read(config_file)) return AzureBlobStorageConfig(**kwargs) diff --git a/flytekit/configuration/internal.py b/flytekit/configuration/internal.py index ec08576c3f..c654acfa65 100644 --- a/flytekit/configuration/internal.py +++ b/flytekit/configuration/internal.py @@ -61,8 +61,9 @@ class AZURE(object): SECTION = "azure" ACCOUNT_NAME = ConfigEntry(LegacyConfigEntry(SECTION, "account_name")) ACCOUNT_KEY = ConfigEntry(LegacyConfigEntry(SECTION, "account_key")) - CLIENT_ID = ConfigEntry(LegacyConfigEntry(SECTION, "account_key")) - CLIENT_SECRET = ConfigEntry(LegacyConfigEntry(SECTION, "access_key_id")) + TENANT_ID = ConfigEntry(LegacyConfigEntry(SECTION, "tenant_id")) + CLIENT_ID = ConfigEntry(LegacyConfigEntry(SECTION, "client_id")) + CLIENT_SECRET = ConfigEntry(LegacyConfigEntry(SECTION, "client_secret")) class Credentials(object): diff --git a/flytekit/core/data_persistence.py b/flytekit/core/data_persistence.py index 25a8ea2323..6ed6904545 100644 --- a/flytekit/core/data_persistence.py +++ b/flytekit/core/data_persistence.py @@ -71,10 +71,13 @@ def azure_setup_args(azure_cfg: configuration.AzureBlobStorageConfig, anonymous: kwargs["account_key"] = azure_cfg.account_key if azure_cfg.client_id: - kwargs["account_key"] = azure_cfg.client_id + kwargs["client_id"] = azure_cfg.client_id if azure_cfg.client_secret: - kwargs["account_key"] = azure_cfg.client_secret + kwargs["client_secret"] = azure_cfg.client_secret + + if azure_cfg.tenant_id: + kwargs["tenant_id"] = azure_cfg.tenant_id if anonymous: kwargs[_ANON] = True From 8ed4a7ddc5a6c7e9f84a9ac0ef0e18934922bdd9 Mon Sep 17 00:00:00 2001 From: Thomas Newton Date: Tue, 12 Sep 2023 17:28:20 +0100 Subject: [PATCH 04/24] Fix typos Signed-off-by: Thomas Newton --- flytekit/core/data_persistence.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/flytekit/core/data_persistence.py b/flytekit/core/data_persistence.py index 6ed6904545..8b46ccb281 100644 --- a/flytekit/core/data_persistence.py +++ b/flytekit/core/data_persistence.py @@ -67,7 +67,7 @@ def azure_setup_args(azure_cfg: configuration.AzureBlobStorageConfig, anonymous: if azure_cfg.account_name: kwargs["account_name"] = azure_cfg.account_name - if azure_cfg.account_name: + if azure_cfg.account_key: kwargs["account_key"] = azure_cfg.account_key if azure_cfg.client_id: @@ -146,7 +146,7 @@ def get_filesystem( elif protocol == "abfs": azurekwargs = azure_setup_args(self._data_config.azure, anonymous=anonymous) azurekwargs.update(kwargs) - return fsspec.filesystem(protocol, **kwargs) # type: ignore + return fsspec.filesystem(protocol, **azurekwargs) # type: ignore # Preserve old behavior of returning None for file systems that don't have an explicit anonymous option. if anonymous: From 172ee25ab63dc6f5fb912a29e0870a710d974135 Mon Sep 17 00:00:00 2001 From: Thomas Newton Date: Tue, 12 Sep 2023 18:07:09 +0100 Subject: [PATCH 05/24] Extract get_storage_options_for_filesystem Signed-off-by: Thomas Newton --- flytekit/core/data_persistence.py | 39 ++++++++++++++----------------- 1 file changed, 18 insertions(+), 21 deletions(-) diff --git a/flytekit/core/data_persistence.py b/flytekit/core/data_persistence.py index 8b46ccb281..33d3d730a0 100644 --- a/flytekit/core/data_persistence.py +++ b/flytekit/core/data_persistence.py @@ -25,6 +25,7 @@ from uuid import UUID import fsspec +from fsspec.core import url_to_fs from fsspec.utils import get_protocol from flytekit import configuration @@ -85,6 +86,19 @@ def azure_setup_args(azure_cfg: configuration.AzureBlobStorageConfig, anonymous: return kwargs +def get_storage_options_for_filesystem(protocol: str, data_config: typing.Optional[DataConfig] = None, anonymous: bool = False, **kwargs) -> typing.Dict[str, Any]: + if protocol == "file": + return {"auto_mkdir": True, **kwargs} + if protocol == "s3": + return {**s3_setup_args(data_config.s3, anonymous=anonymous), **kwargs} + if protocol == "gs": + if anonymous: + kwargs["token"] = _ANON + return kwargs + if protocol == "abfs": + return {**azure_setup_args(data_config.azure, anonymous=anonymous), **kwargs} + + class FileAccessProvider(object): """ This is the class that is available through the FlyteContext and can be used for persisting data to the remote @@ -130,30 +144,13 @@ def data_config(self) -> DataConfig: def get_filesystem( self, protocol: typing.Optional[str] = None, anonymous: bool = False, **kwargs - ) -> typing.Optional[fsspec.AbstractFileSystem]: + ) -> fsspec.AbstractFileSystem: if not protocol: return self._default_remote - if protocol == "file": - kwargs["auto_mkdir"] = True - elif protocol == "s3": - s3kwargs = s3_setup_args(self._data_config.s3, anonymous=anonymous) - s3kwargs.update(kwargs) - return fsspec.filesystem(protocol, **s3kwargs) # type: ignore - elif protocol == "gs": - if anonymous: - kwargs["token"] = _ANON - return fsspec.filesystem(protocol, **kwargs) # type: ignore - elif protocol == "abfs": - azurekwargs = azure_setup_args(self._data_config.azure, anonymous=anonymous) - azurekwargs.update(kwargs) - return fsspec.filesystem(protocol, **azurekwargs) # type: ignore - - # Preserve old behavior of returning None for file systems that don't have an explicit anonymous option. - if anonymous: - return None - - return fsspec.filesystem(protocol, **kwargs) # type: ignore + + storage_options = get_storage_options_for_filesystem(protocol=protocol, anonymous=anonymous, data_config=self._data_config, **kwargs) + return fsspec.filesystem(protocol, **storage_options) def get_filesystem_for_path(self, path: str = "", anonymous: bool = False, **kwargs) -> fsspec.AbstractFileSystem: protocol = get_protocol(path) return self.get_filesystem(protocol, anonymous=anonymous, **kwargs) From c94fb7dccd21ab5e5f7f0bebdd90a0b5e3fa1a71 Mon Sep 17 00:00:00 2001 From: Thomas Newton Date: Tue, 12 Sep 2023 18:12:04 +0100 Subject: [PATCH 06/24] Use `get_storage_options` when encoding structured datasets Signed-off-by: Thomas Newton --- flytekit/core/data_persistence.py | 2 +- flytekit/types/structured/basic_dfs.py | 24 +++++++----------------- 2 files changed, 8 insertions(+), 18 deletions(-) diff --git a/flytekit/core/data_persistence.py b/flytekit/core/data_persistence.py index 33d3d730a0..14c37e328b 100644 --- a/flytekit/core/data_persistence.py +++ b/flytekit/core/data_persistence.py @@ -86,7 +86,7 @@ def azure_setup_args(azure_cfg: configuration.AzureBlobStorageConfig, anonymous: return kwargs -def get_storage_options_for_filesystem(protocol: str, data_config: typing.Optional[DataConfig] = None, anonymous: bool = False, **kwargs) -> typing.Dict[str, Any]: +def get_storage_options(protocol: str, data_config: typing.Optional[DataConfig] = None, anonymous: bool = False, **kwargs) -> typing.Dict[str, Any]: if protocol == "file": return {"auto_mkdir": True, **kwargs} if protocol == "s3": diff --git a/flytekit/types/structured/basic_dfs.py b/flytekit/types/structured/basic_dfs.py index 98a12ae44d..06442bc61e 100644 --- a/flytekit/types/structured/basic_dfs.py +++ b/flytekit/types/structured/basic_dfs.py @@ -11,8 +11,7 @@ from fsspec.utils import get_protocol from flytekit import FlyteContext, logger -from flytekit.configuration import DataConfig -from flytekit.core.data_persistence import s3_setup_args +from flytekit.core.data_persistence import get_storage_options from flytekit.models import literals from flytekit.models.literals import StructuredDatasetMetadata from flytekit.models.types import StructuredDatasetType @@ -27,15 +26,6 @@ T = TypeVar("T") -def get_storage_options(cfg: DataConfig, uri: str, anon: bool = False) -> typing.Optional[typing.Dict]: - protocol = get_protocol(uri) - if protocol == "s3": - kwargs = s3_setup_args(cfg.s3, anon) - if kwargs: - return kwargs - return None - - class PandasToCSVEncodingHandler(StructuredDatasetEncoder): def __init__(self): super().__init__(pd.DataFrame, None, CSV) @@ -54,7 +44,7 @@ def encode( df.to_csv( path, index=False, - storage_options=get_storage_options(ctx.file_access.data_config, path), + storage_options=get_storage_options(protocol=get_protocol(path), data_config=ctx.file_access.data_config), ) structured_dataset_type.format = CSV return literals.StructuredDataset(uri=uri, metadata=StructuredDatasetMetadata(structured_dataset_type)) @@ -72,7 +62,7 @@ def decode( ) -> pd.DataFrame: uri = flyte_value.uri columns = None - kwargs = get_storage_options(ctx.file_access.data_config, uri) + kwargs = get_storage_options(protocol=get_protocol(uri), data_config=ctx.file_access.data_config) path = os.path.join(uri, ".csv") if current_task_metadata.structured_dataset_type and current_task_metadata.structured_dataset_type.columns: columns = [c.name for c in current_task_metadata.structured_dataset_type.columns] @@ -80,7 +70,7 @@ def decode( return pd.read_csv(path, usecols=columns, storage_options=kwargs) except NoCredentialsError: logger.debug("S3 source detected, attempting anonymous S3 access") - kwargs = get_storage_options(ctx.file_access.data_config, uri, anon=True) + kwargs = get_storage_options(protocol=get_protocol(uri), data_config=ctx.file_access.data_config, anon=True) return pd.read_csv(path, usecols=columns, storage_options=kwargs) @@ -103,7 +93,7 @@ def encode( path, coerce_timestamps="us", allow_truncated_timestamps=False, - storage_options=get_storage_options(ctx.file_access.data_config, path), + storage_options=get_storage_options(protocol=get_protocol(path), data_config=ctx.file_access.data_config), ) structured_dataset_type.format = PARQUET return literals.StructuredDataset(uri=uri, metadata=StructuredDatasetMetadata(structured_dataset_type)) @@ -121,14 +111,14 @@ def decode( ) -> pd.DataFrame: uri = flyte_value.uri columns = None - kwargs = get_storage_options(ctx.file_access.data_config, uri) + kwargs = get_storage_options(protocol=get_protocol(uri), data_config=ctx.file_access.data_config) if current_task_metadata.structured_dataset_type and current_task_metadata.structured_dataset_type.columns: columns = [c.name for c in current_task_metadata.structured_dataset_type.columns] try: return pd.read_parquet(uri, columns=columns, storage_options=kwargs) except NoCredentialsError: logger.debug("S3 source detected, attempting anonymous S3 access") - kwargs = get_storage_options(ctx.file_access.data_config, uri, anon=True) + kwargs = get_storage_options(protocol=get_protocol(uri), data_config=ctx.file_access.data_config, anon=True) return pd.read_parquet(uri, columns=columns, storage_options=kwargs) From f8c8c6dbaca0a66bae03af4377743feae7cc28e8 Mon Sep 17 00:00:00 2001 From: Thomas Newton Date: Tue, 12 Sep 2023 18:24:27 +0100 Subject: [PATCH 07/24] Typo Signed-off-by: Thomas Newton --- flytekit/core/data_persistence.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/flytekit/core/data_persistence.py b/flytekit/core/data_persistence.py index 14c37e328b..dd9e1b072f 100644 --- a/flytekit/core/data_persistence.py +++ b/flytekit/core/data_persistence.py @@ -148,9 +148,10 @@ def get_filesystem( if not protocol: return self._default_remote - storage_options = get_storage_options_for_filesystem(protocol=protocol, anonymous=anonymous, data_config=self._data_config, **kwargs) + storage_options = get_storage_options(protocol=protocol, anonymous=anonymous, data_config=self._data_config, **kwargs) return fsspec.filesystem(protocol, **storage_options) + def get_filesystem_for_path(self, path: str = "", anonymous: bool = False, **kwargs) -> fsspec.AbstractFileSystem: protocol = get_protocol(path) return self.get_filesystem(protocol, anonymous=anonymous, **kwargs) From 83cfe5bbbb452f3cdc680316501c8058476f8377 Mon Sep 17 00:00:00 2001 From: Thomas Newton Date: Tue, 19 Sep 2023 22:36:51 +0100 Subject: [PATCH 08/24] Remove unused import Signed-off-by: Thomas Newton --- flytekit/core/data_persistence.py | 1 - 1 file changed, 1 deletion(-) diff --git a/flytekit/core/data_persistence.py b/flytekit/core/data_persistence.py index dd9e1b072f..c4221e7d9c 100644 --- a/flytekit/core/data_persistence.py +++ b/flytekit/core/data_persistence.py @@ -25,7 +25,6 @@ from uuid import UUID import fsspec -from fsspec.core import url_to_fs from fsspec.utils import get_protocol from flytekit import configuration From dc0f424cf53a50e8ef8b0fc62db97a003024b787 Mon Sep 17 00:00:00 2001 From: Thomas Newton Date: Fri, 22 Sep 2023 22:19:28 +0100 Subject: [PATCH 09/24] Handle unrecognised protocol Signed-off-by: Thomas Newton --- flytekit/core/data_persistence.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/flytekit/core/data_persistence.py b/flytekit/core/data_persistence.py index c4221e7d9c..5f1bde84d6 100644 --- a/flytekit/core/data_persistence.py +++ b/flytekit/core/data_persistence.py @@ -85,7 +85,7 @@ def azure_setup_args(azure_cfg: configuration.AzureBlobStorageConfig, anonymous: return kwargs -def get_storage_options(protocol: str, data_config: typing.Optional[DataConfig] = None, anonymous: bool = False, **kwargs) -> typing.Dict[str, Any]: +def get_fsspec_storage_options(protocol: str, data_config: typing.Optional[DataConfig] = None, anonymous: bool = False, **kwargs) -> typing.Dict: if protocol == "file": return {"auto_mkdir": True, **kwargs} if protocol == "s3": @@ -96,6 +96,7 @@ def get_storage_options(protocol: str, data_config: typing.Optional[DataConfig] return kwargs if protocol == "abfs": return {**azure_setup_args(data_config.azure, anonymous=anonymous), **kwargs} + return {} class FileAccessProvider(object): @@ -147,7 +148,7 @@ def get_filesystem( if not protocol: return self._default_remote - storage_options = get_storage_options(protocol=protocol, anonymous=anonymous, data_config=self._data_config, **kwargs) + storage_options = get_fsspec_storage_options(protocol=protocol, anonymous=anonymous, data_config=self._data_config, **kwargs) return fsspec.filesystem(protocol, **storage_options) From 418689f2ac614daca341ab130af9cfb7e9b9bc5c Mon Sep 17 00:00:00 2001 From: Thomas Newton Date: Fri, 22 Sep 2023 22:34:04 +0100 Subject: [PATCH 10/24] Handle difference between pandas and fsspec storage_options Signed-off-by: Thomas Newton --- flytekit/configuration/__init__.py | 3 ++- flytekit/core/data_persistence.py | 12 +++++++--- flytekit/types/structured/basic_dfs.py | 24 +++++++++++++------ .../flytekitplugins/polars/sd_transformers.py | 5 ++-- 4 files changed, 31 insertions(+), 13 deletions(-) diff --git a/flytekit/configuration/__init__.py b/flytekit/configuration/__init__.py index 7f325311e8..5dd4ba4345 100644 --- a/flytekit/configuration/__init__.py +++ b/flytekit/configuration/__init__.py @@ -560,7 +560,8 @@ def auto(cls, config_file: typing.Union[str, ConfigFile] = None) -> GCSConfig: kwargs = {} kwargs = set_if_exists(kwargs, "gsutil_parallelism", _internal.GCP.GSUTIL_PARALLELISM.read(config_file)) return GCSConfig(**kwargs) - + + @dataclass(init=True, repr=True, eq=True, frozen=True) class AzureBlobStorageConfig(object): """ diff --git a/flytekit/core/data_persistence.py b/flytekit/core/data_persistence.py index 5f1bde84d6..76acce7c20 100644 --- a/flytekit/core/data_persistence.py +++ b/flytekit/core/data_persistence.py @@ -85,7 +85,11 @@ def azure_setup_args(azure_cfg: configuration.AzureBlobStorageConfig, anonymous: return kwargs -def get_fsspec_storage_options(protocol: str, data_config: typing.Optional[DataConfig] = None, anonymous: bool = False, **kwargs) -> typing.Dict: +def get_fsspec_storage_options( + protocol: str, data_config: typing.Optional[DataConfig] = None, anonymous: bool = False, **kwargs +) -> typing.Dict: + data_config = data_config or DataConfig.auto() + if protocol == "file": return {"auto_mkdir": True, **kwargs} if protocol == "s3": @@ -147,8 +151,10 @@ def get_filesystem( ) -> fsspec.AbstractFileSystem: if not protocol: return self._default_remote - - storage_options = get_fsspec_storage_options(protocol=protocol, anonymous=anonymous, data_config=self._data_config, **kwargs) + + storage_options = get_fsspec_storage_options( + protocol=protocol, anonymous=anonymous, data_config=self._data_config, **kwargs + ) return fsspec.filesystem(protocol, **storage_options) diff --git a/flytekit/types/structured/basic_dfs.py b/flytekit/types/structured/basic_dfs.py index 06442bc61e..c5592a02bb 100644 --- a/flytekit/types/structured/basic_dfs.py +++ b/flytekit/types/structured/basic_dfs.py @@ -11,7 +11,8 @@ from fsspec.utils import get_protocol from flytekit import FlyteContext, logger -from flytekit.core.data_persistence import get_storage_options +from flytekit.configuration import DataConfig +from flytekit.core.data_persistence import get_fsspec_storage_options from flytekit.models import literals from flytekit.models.literals import StructuredDatasetMetadata from flytekit.models.types import StructuredDatasetType @@ -26,6 +27,15 @@ T = TypeVar("T") +def get_pandas_storage_options( + uri: str, data_config: DataConfig, anonymous: bool = False +) -> typing.Optional[typing.Dict]: + if pd.io.common.is_url(uri): + return get_fsspec_storage_options(protocol=get_protocol(uri), data_config=data_config, anonymous=anonymous) + # Pandas does not allow storage_options for non-url paths. + return None + + class PandasToCSVEncodingHandler(StructuredDatasetEncoder): def __init__(self): super().__init__(pd.DataFrame, None, CSV) @@ -44,7 +54,7 @@ def encode( df.to_csv( path, index=False, - storage_options=get_storage_options(protocol=get_protocol(path), data_config=ctx.file_access.data_config), + storage_options=get_pandas_storage_options(uri=path, data_config=ctx.file_access.data_config), ) structured_dataset_type.format = CSV return literals.StructuredDataset(uri=uri, metadata=StructuredDatasetMetadata(structured_dataset_type)) @@ -62,7 +72,7 @@ def decode( ) -> pd.DataFrame: uri = flyte_value.uri columns = None - kwargs = get_storage_options(protocol=get_protocol(uri), data_config=ctx.file_access.data_config) + kwargs = get_pandas_storage_options(uri=uri, data_config=ctx.file_access.data_config) path = os.path.join(uri, ".csv") if current_task_metadata.structured_dataset_type and current_task_metadata.structured_dataset_type.columns: columns = [c.name for c in current_task_metadata.structured_dataset_type.columns] @@ -70,7 +80,7 @@ def decode( return pd.read_csv(path, usecols=columns, storage_options=kwargs) except NoCredentialsError: logger.debug("S3 source detected, attempting anonymous S3 access") - kwargs = get_storage_options(protocol=get_protocol(uri), data_config=ctx.file_access.data_config, anon=True) + kwargs = get_pandas_storage_options(uri=uri, data_config=ctx.file_access.data_config, anonymous=True) return pd.read_csv(path, usecols=columns, storage_options=kwargs) @@ -93,7 +103,7 @@ def encode( path, coerce_timestamps="us", allow_truncated_timestamps=False, - storage_options=get_storage_options(protocol=get_protocol(path), data_config=ctx.file_access.data_config), + storage_options=get_pandas_storage_options(uri=path, data_config=ctx.file_access.data_config), ) structured_dataset_type.format = PARQUET return literals.StructuredDataset(uri=uri, metadata=StructuredDatasetMetadata(structured_dataset_type)) @@ -111,14 +121,14 @@ def decode( ) -> pd.DataFrame: uri = flyte_value.uri columns = None - kwargs = get_storage_options(protocol=get_protocol(uri), data_config=ctx.file_access.data_config) + kwargs = get_pandas_storage_options(uri=uri, data_config=ctx.file_access.data_config) if current_task_metadata.structured_dataset_type and current_task_metadata.structured_dataset_type.columns: columns = [c.name for c in current_task_metadata.structured_dataset_type.columns] try: return pd.read_parquet(uri, columns=columns, storage_options=kwargs) except NoCredentialsError: logger.debug("S3 source detected, attempting anonymous S3 access") - kwargs = get_storage_options(protocol=get_protocol(uri), data_config=ctx.file_access.data_config, anon=True) + kwargs = get_pandas_storage_options(uri=uri, data_config=ctx.file_access.data_config, anonymous=True) return pd.read_parquet(uri, columns=columns, storage_options=kwargs) diff --git a/plugins/flytekit-polars/flytekitplugins/polars/sd_transformers.py b/plugins/flytekit-polars/flytekitplugins/polars/sd_transformers.py index 4290c88ae4..ea644dc078 100644 --- a/plugins/flytekit-polars/flytekitplugins/polars/sd_transformers.py +++ b/plugins/flytekit-polars/flytekitplugins/polars/sd_transformers.py @@ -2,12 +2,13 @@ import pandas as pd import polars as pl +from fsspec.utils import get_protocol from flytekit import FlyteContext +from flytekit.core.data_persistence import get_fsspec_storage_options from flytekit.models import literals from flytekit.models.literals import StructuredDatasetMetadata from flytekit.models.types import StructuredDatasetType -from flytekit.types.structured.basic_dfs import get_storage_options from flytekit.types.structured.structured_dataset import ( PARQUET, StructuredDataset, @@ -64,7 +65,7 @@ def decode( current_task_metadata: StructuredDatasetMetadata, ) -> pl.DataFrame: uri = flyte_value.uri - kwargs = get_storage_options(ctx.file_access.data_config, uri) + kwargs = get_fsspec_storage_options(protocol=get_protocol(uri), data_config=ctx.file_access.data_config) if current_task_metadata.structured_dataset_type and current_task_metadata.structured_dataset_type.columns: columns = [c.name for c in current_task_metadata.structured_dataset_type.columns] return pl.read_parquet(uri, columns=columns, use_pyarrow=True, storage_options=kwargs) From f9da829ed66d38f2183643cc4d7adc0c61ec2e2d Mon Sep 17 00:00:00 2001 From: Thomas Newton Date: Sat, 23 Sep 2023 00:07:31 +0100 Subject: [PATCH 11/24] Add test for azure_setup_args Signed-off-by: Thomas Newton --- tests/flytekit/unit/core/test_data.py | 31 +++++++++++++++++++++++++-- 1 file changed, 29 insertions(+), 2 deletions(-) diff --git a/tests/flytekit/unit/core/test_data.py b/tests/flytekit/unit/core/test_data.py index 2d61b58d8c..dff7ac223e 100644 --- a/tests/flytekit/unit/core/test_data.py +++ b/tests/flytekit/unit/core/test_data.py @@ -8,9 +8,14 @@ import mock import pytest -from flytekit.configuration import Config, S3Config +from flytekit.configuration import AzureBlobStorageConfig, Config, S3Config from flytekit.core.context_manager import FlyteContextManager -from flytekit.core.data_persistence import FileAccessProvider, default_local_file_access_provider, s3_setup_args +from flytekit.core.data_persistence import ( + FileAccessProvider, + azure_setup_args, + default_local_file_access_provider, + s3_setup_args, +) from flytekit.types.directory.types import FlyteDirectory local = fsspec.filesystem("file") @@ -221,6 +226,28 @@ def test_s3_setup_args_env_aws(mock_os, mock_get_config_file): assert kwargs == {"cache_regions": True} +@mock.patch("flytekit.configuration.get_config_file") +@mock.patch("os.environ") +def test_azure_setup_env_args(mock_os, mock_get_config_file): + mock_get_config_file.return_value = None + ee = { + "FLYTE_AZURE_ACCOUNT_NAME": "accountname", + "FLYTE_AZURE_ACCOUNT_KEY": "accountkey", + "FLYTE_AZURE_TENANT_ID": "tenantid", + "FLYTE_AZURE_CLIENT_ID": "clientid", + "FLYTE_AZURE_CLIENT_SECRET": "clientsecret", + } + mock_os.get.side_effect = lambda x, y: ee.get(x) + kwargs = azure_setup_args(AzureBlobStorageConfig.auto()) + assert kwargs == { + "account_name": "accountname", + "account_key": "accountkey", + "client_id": "clientid", + "client_secret": "clientsecret", + "tenant_id": "tenantid", + } + + def test_crawl_local_nt(source_folder): """ running this to see what it prints From 2d4f54b82c5b4a9f57e8131c84d9322f4ba367bb Mon Sep 17 00:00:00 2001 From: Thomas Newton Date: Sat, 23 Sep 2023 02:03:12 +0100 Subject: [PATCH 12/24] Add tests for initialising Azure filesystem and fix anonymous Signed-off-by: Thomas Newton --- flytekit/core/data_persistence.py | 3 +- .../unit/core/test_data_persistence.py | 29 +++++++++++++++++++ 2 files changed, 30 insertions(+), 2 deletions(-) diff --git a/flytekit/core/data_persistence.py b/flytekit/core/data_persistence.py index 76acce7c20..5133ecc3d1 100644 --- a/flytekit/core/data_persistence.py +++ b/flytekit/core/data_persistence.py @@ -79,8 +79,7 @@ def azure_setup_args(azure_cfg: configuration.AzureBlobStorageConfig, anonymous: if azure_cfg.tenant_id: kwargs["tenant_id"] = azure_cfg.tenant_id - if anonymous: - kwargs[_ANON] = True + kwargs[_ANON] = anonymous return kwargs diff --git a/tests/flytekit/unit/core/test_data_persistence.py b/tests/flytekit/unit/core/test_data_persistence.py index 27b407c1ce..002153c1e5 100644 --- a/tests/flytekit/unit/core/test_data_persistence.py +++ b/tests/flytekit/unit/core/test_data_persistence.py @@ -1,3 +1,7 @@ +import os + +from azure.identity import ClientSecretCredential, DefaultAzureCredential +from mock import patch from flytekit.core.data_persistence import FileAccessProvider @@ -14,3 +18,28 @@ def test_is_remote(): assert fp.is_remote("/tmp/foo/bar") is False assert fp.is_remote("file://foo/bar") is False assert fp.is_remote("s3://my-bucket/foo/bar") is True + + +def test_initialise_azure_file_provider_with_account_key(): + with patch.dict(os.environ, {"FLYTE_AZURE_ACCOUNT_NAME": "accountname", "FLYTE_AZURE_ACCOUNT_KEY": "accountkey"}, clear=True): + fp = FileAccessProvider("/tmp", "abfs://container/path/within/container") + assert fp.get_filesystem().account_name == "accountname" + assert fp.get_filesystem().account_key == "accountkey" + assert fp.get_filesystem().sync_credential is None + +def test_initialise_azure_file_provider_with_service_principal(): + with patch.dict(os.environ, {"FLYTE_AZURE_ACCOUNT_NAME": "accountname", "FLYTE_AZURE_CLIENT_SECRET": "clientsecret", "FLYTE_AZURE_CLIENT_ID": "clientid", + "FLYTE_AZURE_TENANT_ID": "tenantid" + }, clear=True): + fp = FileAccessProvider("/tmp", "abfs://container/path/within/container") + assert fp.get_filesystem().account_name == "accountname" + assert isinstance(fp.get_filesystem().sync_credential, ClientSecretCredential) + assert fp.get_filesystem().client_secret == "clientsecret" + assert fp.get_filesystem().client_id == "clientid" + assert fp.get_filesystem().tenant_id == "tenantid" + +def test_initialise_azure_file_provider_with_default_credential(): + with patch.dict(os.environ, {"FLYTE_AZURE_ACCOUNT_NAME": "accountname"}, clear=True): + fp = FileAccessProvider("/tmp", "abfs://container/path/within/container") + assert fp.get_filesystem().account_name == "accountname" + assert isinstance(fp.get_filesystem().sync_credential, DefaultAzureCredential) From f846352dc2ce48df228c6fc927181fccc61e0b67 Mon Sep 17 00:00:00 2001 From: Thomas Newton Date: Mon, 25 Sep 2023 08:51:12 +0100 Subject: [PATCH 13/24] Fix anon assertion Signed-off-by: Thomas Newton --- tests/flytekit/unit/core/test_data.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/flytekit/unit/core/test_data.py b/tests/flytekit/unit/core/test_data.py index dff7ac223e..5dd79a3618 100644 --- a/tests/flytekit/unit/core/test_data.py +++ b/tests/flytekit/unit/core/test_data.py @@ -245,6 +245,7 @@ def test_azure_setup_env_args(mock_os, mock_get_config_file): "client_id": "clientid", "client_secret": "clientsecret", "tenant_id": "tenantid", + "anon": False, } From 7ffca4093f5fcb2a339915b3f86a1c762e37dcdc Mon Sep 17 00:00:00 2001 From: Thomas Newton Date: Mon, 25 Sep 2023 08:51:18 +0100 Subject: [PATCH 14/24] Autoformat Signed-off-by: Thomas Newton --- .../unit/core/test_data_persistence.py | 20 +++++++++++++++---- 1 file changed, 16 insertions(+), 4 deletions(-) diff --git a/tests/flytekit/unit/core/test_data_persistence.py b/tests/flytekit/unit/core/test_data_persistence.py index 002153c1e5..0bfa3cb9d9 100644 --- a/tests/flytekit/unit/core/test_data_persistence.py +++ b/tests/flytekit/unit/core/test_data_persistence.py @@ -2,6 +2,7 @@ from azure.identity import ClientSecretCredential, DefaultAzureCredential from mock import patch + from flytekit.core.data_persistence import FileAccessProvider @@ -21,16 +22,26 @@ def test_is_remote(): def test_initialise_azure_file_provider_with_account_key(): - with patch.dict(os.environ, {"FLYTE_AZURE_ACCOUNT_NAME": "accountname", "FLYTE_AZURE_ACCOUNT_KEY": "accountkey"}, clear=True): + with patch.dict( + os.environ, {"FLYTE_AZURE_ACCOUNT_NAME": "accountname", "FLYTE_AZURE_ACCOUNT_KEY": "accountkey"}, clear=True + ): fp = FileAccessProvider("/tmp", "abfs://container/path/within/container") assert fp.get_filesystem().account_name == "accountname" assert fp.get_filesystem().account_key == "accountkey" assert fp.get_filesystem().sync_credential is None + def test_initialise_azure_file_provider_with_service_principal(): - with patch.dict(os.environ, {"FLYTE_AZURE_ACCOUNT_NAME": "accountname", "FLYTE_AZURE_CLIENT_SECRET": "clientsecret", "FLYTE_AZURE_CLIENT_ID": "clientid", - "FLYTE_AZURE_TENANT_ID": "tenantid" - }, clear=True): + with patch.dict( + os.environ, + { + "FLYTE_AZURE_ACCOUNT_NAME": "accountname", + "FLYTE_AZURE_CLIENT_SECRET": "clientsecret", + "FLYTE_AZURE_CLIENT_ID": "clientid", + "FLYTE_AZURE_TENANT_ID": "tenantid", + }, + clear=True, + ): fp = FileAccessProvider("/tmp", "abfs://container/path/within/container") assert fp.get_filesystem().account_name == "accountname" assert isinstance(fp.get_filesystem().sync_credential, ClientSecretCredential) @@ -38,6 +49,7 @@ def test_initialise_azure_file_provider_with_service_principal(): assert fp.get_filesystem().client_id == "clientid" assert fp.get_filesystem().tenant_id == "tenantid" + def test_initialise_azure_file_provider_with_default_credential(): with patch.dict(os.environ, {"FLYTE_AZURE_ACCOUNT_NAME": "accountname"}, clear=True): fp = FileAccessProvider("/tmp", "abfs://container/path/within/container") From 087021de7841a14223253c5676921ce227b396bd Mon Sep 17 00:00:00 2001 From: Thomas Newton Date: Mon, 25 Sep 2023 09:38:55 +0100 Subject: [PATCH 15/24] Improve test coverage for getting fsspec storage options Signed-off-by: Thomas Newton --- flytekit/core/data_persistence.py | 8 +---- tests/flytekit/unit/core/test_data.py | 50 ++++++++++++++++++++++----- 2 files changed, 43 insertions(+), 15 deletions(-) diff --git a/flytekit/core/data_persistence.py b/flytekit/core/data_persistence.py index 5133ecc3d1..923c6421c6 100644 --- a/flytekit/core/data_persistence.py +++ b/flytekit/core/data_persistence.py @@ -66,21 +66,15 @@ def azure_setup_args(azure_cfg: configuration.AzureBlobStorageConfig, anonymous: if azure_cfg.account_name: kwargs["account_name"] = azure_cfg.account_name - if azure_cfg.account_key: kwargs["account_key"] = azure_cfg.account_key - if azure_cfg.client_id: kwargs["client_id"] = azure_cfg.client_id - if azure_cfg.client_secret: kwargs["client_secret"] = azure_cfg.client_secret - if azure_cfg.tenant_id: kwargs["tenant_id"] = azure_cfg.tenant_id - kwargs[_ANON] = anonymous - return kwargs @@ -97,7 +91,7 @@ def get_fsspec_storage_options( if anonymous: kwargs["token"] = _ANON return kwargs - if protocol == "abfs": + if protocol in ("abfs", "abfss"): return {**azure_setup_args(data_config.azure, anonymous=anonymous), **kwargs} return {} diff --git a/tests/flytekit/unit/core/test_data.py b/tests/flytekit/unit/core/test_data.py index 5dd79a3618..72d7f4b015 100644 --- a/tests/flytekit/unit/core/test_data.py +++ b/tests/flytekit/unit/core/test_data.py @@ -8,12 +8,13 @@ import mock import pytest -from flytekit.configuration import AzureBlobStorageConfig, Config, S3Config +from flytekit.configuration import AzureBlobStorageConfig, Config, DataConfig, S3Config from flytekit.core.context_manager import FlyteContextManager from flytekit.core.data_persistence import ( FileAccessProvider, azure_setup_args, default_local_file_access_provider, + get_fsspec_storage_options, s3_setup_args, ) from flytekit.types.directory.types import FlyteDirectory @@ -227,19 +228,37 @@ def test_s3_setup_args_env_aws(mock_os, mock_get_config_file): @mock.patch("flytekit.configuration.get_config_file") -@mock.patch("os.environ") -def test_azure_setup_env_args(mock_os, mock_get_config_file): +@mock.patch.dict(os.environ, { + "FLYTE_GCP_GSUTIL_PARALLELISM": "False", + }) +def test_get_fsspec_storage_options_gcs(mock_get_config_file): mock_get_config_file.return_value = None - ee = { + storage_options = get_fsspec_storage_options("gs", DataConfig.auto()) + assert storage_options == {} + + +@mock.patch("flytekit.configuration.get_config_file") +@mock.patch.dict(os.environ, { + "FLYTE_GCP_GSUTIL_PARALLELISM": "False", + }) +def test_get_fsspec_storage_options_gcs(mock_get_config_file): + mock_get_config_file.return_value = None + storage_options = get_fsspec_storage_options("gs", DataConfig.auto(), anonymous=True, other_argument="value") + assert storage_options == {"token": "anon", "other_argument": "value"} + + +@mock.patch("flytekit.configuration.get_config_file") +@mock.patch.dict(os.environ, { "FLYTE_AZURE_ACCOUNT_NAME": "accountname", "FLYTE_AZURE_ACCOUNT_KEY": "accountkey", "FLYTE_AZURE_TENANT_ID": "tenantid", "FLYTE_AZURE_CLIENT_ID": "clientid", "FLYTE_AZURE_CLIENT_SECRET": "clientsecret", - } - mock_os.get.side_effect = lambda x, y: ee.get(x) - kwargs = azure_setup_args(AzureBlobStorageConfig.auto()) - assert kwargs == { + }) +def test_get_fsspec_storage_options_azure(mock_get_config_file): + mock_get_config_file.return_value = None + storage_options = get_fsspec_storage_options("abfs", DataConfig.auto()) + assert storage_options == { "account_name": "accountname", "account_key": "accountkey", "client_id": "clientid", @@ -248,6 +267,21 @@ def test_azure_setup_env_args(mock_os, mock_get_config_file): "anon": False, } +@mock.patch("flytekit.configuration.get_config_file") +@mock.patch.dict(os.environ, { + "FLYTE_AZURE_ACCOUNT_NAME": "accountname", + "FLYTE_AZURE_ACCOUNT_KEY": "accountkey", + }) +def test_get_fsspec_storage_options_azure_with_overrides(mock_get_config_file): + mock_get_config_file.return_value = None + storage_options = get_fsspec_storage_options("abfs", DataConfig.auto(), anonymous=True, account_name="other_accountname", other_argument="value") + assert storage_options == { + "account_name": "other_accountname", + "account_key": "accountkey", + "anon": True, + "other_argument": "value", + } + def test_crawl_local_nt(source_folder): """ From b5df7c16494e1500be20a7171e3805b127948f69 Mon Sep 17 00:00:00 2001 From: Thomas Newton Date: Mon, 25 Sep 2023 10:19:43 +0100 Subject: [PATCH 16/24] Re-name env variable names + autoformat Signed-off-by: Thomas Newton --- flytekit/configuration/__init__.py | 4 +-- flytekit/configuration/internal.py | 4 +-- tests/flytekit/unit/core/test_data.py | 42 ++++++++++++++++++--------- 3 files changed, 32 insertions(+), 18 deletions(-) diff --git a/flytekit/configuration/__init__.py b/flytekit/configuration/__init__.py index 5dd4ba4345..76beddabcd 100644 --- a/flytekit/configuration/__init__.py +++ b/flytekit/configuration/__init__.py @@ -578,8 +578,8 @@ class AzureBlobStorageConfig(object): def auto(cls, config_file: typing.Union[str, ConfigFile] = None) -> GCSConfig: config_file = get_config_file(config_file) kwargs = {} - kwargs = set_if_exists(kwargs, "account_name", _internal.AZURE.ACCOUNT_NAME.read(config_file)) - kwargs = set_if_exists(kwargs, "account_key", _internal.AZURE.ACCOUNT_KEY.read(config_file)) + kwargs = set_if_exists(kwargs, "account_name", _internal.AZURE.STORAGE_ACCOUNT_NAME.read(config_file)) + kwargs = set_if_exists(kwargs, "account_key", _internal.AZURE.STORAGE_ACCOUNT_KEY.read(config_file)) kwargs = set_if_exists(kwargs, "tenant_id", _internal.AZURE.TENANT_ID.read(config_file)) kwargs = set_if_exists(kwargs, "client_id", _internal.AZURE.CLIENT_ID.read(config_file)) kwargs = set_if_exists(kwargs, "client_secret", _internal.AZURE.CLIENT_SECRET.read(config_file)) diff --git a/flytekit/configuration/internal.py b/flytekit/configuration/internal.py index c654acfa65..257eba4efd 100644 --- a/flytekit/configuration/internal.py +++ b/flytekit/configuration/internal.py @@ -59,8 +59,8 @@ class GCP(object): class AZURE(object): SECTION = "azure" - ACCOUNT_NAME = ConfigEntry(LegacyConfigEntry(SECTION, "account_name")) - ACCOUNT_KEY = ConfigEntry(LegacyConfigEntry(SECTION, "account_key")) + STORAGE_ACCOUNT_NAME = ConfigEntry(LegacyConfigEntry(SECTION, "storage_account_name")) + STORAGE_ACCOUNT_KEY = ConfigEntry(LegacyConfigEntry(SECTION, "storage_account_key")) TENANT_ID = ConfigEntry(LegacyConfigEntry(SECTION, "tenant_id")) CLIENT_ID = ConfigEntry(LegacyConfigEntry(SECTION, "client_id")) CLIENT_SECRET = ConfigEntry(LegacyConfigEntry(SECTION, "client_secret")) diff --git a/tests/flytekit/unit/core/test_data.py b/tests/flytekit/unit/core/test_data.py index 72d7f4b015..6d9f2a2cf6 100644 --- a/tests/flytekit/unit/core/test_data.py +++ b/tests/flytekit/unit/core/test_data.py @@ -8,11 +8,10 @@ import mock import pytest -from flytekit.configuration import AzureBlobStorageConfig, Config, DataConfig, S3Config +from flytekit.configuration import Config, DataConfig, S3Config from flytekit.core.context_manager import FlyteContextManager from flytekit.core.data_persistence import ( FileAccessProvider, - azure_setup_args, default_local_file_access_provider, get_fsspec_storage_options, s3_setup_args, @@ -228,9 +227,12 @@ def test_s3_setup_args_env_aws(mock_os, mock_get_config_file): @mock.patch("flytekit.configuration.get_config_file") -@mock.patch.dict(os.environ, { +@mock.patch.dict( + os.environ, + { "FLYTE_GCP_GSUTIL_PARALLELISM": "False", - }) + }, +) def test_get_fsspec_storage_options_gcs(mock_get_config_file): mock_get_config_file.return_value = None storage_options = get_fsspec_storage_options("gs", DataConfig.auto()) @@ -238,23 +240,29 @@ def test_get_fsspec_storage_options_gcs(mock_get_config_file): @mock.patch("flytekit.configuration.get_config_file") -@mock.patch.dict(os.environ, { +@mock.patch.dict( + os.environ, + { "FLYTE_GCP_GSUTIL_PARALLELISM": "False", - }) -def test_get_fsspec_storage_options_gcs(mock_get_config_file): + }, +) +def test_get_fsspec_storage_options_gcs_with_overrides(mock_get_config_file): mock_get_config_file.return_value = None storage_options = get_fsspec_storage_options("gs", DataConfig.auto(), anonymous=True, other_argument="value") assert storage_options == {"token": "anon", "other_argument": "value"} @mock.patch("flytekit.configuration.get_config_file") -@mock.patch.dict(os.environ, { - "FLYTE_AZURE_ACCOUNT_NAME": "accountname", - "FLYTE_AZURE_ACCOUNT_KEY": "accountkey", +@mock.patch.dict( + os.environ, + { + "FLYTE_AZURE_STORAGE_ACCOUNT_NAME": "accountname", + "FLYTE_AZURE_STORAGE_ACCOUNT_KEY": "accountkey", "FLYTE_AZURE_TENANT_ID": "tenantid", "FLYTE_AZURE_CLIENT_ID": "clientid", "FLYTE_AZURE_CLIENT_SECRET": "clientsecret", - }) + }, +) def test_get_fsspec_storage_options_azure(mock_get_config_file): mock_get_config_file.return_value = None storage_options = get_fsspec_storage_options("abfs", DataConfig.auto()) @@ -267,14 +275,20 @@ def test_get_fsspec_storage_options_azure(mock_get_config_file): "anon": False, } + @mock.patch("flytekit.configuration.get_config_file") -@mock.patch.dict(os.environ, { +@mock.patch.dict( + os.environ, + { "FLYTE_AZURE_ACCOUNT_NAME": "accountname", "FLYTE_AZURE_ACCOUNT_KEY": "accountkey", - }) + }, +) def test_get_fsspec_storage_options_azure_with_overrides(mock_get_config_file): mock_get_config_file.return_value = None - storage_options = get_fsspec_storage_options("abfs", DataConfig.auto(), anonymous=True, account_name="other_accountname", other_argument="value") + storage_options = get_fsspec_storage_options( + "abfs", DataConfig.auto(), anonymous=True, account_name="other_accountname", other_argument="value" + ) assert storage_options == { "account_name": "other_accountname", "account_key": "accountkey", From c3ee81f8340b726094d642dc5255e239ef574d03 Mon Sep 17 00:00:00 2001 From: Thomas Newton Date: Mon, 25 Sep 2023 10:30:21 +0100 Subject: [PATCH 17/24] Update tests for new env var names Signed-off-by: Thomas Newton --- tests/flytekit/unit/core/test_data.py | 4 ++-- tests/flytekit/unit/core/test_data_persistence.py | 8 +++++--- 2 files changed, 7 insertions(+), 5 deletions(-) diff --git a/tests/flytekit/unit/core/test_data.py b/tests/flytekit/unit/core/test_data.py index 6d9f2a2cf6..d686fae686 100644 --- a/tests/flytekit/unit/core/test_data.py +++ b/tests/flytekit/unit/core/test_data.py @@ -280,8 +280,8 @@ def test_get_fsspec_storage_options_azure(mock_get_config_file): @mock.patch.dict( os.environ, { - "FLYTE_AZURE_ACCOUNT_NAME": "accountname", - "FLYTE_AZURE_ACCOUNT_KEY": "accountkey", + "FLYTE_AZURE_STORAGE_ACCOUNT_NAME": "accountname", + "FLYTE_AZURE_STORAGE_ACCOUNT_KEY": "accountkey", }, ) def test_get_fsspec_storage_options_azure_with_overrides(mock_get_config_file): diff --git a/tests/flytekit/unit/core/test_data_persistence.py b/tests/flytekit/unit/core/test_data_persistence.py index 0bfa3cb9d9..af2b2bf6f8 100644 --- a/tests/flytekit/unit/core/test_data_persistence.py +++ b/tests/flytekit/unit/core/test_data_persistence.py @@ -23,7 +23,9 @@ def test_is_remote(): def test_initialise_azure_file_provider_with_account_key(): with patch.dict( - os.environ, {"FLYTE_AZURE_ACCOUNT_NAME": "accountname", "FLYTE_AZURE_ACCOUNT_KEY": "accountkey"}, clear=True + os.environ, + {"FLYTE_AZURE_STORAGE_ACCOUNT_NAME": "accountname", "FLYTE_AZURE_STORAGE_ACCOUNT_KEY": "accountkey"}, + clear=True, ): fp = FileAccessProvider("/tmp", "abfs://container/path/within/container") assert fp.get_filesystem().account_name == "accountname" @@ -35,7 +37,7 @@ def test_initialise_azure_file_provider_with_service_principal(): with patch.dict( os.environ, { - "FLYTE_AZURE_ACCOUNT_NAME": "accountname", + "FLYTE_AZURE_STORAGE_ACCOUNT_NAME": "accountname", "FLYTE_AZURE_CLIENT_SECRET": "clientsecret", "FLYTE_AZURE_CLIENT_ID": "clientid", "FLYTE_AZURE_TENANT_ID": "tenantid", @@ -51,7 +53,7 @@ def test_initialise_azure_file_provider_with_service_principal(): def test_initialise_azure_file_provider_with_default_credential(): - with patch.dict(os.environ, {"FLYTE_AZURE_ACCOUNT_NAME": "accountname"}, clear=True): + with patch.dict(os.environ, {"FLYTE_AZURE_STORAGE_ACCOUNT_NAME": "accountname"}, clear=True): fp = FileAccessProvider("/tmp", "abfs://container/path/within/container") assert fp.get_filesystem().account_name == "accountname" assert isinstance(fp.get_filesystem().sync_credential, DefaultAzureCredential) From b689b3cd5159244b60ec3a82e454bbb775bf85c6 Mon Sep 17 00:00:00 2001 From: Thomas Newton Date: Mon, 25 Sep 2023 14:55:17 +0100 Subject: [PATCH 18/24] Better mocking of os.environ Signed-off-by: Thomas Newton --- tests/flytekit/unit/core/test_data.py | 52 +++++++++---------- .../unit/core/test_data_persistence.py | 10 ++-- 2 files changed, 28 insertions(+), 34 deletions(-) diff --git a/tests/flytekit/unit/core/test_data.py b/tests/flytekit/unit/core/test_data.py index d686fae686..667445321b 100644 --- a/tests/flytekit/unit/core/test_data.py +++ b/tests/flytekit/unit/core/test_data.py @@ -227,44 +227,41 @@ def test_s3_setup_args_env_aws(mock_os, mock_get_config_file): @mock.patch("flytekit.configuration.get_config_file") -@mock.patch.dict( - os.environ, - { - "FLYTE_GCP_GSUTIL_PARALLELISM": "False", - }, -) -def test_get_fsspec_storage_options_gcs(mock_get_config_file): +@mock.patch("os.environ") +def test_get_fsspec_storage_options_gcs(mock_os, mock_get_config_file): mock_get_config_file.return_value = None + ee = { + "FLYTE_GCP_GSUTIL_PARALLELISM": "False", + } + mock_os.get.side_effect = lambda x, y: ee.get(x) storage_options = get_fsspec_storage_options("gs", DataConfig.auto()) assert storage_options == {} @mock.patch("flytekit.configuration.get_config_file") -@mock.patch.dict( - os.environ, - { - "FLYTE_GCP_GSUTIL_PARALLELISM": "False", - }, -) -def test_get_fsspec_storage_options_gcs_with_overrides(mock_get_config_file): +@mock.patch("os.environ") +def test_get_fsspec_storage_options_gcs_with_overrides(mock_os, mock_get_config_file): mock_get_config_file.return_value = None + ee = { + "FLYTE_GCP_GSUTIL_PARALLELISM": "False", + } + mock_os.get.side_effect = lambda x, y: ee.get(x) storage_options = get_fsspec_storage_options("gs", DataConfig.auto(), anonymous=True, other_argument="value") assert storage_options == {"token": "anon", "other_argument": "value"} @mock.patch("flytekit.configuration.get_config_file") -@mock.patch.dict( - os.environ, - { +@mock.patch("os.environ") +def test_get_fsspec_storage_options_azure(mock_os, mock_get_config_file): + mock_get_config_file.return_value = None + ee = { "FLYTE_AZURE_STORAGE_ACCOUNT_NAME": "accountname", "FLYTE_AZURE_STORAGE_ACCOUNT_KEY": "accountkey", "FLYTE_AZURE_TENANT_ID": "tenantid", "FLYTE_AZURE_CLIENT_ID": "clientid", "FLYTE_AZURE_CLIENT_SECRET": "clientsecret", - }, -) -def test_get_fsspec_storage_options_azure(mock_get_config_file): - mock_get_config_file.return_value = None + } + mock_os.get.side_effect = lambda x, y: ee.get(x) storage_options = get_fsspec_storage_options("abfs", DataConfig.auto()) assert storage_options == { "account_name": "accountname", @@ -277,15 +274,14 @@ def test_get_fsspec_storage_options_azure(mock_get_config_file): @mock.patch("flytekit.configuration.get_config_file") -@mock.patch.dict( - os.environ, - { +@mock.patch("os.environ") +def test_get_fsspec_storage_options_azure_with_overrides(mock_os, mock_get_config_file): + mock_get_config_file.return_value = None + ee = { "FLYTE_AZURE_STORAGE_ACCOUNT_NAME": "accountname", "FLYTE_AZURE_STORAGE_ACCOUNT_KEY": "accountkey", - }, -) -def test_get_fsspec_storage_options_azure_with_overrides(mock_get_config_file): - mock_get_config_file.return_value = None + } + mock_os.get.side_effect = lambda x, y: ee.get(x) storage_options = get_fsspec_storage_options( "abfs", DataConfig.auto(), anonymous=True, account_name="other_accountname", other_argument="value" ) diff --git a/tests/flytekit/unit/core/test_data_persistence.py b/tests/flytekit/unit/core/test_data_persistence.py index af2b2bf6f8..2fc8b6c452 100644 --- a/tests/flytekit/unit/core/test_data_persistence.py +++ b/tests/flytekit/unit/core/test_data_persistence.py @@ -1,7 +1,7 @@ import os +import mock from azure.identity import ClientSecretCredential, DefaultAzureCredential -from mock import patch from flytekit.core.data_persistence import FileAccessProvider @@ -22,10 +22,9 @@ def test_is_remote(): def test_initialise_azure_file_provider_with_account_key(): - with patch.dict( + with mock.patch.dict( os.environ, {"FLYTE_AZURE_STORAGE_ACCOUNT_NAME": "accountname", "FLYTE_AZURE_STORAGE_ACCOUNT_KEY": "accountkey"}, - clear=True, ): fp = FileAccessProvider("/tmp", "abfs://container/path/within/container") assert fp.get_filesystem().account_name == "accountname" @@ -34,7 +33,7 @@ def test_initialise_azure_file_provider_with_account_key(): def test_initialise_azure_file_provider_with_service_principal(): - with patch.dict( + with mock.patch.dict( os.environ, { "FLYTE_AZURE_STORAGE_ACCOUNT_NAME": "accountname", @@ -42,7 +41,6 @@ def test_initialise_azure_file_provider_with_service_principal(): "FLYTE_AZURE_CLIENT_ID": "clientid", "FLYTE_AZURE_TENANT_ID": "tenantid", }, - clear=True, ): fp = FileAccessProvider("/tmp", "abfs://container/path/within/container") assert fp.get_filesystem().account_name == "accountname" @@ -53,7 +51,7 @@ def test_initialise_azure_file_provider_with_service_principal(): def test_initialise_azure_file_provider_with_default_credential(): - with patch.dict(os.environ, {"FLYTE_AZURE_STORAGE_ACCOUNT_NAME": "accountname"}, clear=True): + with mock.patch.dict(os.environ, {"FLYTE_AZURE_STORAGE_ACCOUNT_NAME": "accountname"}): fp = FileAccessProvider("/tmp", "abfs://container/path/within/container") assert fp.get_filesystem().account_name == "accountname" assert isinstance(fp.get_filesystem().sync_credential, DefaultAzureCredential) From a279b47ab2bcd30f88bf9f09b5250a6d1741ec9b Mon Sep 17 00:00:00 2001 From: Thomas Newton Date: Mon, 25 Sep 2023 18:03:49 +0100 Subject: [PATCH 19/24] Mostly working test for structured dataset filesystems use Signed-off-by: Thomas Newton --- .../unit/core/test_structured_dataset_handlers.py | 15 +++++++++++++++ 1 file changed, 15 insertions(+) diff --git a/tests/flytekit/unit/core/test_structured_dataset_handlers.py b/tests/flytekit/unit/core/test_structured_dataset_handlers.py index 4b9d183ad8..508a7d01e0 100644 --- a/tests/flytekit/unit/core/test_structured_dataset_handlers.py +++ b/tests/flytekit/unit/core/test_structured_dataset_handlers.py @@ -49,6 +49,21 @@ def test_csv(): df2 = decoder.decode(ctx, sd_lit, StructuredDatasetMetadata(sd_type)) assert df.equals(df2) +@pytest.mark.parametrize("format,encoder", [("parquet", basic_dfs.PandasToParquetEncodingHandler()), ("csv", basic_dfs.PandasToCSVEncodingHandler())]) +@mock.patch("pyarrow.parquet.write_table") +@mock.patch("flytekit.types.structured.basic_dfs.get_fsspec_storage_options") +def test_pandas_to_azure_initialises_filesystem_without_error(mock_get_fsspec_storage_options, mock_write_table, format, encoder): + mock_get_fsspec_storage_options.return_value = {"account_name": "accountname_from_storage_options"} + df = pd.DataFrame({"Name": ["Tom", "Joseph"], "Age": [20, 22]}) + + ctx = context_manager.FlyteContextManager.current_context() + sd = StructuredDataset(dataframe=df, uri="abfs://container/path/within/container") + sd_type = StructuredDatasetType(format=format) + encoder.encode(ctx, sd, sd_type) + mock_write_table.assert_called_once() + filesystem = mock_write_table.mock_calls[0].kwargs["filesystem"] + assert filesystem.account_name == "accountname_from_fsspec_storage_options" + def test_base_isnt_instantiable(): with pytest.raises(TypeError): From 80f34d857d9bcbb3753de88d840dc5e05701c371 Mon Sep 17 00:00:00 2001 From: Thomas Newton Date: Mon, 25 Sep 2023 19:40:22 +0100 Subject: [PATCH 20/24] Working test for structured dataset azure encode decode Signed-off-by: Thomas Newton --- .../core/test_structured_dataset_handlers.py | 63 ++++++++++++++++--- 1 file changed, 54 insertions(+), 9 deletions(-) diff --git a/tests/flytekit/unit/core/test_structured_dataset_handlers.py b/tests/flytekit/unit/core/test_structured_dataset_handlers.py index 508a7d01e0..bf6dad0b5d 100644 --- a/tests/flytekit/unit/core/test_structured_dataset_handlers.py +++ b/tests/flytekit/unit/core/test_structured_dataset_handlers.py @@ -1,5 +1,6 @@ import typing +import mock import pandas as pd import pyarrow as pa import pytest @@ -49,20 +50,64 @@ def test_csv(): df2 = decoder.decode(ctx, sd_lit, StructuredDatasetMetadata(sd_type)) assert df.equals(df2) -@pytest.mark.parametrize("format,encoder", [("parquet", basic_dfs.PandasToParquetEncodingHandler()), ("csv", basic_dfs.PandasToCSVEncodingHandler())]) -@mock.patch("pyarrow.parquet.write_table") + +@mock.patch("pandas.DataFrame.to_parquet") +@mock.patch("pandas.read_parquet") @mock.patch("flytekit.types.structured.basic_dfs.get_fsspec_storage_options") -def test_pandas_to_azure_initialises_filesystem_without_error(mock_get_fsspec_storage_options, mock_write_table, format, encoder): +def test_pandas_to_parquet_correct_storage_options_for_azure( + mock_get_fsspec_storage_options, mock_read_parquet, mock_to_parquet +): + df = pd.DataFrame({"Name": ["Tom", "Joseph"], "Age": [20, 22]}) + encoder = basic_dfs.PandasToParquetEncodingHandler() + decoder = basic_dfs.ParquetToPandasDecodingHandler() + mock_get_fsspec_storage_options.return_value = {"account_name": "accountname_from_storage_options"} + ctx = context_manager.FlyteContextManager.current_context() + sd = StructuredDataset(dataframe=df, uri="abfs://container/parquet_df") + sd_type = StructuredDatasetType(format="parquet") + sd_lit = encoder.encode(ctx, sd, sd_type) + mock_to_parquet.assert_called_once_with( + "abfs://container/parquet_df/00000", + coerce_timestamps=mock.ANY, + allow_truncated_timestamps=mock.ANY, + storage_options={"account_name": "accountname_from_storage_options"}, + ) + + decoder.decode(ctx, sd_lit, StructuredDatasetMetadata(sd_type)) + mock_read_parquet.assert_called_once_with( + "abfs://container/parquet_df", + columns=mock.ANY, + storage_options={"account_name": "accountname_from_storage_options"}, + ) + + +@mock.patch("pandas.DataFrame.to_csv") +@mock.patch("pandas.read_csv") +@mock.patch("flytekit.types.structured.basic_dfs.get_fsspec_storage_options") +def test_pandas_to_csv_correct_storage_options_for_azure( + mock_get_fsspec_storage_options, mock_read_parquet, mock_to_parquet +): df = pd.DataFrame({"Name": ["Tom", "Joseph"], "Age": [20, 22]}) + encoder = basic_dfs.PandasToCSVEncodingHandler() + decoder = basic_dfs.CSVToPandasDecodingHandler() + mock_get_fsspec_storage_options.return_value = {"account_name": "accountname_from_storage_options"} ctx = context_manager.FlyteContextManager.current_context() - sd = StructuredDataset(dataframe=df, uri="abfs://container/path/within/container") - sd_type = StructuredDatasetType(format=format) - encoder.encode(ctx, sd, sd_type) - mock_write_table.assert_called_once() - filesystem = mock_write_table.mock_calls[0].kwargs["filesystem"] - assert filesystem.account_name == "accountname_from_fsspec_storage_options" + sd = StructuredDataset(dataframe=df, uri="abfs://container/csv_df") + sd_type = StructuredDatasetType(format="csv") + sd_lit = encoder.encode(ctx, sd, sd_type) + mock_to_parquet.assert_called_once_with( + "abfs://container/csv_df/.csv", + index=mock.ANY, + storage_options={"account_name": "accountname_from_storage_options"}, + ) + + decoder.decode(ctx, sd_lit, StructuredDatasetMetadata(sd_type)) + mock_read_parquet.assert_called_once_with( + "abfs://container/csv_df/.csv", + usecols=mock.ANY, + storage_options={"account_name": "accountname_from_storage_options"}, + ) def test_base_isnt_instantiable(): From 71d2486f4d752be9a7b8cd2762615dd82756613f Mon Sep 17 00:00:00 2001 From: Thomas Newton Date: Mon, 25 Sep 2023 19:45:39 +0100 Subject: [PATCH 21/24] Fix get_pandas_storage_options Signed-off-by: Thomas Newton --- flytekit/types/structured/basic_dfs.py | 7 +++++-- .../flytekit/unit/core/test_structured_dataset_handlers.py | 4 ++-- 2 files changed, 7 insertions(+), 4 deletions(-) diff --git a/flytekit/types/structured/basic_dfs.py b/flytekit/types/structured/basic_dfs.py index c5592a02bb..b037817474 100644 --- a/flytekit/types/structured/basic_dfs.py +++ b/flytekit/types/structured/basic_dfs.py @@ -30,9 +30,12 @@ def get_pandas_storage_options( uri: str, data_config: DataConfig, anonymous: bool = False ) -> typing.Optional[typing.Dict]: - if pd.io.common.is_url(uri): + print(f"check is uri, uri = {uri}") + print(get_fsspec_storage_options) + if pd.io.common.is_fsspec_url(uri): return get_fsspec_storage_options(protocol=get_protocol(uri), data_config=data_config, anonymous=anonymous) - # Pandas does not allow storage_options for non-url paths. + + # Pandas does not allow storage_options for on-fsspec paths e.g. local. return None diff --git a/tests/flytekit/unit/core/test_structured_dataset_handlers.py b/tests/flytekit/unit/core/test_structured_dataset_handlers.py index bf6dad0b5d..6c0be791d8 100644 --- a/tests/flytekit/unit/core/test_structured_dataset_handlers.py +++ b/tests/flytekit/unit/core/test_structured_dataset_handlers.py @@ -54,7 +54,7 @@ def test_csv(): @mock.patch("pandas.DataFrame.to_parquet") @mock.patch("pandas.read_parquet") @mock.patch("flytekit.types.structured.basic_dfs.get_fsspec_storage_options") -def test_pandas_to_parquet_correct_storage_options_for_azure( +def test_pandas_to_parquet_azure_storage_options( mock_get_fsspec_storage_options, mock_read_parquet, mock_to_parquet ): df = pd.DataFrame({"Name": ["Tom", "Joseph"], "Age": [20, 22]}) @@ -84,7 +84,7 @@ def test_pandas_to_parquet_correct_storage_options_for_azure( @mock.patch("pandas.DataFrame.to_csv") @mock.patch("pandas.read_csv") @mock.patch("flytekit.types.structured.basic_dfs.get_fsspec_storage_options") -def test_pandas_to_csv_correct_storage_options_for_azure( +def test_pandas_to_csv_azure_storage_options( mock_get_fsspec_storage_options, mock_read_parquet, mock_to_parquet ): df = pd.DataFrame({"Name": ["Tom", "Joseph"], "Age": [20, 22]}) From 5dd00ed915bd28b79a1cf62a13c66c14cadf8902 Mon Sep 17 00:00:00 2001 From: Thomas Newton Date: Mon, 25 Sep 2023 22:54:47 +0100 Subject: [PATCH 22/24] Tidy Signed-off-by: Thomas Newton --- flytekit/types/structured/basic_dfs.py | 4 +--- .../unit/core/test_structured_dataset_handlers.py | 8 ++------ 2 files changed, 3 insertions(+), 9 deletions(-) diff --git a/flytekit/types/structured/basic_dfs.py b/flytekit/types/structured/basic_dfs.py index b037817474..2161c5b58a 100644 --- a/flytekit/types/structured/basic_dfs.py +++ b/flytekit/types/structured/basic_dfs.py @@ -30,12 +30,10 @@ def get_pandas_storage_options( uri: str, data_config: DataConfig, anonymous: bool = False ) -> typing.Optional[typing.Dict]: - print(f"check is uri, uri = {uri}") - print(get_fsspec_storage_options) if pd.io.common.is_fsspec_url(uri): return get_fsspec_storage_options(protocol=get_protocol(uri), data_config=data_config, anonymous=anonymous) - # Pandas does not allow storage_options for on-fsspec paths e.g. local. + # Pandas does not allow storage_options for non-fsspec paths e.g. local. return None diff --git a/tests/flytekit/unit/core/test_structured_dataset_handlers.py b/tests/flytekit/unit/core/test_structured_dataset_handlers.py index 6c0be791d8..1083bfa4b9 100644 --- a/tests/flytekit/unit/core/test_structured_dataset_handlers.py +++ b/tests/flytekit/unit/core/test_structured_dataset_handlers.py @@ -54,9 +54,7 @@ def test_csv(): @mock.patch("pandas.DataFrame.to_parquet") @mock.patch("pandas.read_parquet") @mock.patch("flytekit.types.structured.basic_dfs.get_fsspec_storage_options") -def test_pandas_to_parquet_azure_storage_options( - mock_get_fsspec_storage_options, mock_read_parquet, mock_to_parquet -): +def test_pandas_to_parquet_azure_storage_options(mock_get_fsspec_storage_options, mock_read_parquet, mock_to_parquet): df = pd.DataFrame({"Name": ["Tom", "Joseph"], "Age": [20, 22]}) encoder = basic_dfs.PandasToParquetEncodingHandler() decoder = basic_dfs.ParquetToPandasDecodingHandler() @@ -84,9 +82,7 @@ def test_pandas_to_parquet_azure_storage_options( @mock.patch("pandas.DataFrame.to_csv") @mock.patch("pandas.read_csv") @mock.patch("flytekit.types.structured.basic_dfs.get_fsspec_storage_options") -def test_pandas_to_csv_azure_storage_options( - mock_get_fsspec_storage_options, mock_read_parquet, mock_to_parquet -): +def test_pandas_to_csv_azure_storage_options(mock_get_fsspec_storage_options, mock_read_parquet, mock_to_parquet): df = pd.DataFrame({"Name": ["Tom", "Joseph"], "Age": [20, 22]}) encoder = basic_dfs.PandasToCSVEncodingHandler() decoder = basic_dfs.CSVToPandasDecodingHandler() From 572725bad5354aefaaf35481d2c39f1c8b66f285 Mon Sep 17 00:00:00 2001 From: Thomas Newton Date: Tue, 26 Sep 2023 23:00:10 +0100 Subject: [PATCH 23/24] Only assert on `storage_options` in structured dataset handler tests Signed-off-by: Thomas Newton --- .../core/test_structured_dataset_handlers.py | 33 +++++++------------ 1 file changed, 12 insertions(+), 21 deletions(-) diff --git a/tests/flytekit/unit/core/test_structured_dataset_handlers.py b/tests/flytekit/unit/core/test_structured_dataset_handlers.py index 1083bfa4b9..b26349ceeb 100644 --- a/tests/flytekit/unit/core/test_structured_dataset_handlers.py +++ b/tests/flytekit/unit/core/test_structured_dataset_handlers.py @@ -64,19 +64,14 @@ def test_pandas_to_parquet_azure_storage_options(mock_get_fsspec_storage_options sd = StructuredDataset(dataframe=df, uri="abfs://container/parquet_df") sd_type = StructuredDatasetType(format="parquet") sd_lit = encoder.encode(ctx, sd, sd_type) - mock_to_parquet.assert_called_once_with( - "abfs://container/parquet_df/00000", - coerce_timestamps=mock.ANY, - allow_truncated_timestamps=mock.ANY, - storage_options={"account_name": "accountname_from_storage_options"}, - ) + mock_to_parquet.assert_called_once() + write_storage_options = mock_to_parquet.call_args.kwargs["storage_options"] + assert write_storage_options == {"account_name": "accountname_from_storage_options"} decoder.decode(ctx, sd_lit, StructuredDatasetMetadata(sd_type)) - mock_read_parquet.assert_called_once_with( - "abfs://container/parquet_df", - columns=mock.ANY, - storage_options={"account_name": "accountname_from_storage_options"}, - ) + mock_read_parquet.assert_called_once() + read_storage_options = mock_read_parquet.call_args.kwargs["storage_options"] + read_storage_options == {"account_name": "accountname_from_storage_options"} @mock.patch("pandas.DataFrame.to_csv") @@ -92,18 +87,14 @@ def test_pandas_to_csv_azure_storage_options(mock_get_fsspec_storage_options, mo sd = StructuredDataset(dataframe=df, uri="abfs://container/csv_df") sd_type = StructuredDatasetType(format="csv") sd_lit = encoder.encode(ctx, sd, sd_type) - mock_to_parquet.assert_called_once_with( - "abfs://container/csv_df/.csv", - index=mock.ANY, - storage_options={"account_name": "accountname_from_storage_options"}, - ) + mock_to_parquet.assert_called_once() + write_storage_options = mock_to_parquet.call_args.kwargs["storage_options"] + assert write_storage_options == {"account_name": "accountname_from_storage_options"} decoder.decode(ctx, sd_lit, StructuredDatasetMetadata(sd_type)) - mock_read_parquet.assert_called_once_with( - "abfs://container/csv_df/.csv", - usecols=mock.ANY, - storage_options={"account_name": "accountname_from_storage_options"}, - ) + mock_read_parquet.assert_called_once() + read_storage_options = mock_read_parquet.call_args.kwargs["storage_options"] + read_storage_options == {"account_name": "accountname_from_storage_options"} def test_base_isnt_instantiable(): From 72a440eedab0a7a0c2cf92edee0ba3db9c763d49 Mon Sep 17 00:00:00 2001 From: Thomas Newton Date: Wed, 27 Sep 2023 23:04:44 +0100 Subject: [PATCH 24/24] Type hints Signed-off-by: Thomas Newton --- flytekit/core/data_persistence.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/flytekit/core/data_persistence.py b/flytekit/core/data_persistence.py index 923c6421c6..f7e04d6403 100644 --- a/flytekit/core/data_persistence.py +++ b/flytekit/core/data_persistence.py @@ -41,7 +41,7 @@ _ANON = "anon" -def s3_setup_args(s3_cfg: configuration.S3Config, anonymous: bool = False): +def s3_setup_args(s3_cfg: configuration.S3Config, anonymous: bool = False) -> Dict[str, Any]: kwargs: Dict[str, Any] = { "cache_regions": True, } @@ -61,7 +61,7 @@ def s3_setup_args(s3_cfg: configuration.S3Config, anonymous: bool = False): return kwargs -def azure_setup_args(azure_cfg: configuration.AzureBlobStorageConfig, anonymous: bool = False): +def azure_setup_args(azure_cfg: configuration.AzureBlobStorageConfig, anonymous: bool = False) -> Dict[str, Any]: kwargs: Dict[str, Any] = {} if azure_cfg.account_name: @@ -80,7 +80,7 @@ def azure_setup_args(azure_cfg: configuration.AzureBlobStorageConfig, anonymous: def get_fsspec_storage_options( protocol: str, data_config: typing.Optional[DataConfig] = None, anonymous: bool = False, **kwargs -) -> typing.Dict: +) -> Dict[str, Any]: data_config = data_config or DataConfig.auto() if protocol == "file":