Skip to content
This repository has been archived by the owner on Jul 19, 2024. It is now read-only.

Commit

Permalink
Better Azure blob storage support (flyteorg#1842)
Browse files Browse the repository at this point in the history
Signed-off-by: Thomas Newton <[email protected]>
Signed-off-by: Future Outlier <[email protected]>
  • Loading branch information
Tom-Newton authored and Future Outlier committed Oct 3, 2023
1 parent dce7e23 commit 483082a
Show file tree
Hide file tree
Showing 8 changed files with 256 additions and 37 deletions.
26 changes: 26 additions & 0 deletions flytekit/configuration/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -565,6 +565,30 @@ def auto(cls, config_file: typing.Union[str, ConfigFile] = None) -> GCSConfig:
return GCSConfig(**kwargs)


@dataclass(init=True, repr=True, eq=True, frozen=True)
class AzureBlobStorageConfig(object):
"""
Any Azure Blob Storage specific configuration.
"""

account_name: typing.Optional[str] = None
account_key: typing.Optional[str] = None
tenant_id: typing.Optional[str] = None
client_id: typing.Optional[str] = None
client_secret: typing.Optional[str] = None

@classmethod
def auto(cls, config_file: typing.Union[str, ConfigFile] = None) -> GCSConfig:
config_file = get_config_file(config_file)
kwargs = {}
kwargs = set_if_exists(kwargs, "account_name", _internal.AZURE.STORAGE_ACCOUNT_NAME.read(config_file))
kwargs = set_if_exists(kwargs, "account_key", _internal.AZURE.STORAGE_ACCOUNT_KEY.read(config_file))
kwargs = set_if_exists(kwargs, "tenant_id", _internal.AZURE.TENANT_ID.read(config_file))
kwargs = set_if_exists(kwargs, "client_id", _internal.AZURE.CLIENT_ID.read(config_file))
kwargs = set_if_exists(kwargs, "client_secret", _internal.AZURE.CLIENT_SECRET.read(config_file))
return AzureBlobStorageConfig(**kwargs)


@dataclass(init=True, repr=True, eq=True, frozen=True)
class DataConfig(object):
"""
Expand All @@ -575,11 +599,13 @@ class DataConfig(object):

s3: S3Config = S3Config()
gcs: GCSConfig = GCSConfig()
azure: AzureBlobStorageConfig = AzureBlobStorageConfig()

@classmethod
def auto(cls, config_file: typing.Union[str, ConfigFile] = None) -> DataConfig:
config_file = get_config_file(config_file)
return DataConfig(
azure=AzureBlobStorageConfig.auto(config_file),
s3=S3Config.auto(config_file),
gcs=GCSConfig.auto(config_file),
)
Expand Down
9 changes: 9 additions & 0 deletions flytekit/configuration/internal.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,15 @@ class GCP(object):
GSUTIL_PARALLELISM = ConfigEntry(LegacyConfigEntry(SECTION, "gsutil_parallelism", bool))


class AZURE(object):
SECTION = "azure"
STORAGE_ACCOUNT_NAME = ConfigEntry(LegacyConfigEntry(SECTION, "storage_account_name"))
STORAGE_ACCOUNT_KEY = ConfigEntry(LegacyConfigEntry(SECTION, "storage_account_key"))
TENANT_ID = ConfigEntry(LegacyConfigEntry(SECTION, "tenant_id"))
CLIENT_ID = ConfigEntry(LegacyConfigEntry(SECTION, "client_id"))
CLIENT_SECRET = ConfigEntry(LegacyConfigEntry(SECTION, "client_secret"))


class Credentials(object):
SECTION = "credentials"
COMMAND = ConfigEntry(LegacyConfigEntry(SECTION, "command", list), YamlConfigEntry("admin.command", list))
Expand Down
62 changes: 42 additions & 20 deletions flytekit/core/data_persistence.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@
_ANON = "anon"


def s3_setup_args(s3_cfg: configuration.S3Config, anonymous: bool = False):
def s3_setup_args(s3_cfg: configuration.S3Config, anonymous: bool = False) -> Dict[str, Any]:
kwargs: Dict[str, Any] = {
"cache_regions": True,
}
Expand All @@ -61,6 +61,41 @@ def s3_setup_args(s3_cfg: configuration.S3Config, anonymous: bool = False):
return kwargs


def azure_setup_args(azure_cfg: configuration.AzureBlobStorageConfig, anonymous: bool = False) -> Dict[str, Any]:
kwargs: Dict[str, Any] = {}

if azure_cfg.account_name:
kwargs["account_name"] = azure_cfg.account_name
if azure_cfg.account_key:
kwargs["account_key"] = azure_cfg.account_key
if azure_cfg.client_id:
kwargs["client_id"] = azure_cfg.client_id
if azure_cfg.client_secret:
kwargs["client_secret"] = azure_cfg.client_secret
if azure_cfg.tenant_id:
kwargs["tenant_id"] = azure_cfg.tenant_id
kwargs[_ANON] = anonymous
return kwargs


def get_fsspec_storage_options(
protocol: str, data_config: typing.Optional[DataConfig] = None, anonymous: bool = False, **kwargs
) -> Dict[str, Any]:
data_config = data_config or DataConfig.auto()

if protocol == "file":
return {"auto_mkdir": True, **kwargs}
if protocol == "s3":
return {**s3_setup_args(data_config.s3, anonymous=anonymous), **kwargs}
if protocol == "gs":
if anonymous:
kwargs["token"] = _ANON
return kwargs
if protocol in ("abfs", "abfss"):
return {**azure_setup_args(data_config.azure, anonymous=anonymous), **kwargs}
return {}


class FileAccessProvider(object):
"""
This is the class that is available through the FlyteContext and can be used for persisting data to the remote
Expand Down Expand Up @@ -106,28 +141,15 @@ def data_config(self) -> DataConfig:

def get_filesystem(
self, protocol: typing.Optional[str] = None, anonymous: bool = False, **kwargs
) -> typing.Optional[fsspec.AbstractFileSystem]:
) -> fsspec.AbstractFileSystem:
if not protocol:
return self._default_remote
if protocol == "file":
kwargs["auto_mkdir"] = True
elif protocol == "s3":
s3kwargs = s3_setup_args(self._data_config.s3, anonymous=anonymous)
s3kwargs.update(kwargs)
return fsspec.filesystem(protocol, **s3kwargs) # type: ignore
elif protocol == "gs":
if anonymous:
kwargs["token"] = _ANON
return fsspec.filesystem(protocol, **kwargs) # type: ignore
elif protocol == "abfs":
kwargs["anon"] = False
return fsspec.filesystem(protocol, **kwargs) # type: ignore

