diff --git a/flytekit/common/types/impl/schema.py b/flytekit/common/types/impl/schema.py index a888215798..f3e6315647 100644 --- a/flytekit/common/types/impl/schema.py +++ b/flytekit/common/types/impl/schema.py @@ -68,7 +68,6 @@ def get_supported_literal_types_to_pandas_types(): FROM {table}; """ - # Set location in both parts of this query so in case of a partial failure, we will always have some data backing a # partition. _WRITE_HIVE_PARTITION_QUERY_FORMATTER = \ @@ -406,7 +405,6 @@ def __exit__(self, exc_type, exc_val, exc_tb): class SchemaType(_six.with_metaclass(_sdk_bases.ExtendedSdkType, _type_models.SchemaType)): - _LITERAL_TYPE_TO_PROTO_ENUM = { _primitives.Integer.to_flyte_literal_type(): _type_models.SchemaType.SchemaColumn.SchemaColumnType.INTEGER, _primitives.Float.to_flyte_literal_type(): _type_models.SchemaType.SchemaColumn.SchemaColumnType.FLOAT, @@ -591,6 +589,27 @@ def from_python_std(cls, t_value, schema_type=None): return schema elif isinstance(t_value, cls): return t_value + elif isinstance(t_value, _pd.DataFrame): + # Accepts a pandas dataframe and converts to a Schema object + o = cls.create_at_any_location(schema_type=schema_type) + with o as w: + w.write(t_value) + return o + elif isinstance(t_value, list): + # Accepts a list of pandas dataframe and converts to a Schema object + o = cls.create_at_any_location(schema_type=schema_type) + with o as w: + for x in t_value: + if isinstance(x, _pd.DataFrame): + w.write(x) + else: + raise _user_exceptions.FlyteTypeException( + type(t_value), + {str, _six.text_type, Schema}, + received_value=x, + additional_msg="A Schema object can only be create from a pandas DataFrame or a list of pandas DataFrame." + ) + return o else: raise _user_exceptions.FlyteTypeException( type(t_value), @@ -894,8 +913,8 @@ def compare_dataframe_to_schema(self, data_frame, column_subset=None, read=False # TODO np.issubdtype is deprecated. Replace it if all( - not _np.issubdtype(dtype, allowed_type) - for allowed_type in get_supported_literal_types_to_pandas_types()[literal_type] + not _np.issubdtype(dtype, allowed_type) + for allowed_type in get_supported_literal_types_to_pandas_types()[literal_type] ): if read: read_or_write_msg = "read data frame object from schema" diff --git a/tests/flytekit/unit/common_tests/types/impl/test_schema.py b/tests/flytekit/unit/common_tests/types/impl/test_schema.py index f7ad839ecb..c066964186 100644 --- a/tests/flytekit/unit/common_tests/types/impl/test_schema.py +++ b/tests/flytekit/unit/common_tests/types/impl/test_schema.py @@ -1,17 +1,20 @@ from __future__ import absolute_import -import collections as _collections +import datetime as _datetime import os as _os -import pytest as _pytest -import pandas as _pd import uuid as _uuid -import datetime as _datetime -from flytekit.common.types.impl import schema as _schema_impl -from flytekit.common.types import primitives as _primitives, blobs as _blobs + +import collections as _collections +import pandas as _pd +import pytest as _pytest +import six.moves as _six_moves + from flytekit.common import utils as _utils +from flytekit.common.exceptions import user as _user_exceptions +from flytekit.common.types import primitives as _primitives, blobs as _blobs +from flytekit.common.types.impl import schema as _schema_impl from flytekit.models import types as _type_models, literals as _literal_models from flytekit.sdk import test_utils as _test_utils -import six.moves as _six_moves def test_schema_type(): @@ -301,7 +304,57 @@ def test_casting(): def test_from_python_std(): - pass + with _test_utils.LocalTestFileSystem(): + def single_dataframe(): + df1 = _pd.DataFrame.from_dict({'a': [1, 2, 3, 4], 'b': [5, 6, 7, 8]}) + s = _schema_impl.Schema.from_python_std(t_value=df1, schema_type=_schema_impl.SchemaType( + [('a', _primitives.Integer), ('b', _primitives.Integer)])) + assert s is not None + n = _schema_impl.Schema.fetch(s.uri, schema_type=_schema_impl.SchemaType( + [('a', _primitives.Integer), ('b', _primitives.Integer)])) + with n as reader: + df2 = reader.read() + assert df2.columns.values.all() == df1.columns.values.all() + assert df2['b'].tolist() == df1['b'].tolist() + + def list_of_dataframes(): + df1 = _pd.DataFrame.from_dict({'a': [1, 2, 3, 4], 'b': [5, 6, 7, 8]}) + df2 = _pd.DataFrame.from_dict({'a': [9, 10, 11, 12], 'b': [13, 14, 15, 16]}) + s = _schema_impl.Schema.from_python_std(t_value=[df1, df2], schema_type=_schema_impl.SchemaType( + [('a', _primitives.Integer), ('b', _primitives.Integer)])) + assert s is not None + n = _schema_impl.Schema.fetch(s.uri, schema_type=_schema_impl.SchemaType( + [('a', _primitives.Integer), ('b', _primitives.Integer)])) + with n as reader: + actual = [] + for df in reader.iter_chunks(): + assert df.columns.values.all() == df1.columns.values.all() + actual.extend(df['b'].tolist()) + b_val = df1['b'].tolist() + b_val.extend(df2['b'].tolist()) + assert actual == b_val + + def mixed_list(): + df1 = _pd.DataFrame.from_dict({'a': [1, 2, 3, 4], 'b': [5, 6, 7, 8]}) + df2 = [1, 2, 3] + with _pytest.raises(_user_exceptions.FlyteTypeException): + _schema_impl.Schema.from_python_std(t_value=[df1, df2], schema_type=_schema_impl.SchemaType( + [('a', _primitives.Integer), ('b', _primitives.Integer)])) + + def empty_list(): + s = _schema_impl.Schema.from_python_std(t_value=[], schema_type=_schema_impl.SchemaType( + [('a', _primitives.Integer), ('b', _primitives.Integer)])) + assert s is not None + n = _schema_impl.Schema.fetch(s.uri, schema_type=_schema_impl.SchemaType( + [('a', _primitives.Integer), ('b', _primitives.Integer)])) + with n as reader: + df = reader.read() + assert df is None + + single_dataframe() + mixed_list() + empty_list() + list_of_dataframes() def test_promote_from_model_schema_type():