Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Better Azure blob storage support #1842

Merged
Merged
Show file tree
Hide file tree
Changes from 22 commits
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
ff9c57e
Add Azure blob storage configs
Tom-Newton Sep 12, 2023
2719f6a
Use Azure args
Tom-Newton Sep 12, 2023
1547f88
Add tenant id and correct some typos
Tom-Newton Sep 12, 2023
8ed4a7d
Fix typos
Tom-Newton Sep 12, 2023
172ee25
Extract get_storage_options_for_filesystem
Tom-Newton Sep 12, 2023
c94fb7d
Use `get_storage_options` when encoding structured datasets
Tom-Newton Sep 12, 2023
f8c8c6d
Typo
Tom-Newton Sep 12, 2023
83cfe5b
Remove unused import
Tom-Newton Sep 19, 2023
dc0f424
Handle unrecognised protocol
Tom-Newton Sep 22, 2023
418689f
Handle difference between pandas and fsspec storage_options
Tom-Newton Sep 22, 2023
f9da829
Add test for azure_setup_args
Tom-Newton Sep 22, 2023
2d4f54b
Add tests for initialising Azure filesystem and fix anonymous
Tom-Newton Sep 23, 2023
f846352
Fix anon assertion
Tom-Newton Sep 25, 2023
7ffca40
Autoformat
Tom-Newton Sep 25, 2023
087021d
Improve test coverage for getting fsspec storage options
Tom-Newton Sep 25, 2023
b5df7c1
Re-name env variable names + autoformat
Tom-Newton Sep 25, 2023
c3ee81f
Update tests for new env var names
Tom-Newton Sep 25, 2023
b689b3c
Better mocking of os.environ
Tom-Newton Sep 25, 2023
a279b47
Mostly working test for structured dataset filesystems use
Tom-Newton Sep 25, 2023
80f34d8
Working test for structured dataset azure encode decode
Tom-Newton Sep 25, 2023
71d2486
Fix get_pandas_storage_options
Tom-Newton Sep 25, 2023
5dd00ed
Tidy
Tom-Newton Sep 25, 2023
572725b
Only assert on `storage_options` in structured dataset handler tests
Tom-Newton Sep 26, 2023
72a440e
Type hints
Tom-Newton Sep 27, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 26 additions & 0 deletions flytekit/configuration/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -562,6 +562,30 @@ def auto(cls, config_file: typing.Union[str, ConfigFile] = None) -> GCSConfig:
return GCSConfig(**kwargs)


@dataclass(init=True, repr=True, eq=True, frozen=True)
class AzureBlobStorageConfig(object):
"""
Any Azure Blob Storage specific configuration.
"""

account_name: typing.Optional[str] = None
account_key: typing.Optional[str] = None
tenant_id: typing.Optional[str] = None
client_id: typing.Optional[str] = None
client_secret: typing.Optional[str] = None

@classmethod
def auto(cls, config_file: typing.Union[str, ConfigFile] = None) -> GCSConfig:
config_file = get_config_file(config_file)
kwargs = {}
kwargs = set_if_exists(kwargs, "account_name", _internal.AZURE.STORAGE_ACCOUNT_NAME.read(config_file))
kwargs = set_if_exists(kwargs, "account_key", _internal.AZURE.STORAGE_ACCOUNT_KEY.read(config_file))
kwargs = set_if_exists(kwargs, "tenant_id", _internal.AZURE.TENANT_ID.read(config_file))
kwargs = set_if_exists(kwargs, "client_id", _internal.AZURE.CLIENT_ID.read(config_file))
kwargs = set_if_exists(kwargs, "client_secret", _internal.AZURE.CLIENT_SECRET.read(config_file))
return AzureBlobStorageConfig(**kwargs)