# Preserve old behavior of returning None for file systems that don't have an explicit anonymous option.
if anonymous:
return None

return fsspec.filesystem(protocol, **kwargs) # type: ignore
storage_options = get_fsspec_storage_options(
protocol=protocol, anonymous=anonymous, data_config=self._data_config, **kwargs
)

return fsspec.filesystem(protocol, **storage_options)

def get_filesystem_for_path(self, path: str = "", anonymous: bool = False, **kwargs) -> fsspec.AbstractFileSystem:
protocol = get_protocol(path)
Expand Down
27 changes: 14 additions & 13 deletions flytekit/types/structured/basic_dfs.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@

from flytekit import FlyteContext, logger
from flytekit.configuration import DataConfig
from flytekit.core.data_persistence import s3_setup_args
from flytekit.core.data_persistence import get_fsspec_storage_options
from flytekit.models import literals
from flytekit.models.literals import StructuredDatasetMetadata
from flytekit.models.types import StructuredDatasetType
Expand All @@ -27,12 +27,13 @@
T = TypeVar("T")


def get_storage_options(cfg: DataConfig, uri: str, anon: bool = False) -> typing.Optional[typing.Dict]:
protocol = get_protocol(uri)
if protocol == "s3":
kwargs = s3_setup_args(cfg.s3, anon)
if kwargs:
return kwargs
def get_pandas_storage_options(
uri: str, data_config: DataConfig, anonymous: bool = False
) -> typing.Optional[typing.Dict]:
if pd.io.common.is_fsspec_url(uri):
return get_fsspec_storage_options(protocol=get_protocol(uri), data_config=data_config, anonymous=anonymous)

# Pandas does not allow storage_options for non-fsspec paths e.g. local.
return None


