Skip to content

Commit

Permalink
Async/data persistence (#2829)
Browse files Browse the repository at this point in the history
Signed-off-by: Yee Hing Tong <[email protected]>
  • Loading branch information
wild-endeavor authored Oct 22, 2024
1 parent f1adbe8 commit 3fc51af
Show file tree
Hide file tree
Showing 19 changed files with 266 additions and 98 deletions.
104 changes: 76 additions & 28 deletions flytekit/core/data_persistence.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
"""

import asyncio
import io
import os
import pathlib
Expand All @@ -29,6 +30,7 @@

import fsspec
from decorator import decorator
from fsspec.asyn import AsyncFileSystem
from fsspec.utils import get_protocol
from typing_extensions import Unpack

Expand All @@ -40,6 +42,7 @@
from flytekit.exceptions.user import FlyteAssertion, FlyteDataNotFoundException
from flytekit.interfaces.random import random
from flytekit.loggers import logger
from flytekit.utils.asyn import loop_manager

# Refer to https://github.com/fsspec/s3fs/blob/50bafe4d8766c3b2a4e1fc09669cf02fb2d71454/s3fs/core.py#L198
# for key and secret
Expand Down Expand Up @@ -208,8 +211,17 @@ def get_filesystem(
storage_options = get_fsspec_storage_options(
protocol=protocol, anonymous=anonymous, data_config=self._data_config, **kwargs
)
kwargs.update(storage_options)

return fsspec.filesystem(protocol, **storage_options)
return fsspec.filesystem(protocol, **kwargs)

async def get_async_filesystem_for_path(
self, path: str = "", anonymous: bool = False, **kwargs
) -> Union[AsyncFileSystem, fsspec.AbstractFileSystem]:
protocol = get_protocol(path)
loop = asyncio.get_running_loop()

return self.get_filesystem(protocol, anonymous=anonymous, path=path, asynchronous=True, loop=loop, **kwargs)

def get_filesystem_for_path(self, path: str = "", anonymous: bool = False, **kwargs) -> fsspec.AbstractFileSystem:
protocol = get_protocol(path)
Expand Down Expand Up @@ -282,8 +294,8 @@ def exists(self, path: str) -> bool:
raise oe

@retry_request
def get(self, from_path: str, to_path: str, recursive: bool = False, **kwargs):
file_system = self.get_filesystem_for_path(from_path)
async def get(self, from_path: str, to_path: str, recursive: bool = False, **kwargs):
file_system = await self.get_async_filesystem_for_path(from_path)
if recursive:
from_path, to_path = self.recursive_paths(from_path, to_path)
try:
Expand All @@ -294,23 +306,33 @@ def get(self, from_path: str, to_path: str, recursive: bool = False, **kwargs):
self.strip_file_header(from_path), self.strip_file_header(to_path), dirs_exist_ok=True
)
logger.info(f"Getting {from_path} to {to_path}")
dst = file_system.get(from_path, to_path, recursive=recursive, **kwargs)
if isinstance(file_system, AsyncFileSystem):
dst = await file_system._get(from_path, to_path, recursive=recursive, **kwargs) # pylint: disable=W0212
else:
dst = file_system.get(from_path, to_path, recursive=recursive, **kwargs)
if isinstance(dst, (str, pathlib.Path)):
return dst
return to_path
except OSError as oe:
logger.debug(f"Error in getting {from_path} to {to_path} rec {recursive} {oe}")
if not file_system.exists(from_path):
raise FlyteDataNotFoundException(from_path)
file_system = self.get_filesystem(get_protocol(from_path), anonymous=True)
file_system = self.get_filesystem(get_protocol(from_path), anonymous=True, asynchronous=True)
if file_system is not None:
logger.debug(f"Attempting anonymous get with {file_system}")
return file_system.get(from_path, to_path, recursive=recursive, **kwargs)
if isinstance(file_system, AsyncFileSystem):
return await file_system._get(from_path, to_path, recursive=recursive, **kwargs) # pylint: disable=W0212
else:
return file_system.get(from_path, to_path, recursive=recursive, **kwargs)
raise oe

@retry_request
def put(self, from_path: str, to_path: str, recursive: bool = False, **kwargs):
file_system = self.get_filesystem_for_path(to_path)
async def _put(self, from_path: str, to_path: str, recursive: bool = False, **kwargs):
"""
More of an internal function to be called by put_data and put_raw_data
This does not need a separate sync function.
"""
file_system = await self.get_async_filesystem_for_path(to_path)
from_path = self.strip_file_header(from_path)
if recursive:
# Only check this for the local filesystem
Expand All @@ -327,13 +349,16 @@ def put(self, from_path: str, to_path: str, recursive: bool = False, **kwargs):
if "metadata" not in kwargs:
kwargs["metadata"] = {}
kwargs["metadata"].update(self._execution_metadata)
dst = file_system.put(from_path, to_path, recursive=recursive, **kwargs)
if isinstance(file_system, AsyncFileSystem):
dst = await file_system._put(from_path, to_path, recursive=recursive, **kwargs) # pylint: disable=W0212
else:
dst = file_system.put(from_path, to_path, recursive=recursive, **kwargs)
if isinstance(dst, (str, pathlib.Path)):
return dst
else:
return to_path

def put_raw_data(
async def async_put_raw_data(
self,
lpath: Uploadable,
upload_prefix: Optional[str] = None,
Expand Down Expand Up @@ -364,7 +389,7 @@ def put_raw_data(
:param read_chunk_size_bytes: If lpath is a buffer, this is the chunk size to read from it
:param encoding: If lpath is a io.StringIO, this is the encoding to use to encode it to binary.
:param skip_raw_data_prefix: If True, the raw data prefix will not be prepended to the upload_prefix
:param kwargs: Additional kwargs are passed into the the fsspec put() call or the open() call
:param kwargs: Additional kwargs are passed into the fsspec put() call or the open() call
:return: Returns the final path data was written to.
"""
# First figure out what the destination path should be, then call put.
Expand All @@ -388,42 +413,60 @@ def put_raw_data(
raise FlyteAssertion(f"File {from_path} is a symlink, can't upload")
if p.is_dir():
logger.debug(f"Detected directory {from_path}, using recursive put")
r = self.put(from_path, to_path, recursive=True, **kwargs)
r = await self._put(from_path, to_path, recursive=True, **kwargs)
else:
logger.debug(f"Detected file {from_path}, call put non-recursive")
r = self.put(from_path, to_path, **kwargs)
r = await self._put(from_path, to_path, **kwargs)
return r or to_path

# raw bytes
if isinstance(lpath, bytes):
fs = self.get_filesystem_for_path(to_path)
with fs.open(to_path, "wb", **kwargs) as s:
s.write(lpath)
fs = await self.get_async_filesystem_for_path(to_path)
if isinstance(fs, AsyncFileSystem):
async with fs.open_async(to_path, "wb", **kwargs) as s:
s.write(lpath)
else:
with fs.open(to_path, "wb", **kwargs) as s:
s.write(lpath)

return to_path

# If lpath is a buffered reader of some kind
if isinstance(lpath, io.BufferedReader) or isinstance(lpath, io.BytesIO):
if not lpath.readable():
raise FlyteAssertion("Buffered reader must be readable")
fs = self.get_filesystem_for_path(to_path)
fs = await self.get_async_filesystem_for_path(to_path)
lpath.seek(0)
with fs.open(to_path, "wb", **kwargs) as s:
while data := lpath.read(read_chunk_size_bytes):
s.write(data)
if isinstance(fs, AsyncFileSystem):
async with fs.open_async(to_path, "wb", **kwargs) as s:
while data := lpath.read(read_chunk_size_bytes):
s.write(data)
else:
with fs.open(to_path, "wb", **kwargs) as s:
while data := lpath.read(read_chunk_size_bytes):
s.write(data)
return to_path

if isinstance(lpath, io.StringIO):
if not lpath.readable():
raise FlyteAssertion("Buffered reader must be readable")
fs = self.get_filesystem_for_path(to_path)
fs = await self.get_async_filesystem_for_path(to_path)
lpath.seek(0)
with fs.open(to_path, "wb", **kwargs) as s:
while data_str := lpath.read(read_chunk_size_bytes):
s.write(data_str.encode(encoding))
if isinstance(fs, AsyncFileSystem):
async with fs.open_async(to_path, "wb", **kwargs) as s:
while data_str := lpath.read(read_chunk_size_bytes):
s.write(data_str.encode(encoding))
else:
with fs.open(to_path, "wb", **kwargs) as s:
while data_str := lpath.read(read_chunk_size_bytes):
s.write(data_str.encode(encoding))
return to_path

raise FlyteAssertion(f"Unsupported lpath type {type(lpath)}")

# Public synchronous version
put_raw_data = loop_manager.synced(async_put_raw_data)

@staticmethod
def get_random_string() -> str:
return UUID(int=random.getrandbits(128)).hex
Expand Down Expand Up @@ -549,7 +592,7 @@ def upload_directory(self, local_path: str, remote_path: str, **kwargs):
"""
return self.put_data(local_path, remote_path, is_multipart=True, **kwargs)

def get_data(self, remote_path: str, local_path: str, is_multipart: bool = False, **kwargs):
async def async_get_data(self, remote_path: str, local_path: str, is_multipart: bool = False, **kwargs):
"""
:param remote_path:
:param local_path:
Expand All @@ -558,7 +601,7 @@ def get_data(self, remote_path: str, local_path: str, is_multipart: bool = False
try:
pathlib.Path(local_path).parent.mkdir(parents=True, exist_ok=True)
with timeit(f"Download data to local from {remote_path}"):
self.get(remote_path, to_path=local_path, recursive=is_multipart, **kwargs)
await self.get(remote_path, to_path=local_path, recursive=is_multipart, **kwargs)
except FlyteDataNotFoundException:
raise
except Exception as ex:
Expand All @@ -567,7 +610,9 @@ def get_data(self, remote_path: str, local_path: str, is_multipart: bool = False
f"Original exception: {str(ex)}"
)

def put_data(
get_data = loop_manager.synced(async_get_data)

async def async_put_data(
self, local_path: Union[str, os.PathLike], remote_path: str, is_multipart: bool = False, **kwargs
) -> str:
"""
Expand All @@ -581,7 +626,7 @@ def put_data(
try:
local_path = str(local_path)
with timeit(f"Upload data to {remote_path}"):
put_result = self.put(cast(str, local_path), remote_path, recursive=is_multipart, **kwargs)
put_result = await self._put(cast(str, local_path), remote_path, recursive=is_multipart, **kwargs)
# This is an unfortunate workaround to ensure that we return the correct path for the remote location
# Callers of this put_data function in flytekit have been changed to assign the remote path to the
# output
Expand All @@ -595,6 +640,9 @@ def put_data(
f"Original exception: {str(ex)}"
) from ex

# Public synchronous version
put_data = loop_manager.synced(async_put_data)


flyte_tmp_dir = tempfile.mkdtemp(prefix="flyte-")
default_local_file_access_provider = FileAccessProvider(
Expand Down
10 changes: 6 additions & 4 deletions flytekit/core/type_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -1908,7 +1908,9 @@ def extract_types_or_metadata(t: Optional[Type[dict]]) -> typing.Tuple:
return None, None

@staticmethod
def dict_to_binary_literal(ctx: FlyteContext, v: dict, python_type: Type[dict], allow_pickle: bool) -> Literal:
async def dict_to_binary_literal(
ctx: FlyteContext, v: dict, python_type: Type[dict], allow_pickle: bool
) -> Literal:
"""
Converts a Python dictionary to a Flyte-specific ``Literal`` using MessagePack encoding.
Falls back to Pickle if encoding fails and `allow_pickle` is True.
Expand All @@ -1922,7 +1924,7 @@ def dict_to_binary_literal(ctx: FlyteContext, v: dict, python_type: Type[dict],
return Literal(scalar=Scalar(binary=Binary(value=msgpack_bytes, tag="msgpack")))
except TypeError as e:
if allow_pickle:
remote_path = FlytePickle.to_pickle(ctx, v)
remote_path = await FlytePickle.to_pickle(ctx, v)
return Literal(
scalar=Scalar(
generic=_json_format.Parse(json.dumps({"pickle_file": remote_path}), _struct.Struct())
Expand Down Expand Up @@ -1980,7 +1982,7 @@ async def async_to_literal(
allow_pickle, base_type = DictTransformer.is_pickle(python_type)

if expected and expected.simple and expected.simple == SimpleType.STRUCT:
return self.dict_to_binary_literal(ctx, python_val, python_type, allow_pickle)
return await self.dict_to_binary_literal(ctx, python_val, python_type, allow_pickle)

lit_map = {}
for k, v in python_val.items():
Expand Down Expand Up @@ -2036,7 +2038,7 @@ async def async_to_python_value(self, ctx: FlyteContext, lv: Literal, expected_p
from flytekit.types.pickle import FlytePickle

uri = json.loads(_json_format.MessageToJson(lv.scalar.generic)).get("pickle_file")
return FlytePickle.from_pickle(uri)
return await FlytePickle.from_pickle(uri)

try:
return json.loads(_json_format.MessageToJson(lv.scalar.generic))
Expand Down
2 changes: 1 addition & 1 deletion flytekit/extend/backend/base_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -368,7 +368,7 @@ async def _create(
literal_map = await TypeEngine._dict_to_literal_map(ctx, inputs or {}, self.get_input_types())
path = ctx.file_access.get_random_local_path()
utils.write_proto_to_file(literal_map.to_flyte_idl(), path)
ctx.file_access.put_data(path, f"{output_prefix}/inputs.pb")
await ctx.file_access.async_put_data(path, f"{output_prefix}/inputs.pb")
task_template = render_task_template(task_template, output_prefix)
else:
literal_map = TypeEngine.dict_to_literal_map(ctx, inputs or {}, self.get_input_types())
Expand Down
12 changes: 6 additions & 6 deletions flytekit/extras/tensorflow/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,13 @@
import tensorflow as tf

from flytekit.core.context_manager import FlyteContext
from flytekit.core.type_engine import TypeEngine, TypeTransformer, TypeTransformerFailedError
from flytekit.core.type_engine import AsyncTypeTransformer, TypeEngine, TypeTransformerFailedError
from flytekit.models.core import types as _core_types
from flytekit.models.literals import Blob, BlobMetadata, Literal, Scalar
from flytekit.models.types import LiteralType


class TensorFlowModelTransformer(TypeTransformer[tf.keras.Model]):
class TensorFlowModelTransformer(AsyncTypeTransformer[tf.keras.Model]):
TENSORFLOW_FORMAT = "TensorFlowModel"

def __init__(self):
Expand All @@ -24,7 +24,7 @@ def get_literal_type(self, t: Type[tf.keras.Model]) -> LiteralType:
)
)

def to_literal(
async def async_to_literal(
self,
ctx: FlyteContext,
python_val: tf.keras.Model,
Expand All @@ -44,10 +44,10 @@ def to_literal(
# save model in SavedModel format
tf.keras.models.save_model(python_val, local_path)

remote_path = ctx.file_access.put_raw_data(local_path)
remote_path = await ctx.file_access.async_put_raw_data(local_path)
return Literal(scalar=Scalar(blob=Blob(metadata=meta, uri=remote_path)))

def to_python_value(
async def async_to_python_value(
self, ctx: FlyteContext, lv: Literal, expected_python_type: Type[tf.keras.Model]
) -> tf.keras.Model:
try:
Expand All @@ -56,7 +56,7 @@ def to_python_value(
TypeTransformerFailedError(f"Cannot convert from {lv} to {expected_python_type}")

local_path = ctx.file_access.get_random_local_path()
ctx.file_access.get_data(uri, local_path, is_multipart=True)
await ctx.file_access.async_get_data(uri, local_path, is_multipart=True)

# load model
return tf.keras.models.load_model(local_path)
Expand Down
12 changes: 7 additions & 5 deletions flytekit/types/directory/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from flytekit import BlobType
from flytekit.core.constants import MESSAGEPACK
from flytekit.core.context_manager import FlyteContext, FlyteContextManager
from flytekit.core.type_engine import TypeEngine, TypeTransformer, TypeTransformerFailedError, get_batch_size
from flytekit.core.type_engine import AsyncTypeTransformer, TypeEngine, TypeTransformerFailedError, get_batch_size
from flytekit.exceptions.user import FlyteAssertion
from flytekit.models import types as _type_models
from flytekit.models.core import types as _core_types
Expand Down Expand Up @@ -407,7 +407,7 @@ def __str__(self):
return str(self.path)


class FlyteDirToMultipartBlobTransformer(TypeTransformer[FlyteDirectory]):
class FlyteDirToMultipartBlobTransformer(AsyncTypeTransformer[FlyteDirectory]):
"""
This transformer handles conversion between the Python native FlyteDirectory class defined above, and the Flyte
IDL literal/type of Multipart Blob. Please see the FlyteDirectory comments for additional information.
Expand Down Expand Up @@ -444,7 +444,7 @@ def assert_type(self, t: typing.Type[FlyteDirectory], v: typing.Union[FlyteDirec
def get_literal_type(self, t: typing.Type[FlyteDirectory]) -> LiteralType:
return _type_models.LiteralType(blob=self._blob_type(format=FlyteDirToMultipartBlobTransformer.get_format(t)))

def to_literal(
async def async_to_literal(
self,
ctx: FlyteContext,
python_val: FlyteDirectory,
Expand Down Expand Up @@ -499,7 +499,9 @@ def to_literal(
remote_directory = ctx.file_access.get_random_remote_directory()
if not pathlib.Path(source_path).is_dir():
raise FlyteAssertion("Expected a directory. {} is not a directory".format(source_path))
ctx.file_access.put_data(source_path, remote_directory, is_multipart=True, batch_size=batch_size)
await ctx.file_access.async_put_data(
source_path, remote_directory, is_multipart=True, batch_size=batch_size
)
return Literal(scalar=Scalar(blob=Blob(metadata=meta, uri=remote_directory)))

# If not uploading, then we can only take the original source path as the uri.
Expand Down Expand Up @@ -535,7 +537,7 @@ def from_binary_idl(
else:
raise TypeTransformerFailedError(f"Unsupported binary format: `{binary_idl_object.tag}`")

def to_python_value(
async def async_to_python_value(
self, ctx: FlyteContext, lv: Literal, expected_python_type: typing.Type[FlyteDirectory]
) -> FlyteDirectory:
if lv.scalar.binary:
Expand Down
Loading

0 comments on commit 3fc51af

Please sign in to comment.