@dataclass(init=True, repr=True, eq=True, frozen=True)
class DataConfig(object):
"""
Expand All @@ -572,11 +596,13 @@ class DataConfig(object):

s3: S3Config = S3Config()
gcs: GCSConfig = GCSConfig()
azure: AzureBlobStorageConfig = AzureBlobStorageConfig()

@classmethod
def auto(cls, config_file: typing.Union[str, ConfigFile] = None) -> DataConfig:
config_file = get_config_file(config_file)
return DataConfig(
azure=AzureBlobStorageConfig.auto(config_file),
s3=S3Config.auto(config_file),
gcs=GCSConfig.auto(config_file),
)
Expand Down
9 changes: 9 additions & 0 deletions flytekit/configuration/internal.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,15 @@ class GCP(object):
GSUTIL_PARALLELISM = ConfigEntry(LegacyConfigEntry(SECTION, "gsutil_parallelism", bool))


class AZURE(object):
SECTION = "azure"
STORAGE_ACCOUNT_NAME = ConfigEntry(LegacyConfigEntry(SECTION, "storage_account_name"))
Tom-Newton marked this conversation as resolved.
Show resolved Hide resolved
STORAGE_ACCOUNT_KEY = ConfigEntry(LegacyConfigEntry(SECTION, "storage_account_key"))
TENANT_ID = ConfigEntry(LegacyConfigEntry(SECTION, "tenant_id"))
CLIENT_ID = ConfigEntry(LegacyConfigEntry(SECTION, "client_id"))
CLIENT_SECRET = ConfigEntry(LegacyConfigEntry(SECTION, "client_secret"))


class Credentials(object):
SECTION = "credentials"
COMMAND = ConfigEntry(LegacyConfigEntry(SECTION, "command", list), YamlConfigEntry("admin.command", list))
Expand Down
60 changes: 41 additions & 19 deletions flytekit/core/data_persistence.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,41 @@ def s3_setup_args(s3_cfg: configuration.S3Config, anonymous: bool = False):
return kwargs


def azure_setup_args(azure_cfg: configuration.AzureBlobStorageConfig, anonymous: bool = False):
Tom-Newton marked this conversation as resolved.
Show resolved Hide resolved
kwargs: Dict[str, Any] = {}

if azure_cfg.account_name:
kwargs["account_name"] = azure_cfg.account_name
if azure_cfg.account_key:
kwargs["account_key"] = azure_cfg.account_key
if azure_cfg.client_id:
kwargs["client_id"] = azure_cfg.client_id
if azure_cfg.client_secret:
kwargs["client_secret"] = azure_cfg.client_secret
if azure_cfg.tenant_id:
kwargs["tenant_id"] = azure_cfg.tenant_id
kwargs[_ANON] = anonymous
return kwargs


def get_fsspec_storage_options(
protocol: str, data_config: typing.Optional[DataConfig] = None, anonymous: bool = False, **kwargs
) -> typing.Dict:
Tom-Newton marked this conversation as resolved.
Show resolved Hide resolved
data_config = data_config or DataConfig.auto()

if protocol == "file":
return {"auto_mkdir": True, **kwargs}
if protocol == "s3":
return {**s3_setup_args(data_config.s3, anonymous=anonymous), **kwargs}
if protocol == "gs":
if anonymous:
kwargs["token"] = _ANON
return kwargs
if protocol in ("abfs", "abfss"):
return {**azure_setup_args(data_config.azure, anonymous=anonymous), **kwargs}
return {}


class FileAccessProvider(object):
"""
This is the class that is available through the FlyteContext and can be used for persisting data to the remote
Expand Down Expand Up @@ -106,28 +141,15 @@ def data_config(self) -> DataConfig:

def get_filesystem(
self, protocol: typing.Optional[str] = None, anonymous: bool = False, **kwargs
) -> typing.Optional[fsspec.AbstractFileSystem]:
) -> fsspec.AbstractFileSystem:
if not protocol:
return self._default_remote
if protocol == "file":
kwargs["auto_mkdir"] = True
elif protocol == "s3":
s3kwargs = s3_setup_args(self._data_config.s3, anonymous=anonymous)
s3kwargs.update(kwargs)
return fsspec.filesystem(protocol, **s3kwargs) # type: ignore
elif protocol == "gs":
if anonymous:
kwargs["token"] = _ANON
return fsspec.filesystem(protocol, **kwargs) # type: ignore
elif protocol == "abfs":
kwargs["anon"] = False
return fsspec.filesystem(protocol, **kwargs) # type: ignore