Expand All @@ -54,7 +55,7 @@ def encode(
df.to_csv(
path,
index=False,
storage_options=get_storage_options(ctx.file_access.data_config, path),
storage_options=get_pandas_storage_options(uri=path, data_config=ctx.file_access.data_config),
)
structured_dataset_type.format = CSV
return literals.StructuredDataset(uri=uri, metadata=StructuredDatasetMetadata(structured_dataset_type))
Expand All @@ -72,15 +73,15 @@ def decode(
) -> pd.DataFrame:
uri = flyte_value.uri
columns = None
kwargs = get_storage_options(ctx.file_access.data_config, uri)
kwargs = get_pandas_storage_options(uri=uri, data_config=ctx.file_access.data_config)
path = os.path.join(uri, ".csv")
if current_task_metadata.structured_dataset_type and current_task_metadata.structured_dataset_type.columns:
columns = [c.name for c in current_task_metadata.structured_dataset_type.columns]
try:
return pd.read_csv(path, usecols=columns, storage_options=kwargs)
except NoCredentialsError:
logger.debug("S3 source detected, attempting anonymous S3 access")
kwargs = get_storage_options(ctx.file_access.data_config, uri, anon=True)
kwargs = get_pandas_storage_options(uri=uri, data_config=ctx.file_access.data_config, anonymous=True)
return pd.read_csv(path, usecols=columns, storage_options=kwargs)


Expand All @@ -103,7 +104,7 @@ def encode(
path,
coerce_timestamps="us",
allow_truncated_timestamps=False,
storage_options=get_storage_options(ctx.file_access.data_config, path),
storage_options=get_pandas_storage_options(uri=path, data_config=ctx.file_access.data_config),
)
structured_dataset_type.format = PARQUET
return literals.StructuredDataset(uri=uri, metadata=StructuredDatasetMetadata(structured_dataset_type))
Expand All @@ -121,14 +122,14 @@ def decode(
) -> pd.DataFrame:
uri = flyte_value.uri
columns = None
kwargs = get_storage_options(ctx.file_access.data_config, uri)
kwargs = get_pandas_storage_options(uri=uri, data_config=ctx.file_access.data_config)
if current_task_metadata.structured_dataset_type and current_task_metadata.structured_dataset_type.columns:
columns = [c.name for c in current_task_metadata.structured_dataset_type.columns]
try:
return pd.read_parquet(uri, columns=columns, storage_options=kwargs)
except NoCredentialsError:
logger.debug("S3 source detected, attempting anonymous S3 access")
kwargs = get_storage_options(ctx.file_access.data_config, uri, anon=True)
kwargs = get_pandas_storage_options(uri=uri, data_config=ctx.file_access.data_config, anonymous=True)
return pd.read_parquet(uri, columns=columns, storage_options=kwargs)


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,13 @@

import pandas as pd
import polars as pl
from fsspec.utils import get_protocol

from flytekit import FlyteContext
from flytekit.core.data_persistence import get_fsspec_storage_options
from flytekit.models import literals
from flytekit.models.literals import StructuredDatasetMetadata
from flytekit.models.types import StructuredDatasetType
from flytekit.types.structured.basic_dfs import get_storage_options
from flytekit.types.structured.structured_dataset import (
PARQUET,
StructuredDataset,
Expand Down Expand Up @@ -64,7 +65,7 @@ def decode(
current_task_metadata: StructuredDatasetMetadata,
) -> pl.DataFrame:
uri = flyte_value.uri
kwargs = get_storage_options(ctx.file_access.data_config, uri)
kwargs = get_fsspec_storage_options(protocol=get_protocol(uri), data_config=ctx.file_access.data_config)
if current_task_metadata.structured_dataset_type and current_task_metadata.structured_dataset_type.columns:
columns = [c.name for c in current_task_metadata.structured_dataset_type.columns]
return pl.read_parquet(uri, columns=columns, use_pyarrow=True, storage_options=kwargs)
Expand Down
76 changes: 74 additions & 2 deletions tests/flytekit/unit/core/test_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,14 @@
import mock
import pytest

from flytekit.configuration import Config, S3Config
from flytekit.configuration import Config, DataConfig, S3Config
from flytekit.core.context_manager import FlyteContextManager
from flytekit.core.data_persistence import FileAccessProvider, default_local_file_access_provider, s3_setup_args
from flytekit.core.data_persistence import (
FileAccessProvider,
default_local_file_access_provider,
get_fsspec_storage_options,
s3_setup_args,
)
from flytekit.types.directory.types import FlyteDirectory

local = fsspec.filesystem("file")
Expand Down Expand Up @@ -221,6 +226,73 @@ def test_s3_setup_args_env_aws(mock_os, mock_get_config_file):
assert kwargs == {"cache_regions": True}


@mock.patch("flytekit.configuration.get_config_file")
@mock.patch("os.environ")
def test_get_fsspec_storage_options_gcs(mock_os, mock_get_config_file):
mock_get_config_file.return_value = None
ee = {
"FLYTE_GCP_GSUTIL_PARALLELISM": "False",
}
mock_os.get.side_effect = lambda x, y: ee.get(x)
storage_options = get_fsspec_storage_options("gs", DataConfig.auto())
assert storage_options == {}


@mock.patch("flytekit.configuration.get_config_file")
@mock.patch("os.environ")
def test_get_fsspec_storage_options_gcs_with_overrides(mock_os, mock_get_config_file):
mock_get_config_file.return_value = None
ee = {
"FLYTE_GCP_GSUTIL_PARALLELISM": "False",
}
mock_os.get.side_effect = lambda x, y: ee.get(x)
storage_options = get_fsspec_storage_options("gs", DataConfig.auto(), anonymous=True, other_argument="value")
assert storage_options == {"token": "anon", "other_argument": "value"}


@mock.patch("flytekit.configuration.get_config_file")
@mock.patch("os.environ")
def test_get_fsspec_storage_options_azure(mock_os, mock_get_config_file):
mock_get_config_file.return_value = None
ee = {
"FLYTE_AZURE_STORAGE_ACCOUNT_NAME": "accountname",
"FLYTE_AZURE_STORAGE_ACCOUNT_KEY": "accountkey",
"FLYTE_AZURE_TENANT_ID": "tenantid",
"FLYTE_AZURE_CLIENT_ID": "clientid",
"FLYTE_AZURE_CLIENT_SECRET": "clientsecret",
}
mock_os.get.side_effect = lambda x, y: ee.get(x)
storage_options = get_fsspec_storage_options("abfs", DataConfig.auto())
assert storage_options == {
"account_name": "accountname",
"account_key": "accountkey",
"client_id": "clientid",
"client_secret": "clientsecret",
"tenant_id": "tenantid",
"anon": False,
}


@mock.patch("flytekit.configuration.get_config_file")
@mock.patch("os.environ")
def test_get_fsspec_storage_options_azure_with_overrides(mock_os, mock_get_config_file):
mock_get_config_file.return_value = None
ee = {
"FLYTE_AZURE_STORAGE_ACCOUNT_NAME": "accountname",
"FLYTE_AZURE_STORAGE_ACCOUNT_KEY": "accountkey",
}
mock_os.get.side_effect = lambda x, y: ee.get(x)
storage_options = get_fsspec_storage_options(
"abfs", DataConfig.auto(), anonymous=True, account_name="other_accountname", other_argument="value"
)
assert storage_options == {
"account_name": "other_accountname",
"account_key": "accountkey",
"anon": True,
"other_argument": "value",
}


def test_crawl_local_nt(source_folder):
"""
running this to see what it prints
Expand Down
41 changes: 41 additions & 0 deletions tests/flytekit/unit/core/test_data_persistence.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,8 @@
import os

import mock
from azure.identity import ClientSecretCredential, DefaultAzureCredential

from flytekit.core.data_persistence import FileAccessProvider


Expand All @@ -14,3 +19,39 @@ def test_is_remote():
assert fp.is_remote("/tmp/foo/bar") is False
assert fp.is_remote("file://foo/bar") is False
assert fp.is_remote("s3://my-bucket/foo/bar") is True


def test_initialise_azure_file_provider_with_account_key():
with mock.patch.dict(
os.environ,
{"FLYTE_AZURE_STORAGE_ACCOUNT_NAME": "accountname", "FLYTE_AZURE_STORAGE_ACCOUNT_KEY": "accountkey"},
):
fp = FileAccessProvider("/tmp", "abfs://container/path/within/container")
assert fp.get_filesystem().account_name == "accountname"
assert fp.get_filesystem().account_key == "accountkey"
assert fp.get_filesystem().sync_credential is None


def test_initialise_azure_file_provider_with_service_principal():
with mock.patch.dict(
os.environ,
{
"FLYTE_AZURE_STORAGE_ACCOUNT_NAME": "accountname",
"FLYTE_AZURE_CLIENT_SECRET": "clientsecret",
"FLYTE_AZURE_CLIENT_ID": "clientid",
"FLYTE_AZURE_TENANT_ID": "tenantid",
},
):
fp = FileAccessProvider("/tmp", "abfs://container/path/within/container")
assert fp.get_filesystem().account_name == "accountname"
assert isinstance(fp.get_filesystem().sync_credential, ClientSecretCredential)
assert fp.get_filesystem().client_secret == "clientsecret"
assert fp.get_filesystem().client_id == "clientid"
assert fp.get_filesystem().tenant_id == "tenantid"


def test_initialise_azure_file_provider_with_default_credential():
with mock.patch.dict(os.environ, {"FLYTE_AZURE_STORAGE_ACCOUNT_NAME": "accountname"}):
fp = FileAccessProvider("/tmp", "abfs://container/path/within/container")
assert fp.get_filesystem().account_name == "accountname"
assert isinstance(fp.get_filesystem().sync_credential, DefaultAzureCredential)
Loading

0 comments on commit 483082a

Please sign in to comment.