Skip to content

Commit

Permalink
[FlyteSchema] Fix numpy problems (#2619)
Browse files Browse the repository at this point in the history
Signed-off-by: Future-Outlier <[email protected]>
  • Loading branch information
Future-Outlier authored Jul 29, 2024
1 parent 77d056a commit d507328
Show file tree
Hide file tree
Showing 6 changed files with 69 additions and 21 deletions.
5 changes: 3 additions & 2 deletions flytekit/core/type_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -1495,6 +1495,7 @@ def to_literal(self, ctx: FlyteContext, python_val: T, python_type: Type[T], exp
is_ambiguous = False
res = None
res_type = None
t = None
for i in range(len(get_args(python_type))):
try:
t = get_args(python_type)[i]
Expand All @@ -1504,8 +1505,8 @@ def to_literal(self, ctx: FlyteContext, python_val: T, python_type: Type[T], exp
if found_res:
is_ambiguous = True
found_res = True
except Exception:
logger.debug(f"Failed to convert from {python_val} to {t}", exc_info=True)
except Exception as e:
logger.debug(f"Failed to convert from {python_val} to {t} with error: {e}", exc_info=True)
continue

if is_ambiguous:
Expand Down
6 changes: 4 additions & 2 deletions flytekit/interaction/click_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,7 +175,7 @@ def convert(
if isinstance(value, ArtifactQuery):
return value

if " " in value:
if isinstance(value, str) and " " in value:
import re

m = re.match(self._FLOATING_FORMAT_PATTERN, value)
Expand All @@ -193,7 +193,9 @@ def convert(
if parts[1] == "-":
return dt - delta
return dt + delta
raise click.BadParameter(f"Expected format {self.formats}, got {value}")
else:
value = datetime.datetime.fromisoformat(value)

return self._datetime_from_format(value, param, ctx)


Expand Down
41 changes: 26 additions & 15 deletions flytekit/types/schema/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
from pathlib import Path
from typing import Type

import numpy as _np
from dataclasses_json import config
from marshmallow import fields
from mashumaro.mixins.json import DataClassJSONMixin
Expand Down Expand Up @@ -349,27 +348,39 @@ def as_readonly(self) -> FlyteSchema:
return s


def _get_numpy_type_mappings() -> typing.Dict[Type, SchemaType.SchemaColumn.SchemaColumnType]:
try:
import numpy as _np

return {
_np.int32: SchemaType.SchemaColumn.SchemaColumnType.INTEGER,
_np.int64: SchemaType.SchemaColumn.SchemaColumnType.INTEGER,
_np.uint32: SchemaType.SchemaColumn.SchemaColumnType.INTEGER,
_np.uint64: SchemaType.SchemaColumn.SchemaColumnType.INTEGER,
_np.float32: SchemaType.SchemaColumn.SchemaColumnType.FLOAT,
_np.float64: SchemaType.SchemaColumn.SchemaColumnType.FLOAT,
_np.bool_: SchemaType.SchemaColumn.SchemaColumnType.BOOLEAN, # type: ignore
_np.datetime64: SchemaType.SchemaColumn.SchemaColumnType.DATETIME,
_np.timedelta64: SchemaType.SchemaColumn.SchemaColumnType.DURATION,
_np.bytes_: SchemaType.SchemaColumn.SchemaColumnType.STRING,
_np.str_: SchemaType.SchemaColumn.SchemaColumnType.STRING,
_np.object_: SchemaType.SchemaColumn.SchemaColumnType.STRING,
}
except ImportError as e:
logger.warning("Numpy not found, skipping numpy type mappings, error: %s", e)
return {}


class FlyteSchemaTransformer(TypeTransformer[FlyteSchema]):
_SUPPORTED_TYPES: typing.Dict[Type, SchemaType.SchemaColumn.SchemaColumnType] = {
_np.int32: SchemaType.SchemaColumn.SchemaColumnType.INTEGER,
_np.int64: SchemaType.SchemaColumn.SchemaColumnType.INTEGER,
_np.uint32: SchemaType.SchemaColumn.SchemaColumnType.INTEGER,
_np.uint64: SchemaType.SchemaColumn.SchemaColumnType.INTEGER,
int: SchemaType.SchemaColumn.SchemaColumnType.INTEGER,
_np.float32: SchemaType.SchemaColumn.SchemaColumnType.FLOAT,
_np.float64: SchemaType.SchemaColumn.SchemaColumnType.FLOAT,
float: SchemaType.SchemaColumn.SchemaColumnType.FLOAT,
_np.bool_: SchemaType.SchemaColumn.SchemaColumnType.BOOLEAN, # type: ignore
int: SchemaType.SchemaColumn.SchemaColumnType.INTEGER,
bool: SchemaType.SchemaColumn.SchemaColumnType.BOOLEAN,
_np.datetime64: SchemaType.SchemaColumn.SchemaColumnType.DATETIME,
str: SchemaType.SchemaColumn.SchemaColumnType.STRING,
datetime.datetime: SchemaType.SchemaColumn.SchemaColumnType.DATETIME,
_np.timedelta64: SchemaType.SchemaColumn.SchemaColumnType.DURATION,
datetime.timedelta: SchemaType.SchemaColumn.SchemaColumnType.DURATION,
_np.bytes_: SchemaType.SchemaColumn.SchemaColumnType.STRING,
_np.str_: SchemaType.SchemaColumn.SchemaColumnType.STRING,
_np.object_: SchemaType.SchemaColumn.SchemaColumnType.STRING,
str: SchemaType.SchemaColumn.SchemaColumnType.STRING,
}
_SUPPORTED_TYPES.update(_get_numpy_type_mappings())

def __init__(self):
super().__init__("FlyteSchema Transformer", FlyteSchema)
Expand Down
4 changes: 2 additions & 2 deletions plugins/flytekit-envd/tests/test_image_spec.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ def test_image_spec():
apt_packages=["git"],
python_version="3.8",
base_image=base_image,
pip_index="https://private-pip-index/simple",
pip_index="https://pypi.python.org/simple",
source_root=os.path.dirname(os.path.realpath(__file__)),
)

Expand All @@ -58,7 +58,7 @@ def build():
install.python_packages(name=["pandas"])
install.apt_packages(name=["git"])
runtime.environ(env={{'PYTHONPATH': '/root:', '_F_IMG_ID': '{image_name}'}}, extra_path=['/root'])
config.pip_index(url="https://private-pip-index/simple")
config.pip_index(url="https://pypi.python.org/simple")
install.python(version="3.8")
io.copy(source="./", target="/root")
"""
Expand Down
31 changes: 31 additions & 0 deletions tests/flytekit/unit/core/test_dataclass.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from dataclasses_json import DataClassJsonMixin
from mashumaro.mixins.json import DataClassJSONMixin
import os
import sys
import tempfile
from dataclasses import dataclass
from typing import Annotated, List, Dict, Optional
Expand Down Expand Up @@ -882,3 +883,33 @@ class NestedFlyteTypesWithDataClassJson(DataClassJsonMixin):
transformer = DataclassTransformer()
lt = transformer.get_literal_type(NestedFlyteTypesWithDataClassJson)
assert lt.metadata is not None
@pytest.mark.skipif(sys.version_info < (3, 10), reason="Requires Python 3.10 or higher")
def test_numpy_import_issue_from_flyte_schema_in_dataclass():
from dataclasses import dataclass

from flytekit import task, workflow
from flytekit.types.directory import FlyteDirectory
from flytekit.types.file import FlyteFile

@dataclass
class MyDataClass:
output_file: FlyteFile
output_directory: FlyteDirectory

@task
def my_flyte_workflow(b: bool) -> list[MyDataClass | None]:
if b:
return [MyDataClass(__file__, ".")]
return [None]

@task
def my_flyte_task(inputs: list[MyDataClass | None]) -> bool:
return inputs and (inputs[0] is not None) # type: ignore

@workflow
def main_flyte_workflow(b: bool = False) -> bool:
inputs = my_flyte_workflow(b=b)
return my_flyte_task(inputs=inputs)

assert main_flyte_workflow(b=True) == True
assert main_flyte_workflow(b=False) == False
3 changes: 3 additions & 0 deletions tests/flytekit/unit/interaction/test_click_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,6 +181,9 @@ def test_datetime_type():
with pytest.raises(click.BadParameter):
t.convert("aaa + 1d", None, None)

fmt_v = "2024-07-29 13:47:07.643004+00:00"
d = t.convert(fmt_v, None, None)
_datetime_helper(t, fmt_v, d)


def test_json_type():
Expand Down

0 comments on commit d507328

Please sign in to comment.