# Preserve old behavior of returning None for file systems that don't have an explicit anonymous option.
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

From the type hint on get_filesystem_for_path it looks like we assume get_filesystem never returns None so I thought probably best to delete this. If anyone thinks its important I'm happy to add it back.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@wild-endeavor I think you wanted to delete this already? So this seems like a good idea

if anonymous:
return None

return fsspec.filesystem(protocol, **kwargs) # type: ignore
storage_options = get_fsspec_storage_options(
protocol=protocol, anonymous=anonymous, data_config=self._data_config, **kwargs
)

return fsspec.filesystem(protocol, **storage_options)

def get_filesystem_for_path(self, path: str = "", anonymous: bool = False, **kwargs) -> fsspec.AbstractFileSystem:
protocol = get_protocol(path)
Expand Down
27 changes: 14 additions & 13 deletions flytekit/types/structured/basic_dfs.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@

from flytekit import FlyteContext, logger
from flytekit.configuration import DataConfig
from flytekit.core.data_persistence import s3_setup_args
from flytekit.core.data_persistence import get_fsspec_storage_options
from flytekit.models import literals
from flytekit.models.literals import StructuredDatasetMetadata
from flytekit.models.types import StructuredDatasetType
Expand All @@ -27,12 +27,13 @@
T = TypeVar("T")


def get_storage_options(cfg: DataConfig, uri: str, anon: bool = False) -> typing.Optional[typing.Dict]:
protocol = get_protocol(uri)
if protocol == "s3":
kwargs = s3_setup_args(cfg.s3, anon)
if kwargs:
return kwargs
def get_pandas_storage_options(
uri: str, data_config: DataConfig, anonymous: bool = False
) -> typing.Optional[typing.Dict]:
if pd.io.common.is_fsspec_url(uri):
return get_fsspec_storage_options(protocol=get_protocol(uri), data_config=data_config, anonymous=anonymous)

# Pandas does not allow storage_options for non-fsspec paths e.g. local.
return None


