diff --git a/flytekit/configuration/__init__.py b/flytekit/configuration/__init__.py index 0058bdc551..f842f1451b 100644 --- a/flytekit/configuration/__init__.py +++ b/flytekit/configuration/__init__.py @@ -565,6 +565,30 @@ def auto(cls, config_file: typing.Union[str, ConfigFile] = None) -> GCSConfig: return GCSConfig(**kwargs) +@dataclass(init=True, repr=True, eq=True, frozen=True) +class AzureBlobStorageConfig(object): + """ + Any Azure Blob Storage specific configuration. + """ + + account_name: typing.Optional[str] = None + account_key: typing.Optional[str] = None + tenant_id: typing.Optional[str] = None + client_id: typing.Optional[str] = None + client_secret: typing.Optional[str] = None + + @classmethod + def auto(cls, config_file: typing.Union[str, ConfigFile] = None) -> GCSConfig: + config_file = get_config_file(config_file) + kwargs = {} + kwargs = set_if_exists(kwargs, "account_name", _internal.AZURE.STORAGE_ACCOUNT_NAME.read(config_file)) + kwargs = set_if_exists(kwargs, "account_key", _internal.AZURE.STORAGE_ACCOUNT_KEY.read(config_file)) + kwargs = set_if_exists(kwargs, "tenant_id", _internal.AZURE.TENANT_ID.read(config_file)) + kwargs = set_if_exists(kwargs, "client_id", _internal.AZURE.CLIENT_ID.read(config_file)) + kwargs = set_if_exists(kwargs, "client_secret", _internal.AZURE.CLIENT_SECRET.read(config_file)) + return AzureBlobStorageConfig(**kwargs) + + @dataclass(init=True, repr=True, eq=True, frozen=True) class DataConfig(object): """ @@ -575,11 +599,13 @@ class DataConfig(object): s3: S3Config = S3Config() gcs: GCSConfig = GCSConfig() + azure: AzureBlobStorageConfig = AzureBlobStorageConfig() @classmethod def auto(cls, config_file: typing.Union[str, ConfigFile] = None) -> DataConfig: config_file = get_config_file(config_file) return DataConfig( + azure=AzureBlobStorageConfig.auto(config_file), s3=S3Config.auto(config_file), gcs=GCSConfig.auto(config_file), ) diff --git a/flytekit/configuration/internal.py b/flytekit/configuration/internal.py index 9d1980c450..b12103a3fd 100644 --- a/flytekit/configuration/internal.py +++ b/flytekit/configuration/internal.py @@ -57,6 +57,15 @@ class GCP(object): GSUTIL_PARALLELISM = ConfigEntry(LegacyConfigEntry(SECTION, "gsutil_parallelism", bool)) +class AZURE(object): + SECTION = "azure" + STORAGE_ACCOUNT_NAME = ConfigEntry(LegacyConfigEntry(SECTION, "storage_account_name")) + STORAGE_ACCOUNT_KEY = ConfigEntry(LegacyConfigEntry(SECTION, "storage_account_key")) + TENANT_ID = ConfigEntry(LegacyConfigEntry(SECTION, "tenant_id")) + CLIENT_ID = ConfigEntry(LegacyConfigEntry(SECTION, "client_id")) + CLIENT_SECRET = ConfigEntry(LegacyConfigEntry(SECTION, "client_secret")) + + class Credentials(object): SECTION = "credentials" COMMAND = ConfigEntry(LegacyConfigEntry(SECTION, "command", list), YamlConfigEntry("admin.command", list)) diff --git a/flytekit/core/data_persistence.py b/flytekit/core/data_persistence.py index 57311ed415..f7e04d6403 100644 --- a/flytekit/core/data_persistence.py +++ b/flytekit/core/data_persistence.py @@ -41,7 +41,7 @@ _ANON = "anon" -def s3_setup_args(s3_cfg: configuration.S3Config, anonymous: bool = False): +def s3_setup_args(s3_cfg: configuration.S3Config, anonymous: bool = False) -> Dict[str, Any]: kwargs: Dict[str, Any] = { "cache_regions": True, } @@ -61,6 +61,41 @@ def s3_setup_args(s3_cfg: configuration.S3Config, anonymous: bool = False): return kwargs +def azure_setup_args(azure_cfg: configuration.AzureBlobStorageConfig, anonymous: bool = False) -> Dict[str, Any]: + kwargs: Dict[str, Any] = {} + + if azure_cfg.account_name: + kwargs["account_name"] = azure_cfg.account_name + if azure_cfg.account_key: + kwargs["account_key"] = azure_cfg.account_key + if azure_cfg.client_id: + kwargs["client_id"] = azure_cfg.client_id + if azure_cfg.client_secret: + kwargs["client_secret"] = azure_cfg.client_secret + if azure_cfg.tenant_id: + kwargs["tenant_id"] = azure_cfg.tenant_id + kwargs[_ANON] = anonymous + return kwargs + + +def get_fsspec_storage_options( + protocol: str, data_config: typing.Optional[DataConfig] = None, anonymous: bool = False, **kwargs +) -> Dict[str, Any]: + data_config = data_config or DataConfig.auto() + + if protocol == "file": + return {"auto_mkdir": True, **kwargs} + if protocol == "s3": + return {**s3_setup_args(data_config.s3, anonymous=anonymous), **kwargs} + if protocol == "gs": + if anonymous: + kwargs["token"] = _ANON + return kwargs + if protocol in ("abfs", "abfss"): + return {**azure_setup_args(data_config.azure, anonymous=anonymous), **kwargs} + return {} + + class FileAccessProvider(object): """ This is the class that is available through the FlyteContext and can be used for persisting data to the remote @@ -106,28 +141,15 @@ def data_config(self) -> DataConfig: def get_filesystem( self, protocol: typing.Optional[str] = None, anonymous: bool = False, **kwargs - ) -> typing.Optional[fsspec.AbstractFileSystem]: + ) -> fsspec.AbstractFileSystem: if not protocol: return self._default_remote - if protocol == "file": - kwargs["auto_mkdir"] = True - elif protocol == "s3": - s3kwargs = s3_setup_args(self._data_config.s3, anonymous=anonymous) - s3kwargs.update(kwargs) - return fsspec.filesystem(protocol, **s3kwargs) # type: ignore - elif protocol == "gs": - if anonymous: - kwargs["token"] = _ANON - return fsspec.filesystem(protocol, **kwargs) # type: ignore - elif protocol == "abfs": - kwargs["anon"] = False - return fsspec.filesystem(protocol, **kwargs) # type: ignore - - # Preserve old behavior of returning None for file systems that don't have an explicit anonymous option. - if anonymous: - return None - return fsspec.filesystem(protocol, **kwargs) # type: ignore + storage_options = get_fsspec_storage_options( + protocol=protocol, anonymous=anonymous, data_config=self._data_config, **kwargs + ) + + return fsspec.filesystem(protocol, **storage_options) def get_filesystem_for_path(self, path: str = "", anonymous: bool = False, **kwargs) -> fsspec.AbstractFileSystem: protocol = get_protocol(path) diff --git a/flytekit/types/structured/basic_dfs.py b/flytekit/types/structured/basic_dfs.py index 98a12ae44d..2161c5b58a 100644 --- a/flytekit/types/structured/basic_dfs.py +++ b/flytekit/types/structured/basic_dfs.py @@ -12,7 +12,7 @@ from flytekit import FlyteContext, logger from flytekit.configuration import DataConfig -from flytekit.core.data_persistence import s3_setup_args +from flytekit.core.data_persistence import get_fsspec_storage_options from flytekit.models import literals from flytekit.models.literals import StructuredDatasetMetadata from flytekit.models.types import StructuredDatasetType @@ -27,12 +27,13 @@ T = TypeVar("T") -def get_storage_options(cfg: DataConfig, uri: str, anon: bool = False) -> typing.Optional[typing.Dict]: - protocol = get_protocol(uri) - if protocol == "s3": - kwargs = s3_setup_args(cfg.s3, anon) - if kwargs: - return kwargs +def get_pandas_storage_options( + uri: str, data_config: DataConfig, anonymous: bool = False +) -> typing.Optional[typing.Dict]: + if pd.io.common.is_fsspec_url(uri): + return get_fsspec_storage_options(protocol=get_protocol(uri), data_config=data_config, anonymous=anonymous) + + # Pandas does not allow storage_options for non-fsspec paths e.g. local. return None @@ -54,7 +55,7 @@ def encode( df.to_csv( path, index=False, - storage_options=get_storage_options(ctx.file_access.data_config, path), + storage_options=get_pandas_storage_options(uri=path, data_config=ctx.file_access.data_config), ) structured_dataset_type.format = CSV return literals.StructuredDataset(uri=uri, metadata=StructuredDatasetMetadata(structured_dataset_type)) @@ -72,7 +73,7 @@ def decode( ) -> pd.DataFrame: uri = flyte_value.uri columns = None - kwargs = get_storage_options(ctx.file_access.data_config, uri) + kwargs = get_pandas_storage_options(uri=uri, data_config=ctx.file_access.data_config) path = os.path.join(uri, ".csv") if current_task_metadata.structured_dataset_type and current_task_metadata.structured_dataset_type.columns: columns = [c.name for c in current_task_metadata.structured_dataset_type.columns] @@ -80,7 +81,7 @@ def decode( return pd.read_csv(path, usecols=columns, storage_options=kwargs) except NoCredentialsError: logger.debug("S3 source detected, attempting anonymous S3 access") - kwargs = get_storage_options(ctx.file_access.data_config, uri, anon=True) + kwargs = get_pandas_storage_options(uri=uri, data_config=ctx.file_access.data_config, anonymous=True) return pd.read_csv(path, usecols=columns, storage_options=kwargs) @@ -103,7 +104,7 @@ def encode( path, coerce_timestamps="us", allow_truncated_timestamps=False, - storage_options=get_storage_options(ctx.file_access.data_config, path), + storage_options=get_pandas_storage_options(uri=path, data_config=ctx.file_access.data_config), ) structured_dataset_type.format = PARQUET return literals.StructuredDataset(uri=uri, metadata=StructuredDatasetMetadata(structured_dataset_type)) @@ -121,14 +122,14 @@ def decode( ) -> pd.DataFrame: uri = flyte_value.uri columns = None - kwargs = get_storage_options(ctx.file_access.data_config, uri) + kwargs = get_pandas_storage_options(uri=uri, data_config=ctx.file_access.data_config) if current_task_metadata.structured_dataset_type and current_task_metadata.structured_dataset_type.columns: columns = [c.name for c in current_task_metadata.structured_dataset_type.columns] try: return pd.read_parquet(uri, columns=columns, storage_options=kwargs) except NoCredentialsError: logger.debug("S3 source detected, attempting anonymous S3 access") - kwargs = get_storage_options(ctx.file_access.data_config, uri, anon=True) + kwargs = get_pandas_storage_options(uri=uri, data_config=ctx.file_access.data_config, anonymous=True) return pd.read_parquet(uri, columns=columns, storage_options=kwargs) diff --git a/plugins/flytekit-polars/flytekitplugins/polars/sd_transformers.py b/plugins/flytekit-polars/flytekitplugins/polars/sd_transformers.py index 4290c88ae4..ea644dc078 100644 --- a/plugins/flytekit-polars/flytekitplugins/polars/sd_transformers.py +++ b/plugins/flytekit-polars/flytekitplugins/polars/sd_transformers.py @@ -2,12 +2,13 @@ import pandas as pd import polars as pl +from fsspec.utils import get_protocol from flytekit import FlyteContext +from flytekit.core.data_persistence import get_fsspec_storage_options from flytekit.models import literals from flytekit.models.literals import StructuredDatasetMetadata from flytekit.models.types import StructuredDatasetType -from flytekit.types.structured.basic_dfs import get_storage_options from flytekit.types.structured.structured_dataset import ( PARQUET, StructuredDataset, @@ -64,7 +65,7 @@ def decode( current_task_metadata: StructuredDatasetMetadata, ) -> pl.DataFrame: uri = flyte_value.uri - kwargs = get_storage_options(ctx.file_access.data_config, uri) + kwargs = get_fsspec_storage_options(protocol=get_protocol(uri), data_config=ctx.file_access.data_config) if current_task_metadata.structured_dataset_type and current_task_metadata.structured_dataset_type.columns: columns = [c.name for c in current_task_metadata.structured_dataset_type.columns] return pl.read_parquet(uri, columns=columns, use_pyarrow=True, storage_options=kwargs) diff --git a/tests/flytekit/unit/core/test_data.py b/tests/flytekit/unit/core/test_data.py index 2d61b58d8c..667445321b 100644 --- a/tests/flytekit/unit/core/test_data.py +++ b/tests/flytekit/unit/core/test_data.py @@ -8,9 +8,14 @@ import mock import pytest -from flytekit.configuration import Config, S3Config +from flytekit.configuration import Config, DataConfig, S3Config from flytekit.core.context_manager import FlyteContextManager -from flytekit.core.data_persistence import FileAccessProvider, default_local_file_access_provider, s3_setup_args +from flytekit.core.data_persistence import ( + FileAccessProvider, + default_local_file_access_provider, + get_fsspec_storage_options, + s3_setup_args, +) from flytekit.types.directory.types import FlyteDirectory local = fsspec.filesystem("file") @@ -221,6 +226,73 @@ def test_s3_setup_args_env_aws(mock_os, mock_get_config_file): assert kwargs == {"cache_regions": True} +@mock.patch("flytekit.configuration.get_config_file") +@mock.patch("os.environ") +def test_get_fsspec_storage_options_gcs(mock_os, mock_get_config_file): + mock_get_config_file.return_value = None + ee = { + "FLYTE_GCP_GSUTIL_PARALLELISM": "False", + } + mock_os.get.side_effect = lambda x, y: ee.get(x) + storage_options = get_fsspec_storage_options("gs", DataConfig.auto()) + assert storage_options == {} + + +@mock.patch("flytekit.configuration.get_config_file") +@mock.patch("os.environ") +def test_get_fsspec_storage_options_gcs_with_overrides(mock_os, mock_get_config_file): + mock_get_config_file.return_value = None + ee = { + "FLYTE_GCP_GSUTIL_PARALLELISM": "False", + } + mock_os.get.side_effect = lambda x, y: ee.get(x) + storage_options = get_fsspec_storage_options("gs", DataConfig.auto(), anonymous=True, other_argument="value") + assert storage_options == {"token": "anon", "other_argument": "value"} + + +@mock.patch("flytekit.configuration.get_config_file") +@mock.patch("os.environ") +def test_get_fsspec_storage_options_azure(mock_os, mock_get_config_file): + mock_get_config_file.return_value = None + ee = { + "FLYTE_AZURE_STORAGE_ACCOUNT_NAME": "accountname", + "FLYTE_AZURE_STORAGE_ACCOUNT_KEY": "accountkey", + "FLYTE_AZURE_TENANT_ID": "tenantid", + "FLYTE_AZURE_CLIENT_ID": "clientid", + "FLYTE_AZURE_CLIENT_SECRET": "clientsecret", + } + mock_os.get.side_effect = lambda x, y: ee.get(x) + storage_options = get_fsspec_storage_options("abfs", DataConfig.auto()) + assert storage_options == { + "account_name": "accountname", + "account_key": "accountkey", + "client_id": "clientid", + "client_secret": "clientsecret", + "tenant_id": "tenantid", + "anon": False, + } + + +@mock.patch("flytekit.configuration.get_config_file") +@mock.patch("os.environ") +def test_get_fsspec_storage_options_azure_with_overrides(mock_os, mock_get_config_file): + mock_get_config_file.return_value = None + ee = { + "FLYTE_AZURE_STORAGE_ACCOUNT_NAME": "accountname", + "FLYTE_AZURE_STORAGE_ACCOUNT_KEY": "accountkey", + } + mock_os.get.side_effect = lambda x, y: ee.get(x) + storage_options = get_fsspec_storage_options( + "abfs", DataConfig.auto(), anonymous=True, account_name="other_accountname", other_argument="value" + ) + assert storage_options == { + "account_name": "other_accountname", + "account_key": "accountkey", + "anon": True, + "other_argument": "value", + } + + def test_crawl_local_nt(source_folder): """ running this to see what it prints diff --git a/tests/flytekit/unit/core/test_data_persistence.py b/tests/flytekit/unit/core/test_data_persistence.py index 27b407c1ce..2fc8b6c452 100644 --- a/tests/flytekit/unit/core/test_data_persistence.py +++ b/tests/flytekit/unit/core/test_data_persistence.py @@ -1,3 +1,8 @@ +import os + +import mock +from azure.identity import ClientSecretCredential, DefaultAzureCredential + from flytekit.core.data_persistence import FileAccessProvider @@ -14,3 +19,39 @@ def test_is_remote(): assert fp.is_remote("/tmp/foo/bar") is False assert fp.is_remote("file://foo/bar") is False assert fp.is_remote("s3://my-bucket/foo/bar") is True + + +def test_initialise_azure_file_provider_with_account_key(): + with mock.patch.dict( + os.environ, + {"FLYTE_AZURE_STORAGE_ACCOUNT_NAME": "accountname", "FLYTE_AZURE_STORAGE_ACCOUNT_KEY": "accountkey"}, + ): + fp = FileAccessProvider("/tmp", "abfs://container/path/within/container") + assert fp.get_filesystem().account_name == "accountname" + assert fp.get_filesystem().account_key == "accountkey" + assert fp.get_filesystem().sync_credential is None + + +def test_initialise_azure_file_provider_with_service_principal(): + with mock.patch.dict( + os.environ, + { + "FLYTE_AZURE_STORAGE_ACCOUNT_NAME": "accountname", + "FLYTE_AZURE_CLIENT_SECRET": "clientsecret", + "FLYTE_AZURE_CLIENT_ID": "clientid", + "FLYTE_AZURE_TENANT_ID": "tenantid", + }, + ): + fp = FileAccessProvider("/tmp", "abfs://container/path/within/container") + assert fp.get_filesystem().account_name == "accountname" + assert isinstance(fp.get_filesystem().sync_credential, ClientSecretCredential) + assert fp.get_filesystem().client_secret == "clientsecret" + assert fp.get_filesystem().client_id == "clientid" + assert fp.get_filesystem().tenant_id == "tenantid" + + +def test_initialise_azure_file_provider_with_default_credential(): + with mock.patch.dict(os.environ, {"FLYTE_AZURE_STORAGE_ACCOUNT_NAME": "accountname"}): + fp = FileAccessProvider("/tmp", "abfs://container/path/within/container") + assert fp.get_filesystem().account_name == "accountname" + assert isinstance(fp.get_filesystem().sync_credential, DefaultAzureCredential) diff --git a/tests/flytekit/unit/core/test_structured_dataset_handlers.py b/tests/flytekit/unit/core/test_structured_dataset_handlers.py index 4b9d183ad8..b26349ceeb 100644 --- a/tests/flytekit/unit/core/test_structured_dataset_handlers.py +++ b/tests/flytekit/unit/core/test_structured_dataset_handlers.py @@ -1,5 +1,6 @@ import typing +import mock import pandas as pd import pyarrow as pa import pytest @@ -50,6 +51,52 @@ def test_csv(): assert df.equals(df2) +@mock.patch("pandas.DataFrame.to_parquet") +@mock.patch("pandas.read_parquet") +@mock.patch("flytekit.types.structured.basic_dfs.get_fsspec_storage_options") +def test_pandas_to_parquet_azure_storage_options(mock_get_fsspec_storage_options, mock_read_parquet, mock_to_parquet): + df = pd.DataFrame({"Name": ["Tom", "Joseph"], "Age": [20, 22]}) + encoder = basic_dfs.PandasToParquetEncodingHandler() + decoder = basic_dfs.ParquetToPandasDecodingHandler() + + mock_get_fsspec_storage_options.return_value = {"account_name": "accountname_from_storage_options"} + ctx = context_manager.FlyteContextManager.current_context() + sd = StructuredDataset(dataframe=df, uri="abfs://container/parquet_df") + sd_type = StructuredDatasetType(format="parquet") + sd_lit = encoder.encode(ctx, sd, sd_type) + mock_to_parquet.assert_called_once() + write_storage_options = mock_to_parquet.call_args.kwargs["storage_options"] + assert write_storage_options == {"account_name": "accountname_from_storage_options"} + + decoder.decode(ctx, sd_lit, StructuredDatasetMetadata(sd_type)) + mock_read_parquet.assert_called_once() + read_storage_options = mock_read_parquet.call_args.kwargs["storage_options"] + read_storage_options == {"account_name": "accountname_from_storage_options"} + + +@mock.patch("pandas.DataFrame.to_csv") +@mock.patch("pandas.read_csv") +@mock.patch("flytekit.types.structured.basic_dfs.get_fsspec_storage_options") +def test_pandas_to_csv_azure_storage_options(mock_get_fsspec_storage_options, mock_read_parquet, mock_to_parquet): + df = pd.DataFrame({"Name": ["Tom", "Joseph"], "Age": [20, 22]}) + encoder = basic_dfs.PandasToCSVEncodingHandler() + decoder = basic_dfs.CSVToPandasDecodingHandler() + + mock_get_fsspec_storage_options.return_value = {"account_name": "accountname_from_storage_options"} + ctx = context_manager.FlyteContextManager.current_context() + sd = StructuredDataset(dataframe=df, uri="abfs://container/csv_df") + sd_type = StructuredDatasetType(format="csv") + sd_lit = encoder.encode(ctx, sd, sd_type) + mock_to_parquet.assert_called_once() + write_storage_options = mock_to_parquet.call_args.kwargs["storage_options"] + assert write_storage_options == {"account_name": "accountname_from_storage_options"} + + decoder.decode(ctx, sd_lit, StructuredDatasetMetadata(sd_type)) + mock_read_parquet.assert_called_once() + read_storage_options = mock_read_parquet.call_args.kwargs["storage_options"] + read_storage_options == {"account_name": "accountname_from_storage_options"} + + def test_base_isnt_instantiable(): with pytest.raises(TypeError): StructuredDatasetEncoder(pd.DataFrame, "", "")