Expand All @@ -54,7 +55,7 @@ def encode(
df.to_csv(
path,
index=False,
storage_options=get_storage_options(ctx.file_access.data_config, path),
storage_options=get_pandas_storage_options(uri=path, data_config=ctx.file_access.data_config),
)
structured_dataset_type.format = CSV
return literals.StructuredDataset(uri=uri, metadata=StructuredDatasetMetadata(structured_dataset_type))
Expand All @@ -72,15 +73,15 @@ def decode(
) -> pd.DataFrame:
uri = flyte_value.uri
columns = None
kwargs = get_storage_options(ctx.file_access.data_config, uri)
kwargs = get_pandas_storage_options(uri=uri, data_config=ctx.file_access.data_config)
path = os.path.join(uri, ".csv")
if current_task_metadata.structured_dataset_type and current_task_metadata.structured_dataset_type.columns:
columns = [c.name for c in current_task_metadata.structured_dataset_type.columns]
try:
return pd.read_csv(path, usecols=columns, storage_options=kwargs)
except NoCredentialsError:
logger.debug("S3 source detected, attempting anonymous S3 access")
kwargs = get_storage_options(ctx.file_access.data_config, uri, anon=True)
kwargs = get_pandas_storage_options(uri=uri, data_config=ctx.file_access.data_config, anonymous=True)
return pd.read_csv(path, usecols=columns, storage_options=kwargs)


Expand All @@ -103,7 +104,7 @@ def encode(
path,
coerce_timestamps="us",
allow_truncated_timestamps=False,
storage_options=get_storage_options(ctx.file_access.data_config, path),
storage_options=get_pandas_storage_options(uri=path, data_config=ctx.file_access.data_config),
)
structured_dataset_type.format = PARQUET
return literals.StructuredDataset(uri=uri, metadata=StructuredDatasetMetadata(structured_dataset_type))
Expand All @@ -121,14 +122,14 @@ def decode(
) -> pd.DataFrame:
uri = flyte_value.uri
columns = None
kwargs = get_storage_options(ctx.file_access.data_config, uri)
kwargs = get_pandas_storage_options(uri=uri, data_config=ctx.file_access.data_config)
if current_task_metadata.structured_dataset_type and current_task_metadata.structured_dataset_type.columns:
columns = [c.name for c in current_task_metadata.structured_dataset_type.columns]
try:
return pd.read_parquet(uri, columns=columns, storage_options=kwargs)
except NoCredentialsError:
logger.debug("S3 source detected, attempting anonymous S3 access")
kwargs = get_storage_options(ctx.file_access.data_config, uri, anon=True)
kwargs = get_pandas_storage_options(uri=uri, data_config=ctx.file_access.data_config, anonymous=True)
return pd.read_parquet(uri, columns=columns, storage_options=kwargs)


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,13 @@

import pandas as pd
import polars as pl
from fsspec.utils import get_protocol

from flytekit import FlyteContext
from flytekit.core.data_persistence import get_fsspec_storage_options
from flytekit.models import literals
from flytekit.models.literals import StructuredDatasetMetadata
from flytekit.models.types import StructuredDatasetType
from flytekit.types.structured.basic_dfs import get_storage_options
from flytekit.types.structured.structured_dataset import (
PARQUET,
StructuredDataset,
Expand Down Expand Up @@ -64,7 +65,7 @@ def decode(
current_task_metadata: StructuredDatasetMetadata,
) -> pl.DataFrame:
uri = flyte_value.uri
kwargs = get_storage_options(ctx.file_access.data_config, uri)
kwargs = get_fsspec_storage_options(protocol=get_protocol(uri), data_config=ctx.file_access.data_config)
if current_task_metadata.structured_dataset_type and current_task_metadata.structured_dataset_type.columns:
columns = [c.name for c in current_task_metadata.structured_dataset_type.columns]
return pl.read_parquet(uri, columns=columns, use_pyarrow=True, storage_options=kwargs)
Expand Down
76 changes: 74 additions & 2 deletions tests/flytekit/unit/core/test_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,14 @@
import mock
import pytest

from flytekit.configuration import Config, S3Config
from flytekit.configuration import Config, DataConfig, S3Config
from flytekit.core.context_manager import FlyteContextManager
from flytekit.core.data_persistence import FileAccessProvider, default_local_file_access_provider, s3_setup_args
from flytekit.core.data_persistence import (
FileAccessProvider,
default_local_file_access_provider,
get_fsspec_storage_options,
s3_setup_args,
)
from flytekit.types.directory.types import FlyteDirectory

local = fsspec.filesystem("file")
Expand Down Expand Up @@ -221,6 +226,73 @@ def test_s3_setup_args_env_aws(mock_os, mock_get_config_file):
assert kwargs == {"cache_regions": True}


@mock.patch("flytekit.configuration.get_config_file")
@mock.patch("os.environ")
def test_get_fsspec_storage_options_gcs(mock_os, mock_get_config_file):
mock_get_config_file.return_value = None
ee = {
"FLYTE_GCP_GSUTIL_PARALLELISM": "False",
}
mock_os.get.side_effect = lambda x, y: ee.get(x)
storage_options = get_fsspec_storage_options("gs", DataConfig.auto())
assert storage_options == {}


@mock.patch("flytekit.configuration.get_config_file")
@mock.patch("os.environ")
def test_get_fsspec_storage_options_gcs_with_overrides(mock_os, mock_get_config_file):
mock_get_config_file.return_value = None
ee = {
"FLYTE_GCP_GSUTIL_PARALLELISM": "False",
}
mock_os.get.side_effect = lambda x, y: ee.get(x)
storage_options = get_fsspec_storage_options("gs", DataConfig.auto(), anonymous=True, other_argument="value")
assert storage_options == {"token": "anon", "other_argument": "value"}


@mock.patch("flytekit.configuration.get_config_file")
@mock.patch("os.environ")
def test_get_fsspec_storage_options_azure(mock_os, mock_get_config_file):
mock_get_config_file.return_value = None
ee = {
"FLYTE_AZURE_STORAGE_ACCOUNT_NAME": "accountname",
"FLYTE_AZURE_STORAGE_ACCOUNT_KEY": "accountkey",
"FLYTE_AZURE_TENANT_ID": "tenantid",
"FLYTE_AZURE_CLIENT_ID": "clientid",
"FLYTE_AZURE_CLIENT_SECRET": "clientsecret",
}
mock_os.get.side_effect = lambda x, y: ee.get(x)
storage_options = get_fsspec_storage_options("abfs", DataConfig.auto())
assert storage_options == {
"account_name": "accountname",
"account_key": "accountkey",
"client_id": "clientid",
"client_secret": "clientsecret",
"tenant_id": "tenantid",
"anon": False,
}


@mock.patch("flytekit.configuration.get_config_file")
@mock.patch("os.environ")
def test_get_fsspec_storage_options_azure_with_overrides(mock_os, mock_get_config_file):
mock_get_config_file.return_value = None
ee = {
"FLYTE_AZURE_STORAGE_ACCOUNT_NAME": "accountname",
"FLYTE_AZURE_STORAGE_ACCOUNT_KEY": "accountkey",
}
mock_os.get.side_effect = lambda x, y: ee.get(x)
storage_options = get_fsspec_storage_options(
"abfs", DataConfig.auto(), anonymous=True, account_name="other_accountname", other_argument="value"
)
assert storage_options == {
"account_name": "other_accountname",
"account_key": "accountkey",
"anon": True,
"other_argument": "value",
}


def test_crawl_local_nt(source_folder):
"""
running this to see what it prints
Expand Down
41 changes: 41 additions & 0 deletions tests/flytekit/unit/core/test_data_persistence.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,8 @@
import os

import mock
from azure.identity import ClientSecretCredential, DefaultAzureCredential

from flytekit.core.data_persistence import FileAccessProvider


Expand All @@ -14,3 +19,39 @@ def test_is_remote():
assert fp.is_remote("/tmp/foo/bar") is False
assert fp.is_remote("file://foo/bar") is False
assert fp.is_remote("s3://my-bucket/foo/bar") is True


def test_initialise_azure_file_provider_with_account_key():
with mock.patch.dict(
os.environ,
{"FLYTE_AZURE_STORAGE_ACCOUNT_NAME": "accountname", "FLYTE_AZURE_STORAGE_ACCOUNT_KEY": "accountkey"},
):
fp = FileAccessProvider("/tmp", "abfs://container/path/within/container")
assert fp.get_filesystem().account_name == "accountname"
assert fp.get_filesystem().account_key == "accountkey"
assert fp.get_filesystem().sync_credential is None


def test_initialise_azure_file_provider_with_service_principal():
with mock.patch.dict(
os.environ,
{
"FLYTE_AZURE_STORAGE_ACCOUNT_NAME": "accountname",
"FLYTE_AZURE_CLIENT_SECRET": "clientsecret",
"FLYTE_AZURE_CLIENT_ID": "clientid",
"FLYTE_AZURE_TENANT_ID": "tenantid",
},
):
fp = FileAccessProvider("/tmp", "abfs://container/path/within/container")
assert fp.get_filesystem().account_name == "accountname"
assert isinstance(fp.get_filesystem().sync_credential, ClientSecretCredential)
assert fp.get_filesystem().client_secret == "clientsecret"
assert fp.get_filesystem().client_id == "clientid"
assert fp.get_filesystem().tenant_id == "tenantid"


def test_initialise_azure_file_provider_with_default_credential():
with mock.patch.dict(os.environ, {"FLYTE_AZURE_STORAGE_ACCOUNT_NAME": "accountname"}):
fp = FileAccessProvider("/tmp", "abfs://container/path/within/container")
assert fp.get_filesystem().account_name == "accountname"
assert isinstance(fp.get_filesystem().sync_credential, DefaultAzureCredential)
Loading