Skip to content

Commit

Permalink
Schema update to support directly setting pandas dataframe (#53)
Browse files Browse the repository at this point in the history
* Schema update to support directly setting pandas dataframe

* update

* Unit testing support

* Support for list of dataframes
  • Loading branch information
Ketan Umare authored Nov 15, 2019
1 parent 75f94b1 commit d0d1421
Show file tree
Hide file tree
Showing 2 changed files with 84 additions and 12 deletions.
27 changes: 23 additions & 4 deletions flytekit/common/types/impl/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,6 @@ def get_supported_literal_types_to_pandas_types():
FROM {table};
"""


# Set location in both parts of this query so in case of a partial failure, we will always have some data backing a
# partition.
_WRITE_HIVE_PARTITION_QUERY_FORMATTER = \
Expand Down Expand Up @@ -406,7 +405,6 @@ def __exit__(self, exc_type, exc_val, exc_tb):


class SchemaType(_six.with_metaclass(_sdk_bases.ExtendedSdkType, _type_models.SchemaType)):

_LITERAL_TYPE_TO_PROTO_ENUM = {
_primitives.Integer.to_flyte_literal_type(): _type_models.SchemaType.SchemaColumn.SchemaColumnType.INTEGER,
_primitives.Float.to_flyte_literal_type(): _type_models.SchemaType.SchemaColumn.SchemaColumnType.FLOAT,
Expand Down Expand Up @@ -591,6 +589,27 @@ def from_python_std(cls, t_value, schema_type=None):
return schema
elif isinstance(t_value, cls):
return t_value
elif isinstance(t_value, _pd.DataFrame):
# Accepts a pandas dataframe and converts to a Schema object
o = cls.create_at_any_location(schema_type=schema_type)
with o as w:
w.write(t_value)
return o
elif isinstance(t_value, list):
# Accepts a list of pandas dataframe and converts to a Schema object
o = cls.create_at_any_location(schema_type=schema_type)
with o as w:
for x in t_value:
if isinstance(x, _pd.DataFrame):
w.write(x)
else:
raise _user_exceptions.FlyteTypeException(
type(t_value),
{str, _six.text_type, Schema},
received_value=x,
additional_msg="A Schema object can only be create from a pandas DataFrame or a list of pandas DataFrame."
)
return o
else:
raise _user_exceptions.FlyteTypeException(
type(t_value),
Expand Down Expand Up @@ -894,8 +913,8 @@ def compare_dataframe_to_schema(self, data_frame, column_subset=None, read=False

# TODO np.issubdtype is deprecated. Replace it
if all(
not _np.issubdtype(dtype, allowed_type)
for allowed_type in get_supported_literal_types_to_pandas_types()[literal_type]
not _np.issubdtype(dtype, allowed_type)
for allowed_type in get_supported_literal_types_to_pandas_types()[literal_type]
):
if read:
read_or_write_msg = "read data frame object from schema"
Expand Down
69 changes: 61 additions & 8 deletions tests/flytekit/unit/common_tests/types/impl/test_schema.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,20 @@
from __future__ import absolute_import

import collections as _collections
import datetime as _datetime
import os as _os
import pytest as _pytest
import pandas as _pd
import uuid as _uuid
import datetime as _datetime
from flytekit.common.types.impl import schema as _schema_impl
from flytekit.common.types import primitives as _primitives, blobs as _blobs

import collections as _collections
import pandas as _pd
import pytest as _pytest
import six.moves as _six_moves

from flytekit.common import utils as _utils
from flytekit.common.exceptions import user as _user_exceptions
from flytekit.common.types import primitives as _primitives, blobs as _blobs
from flytekit.common.types.impl import schema as _schema_impl
from flytekit.models import types as _type_models, literals as _literal_models
from flytekit.sdk import test_utils as _test_utils
import six.moves as _six_moves


def test_schema_type():
Expand Down Expand Up @@ -301,7 +304,57 @@ def test_casting():


def test_from_python_std():
pass
with _test_utils.LocalTestFileSystem():
def single_dataframe():
df1 = _pd.DataFrame.from_dict({'a': [1, 2, 3, 4], 'b': [5, 6, 7, 8]})
s = _schema_impl.Schema.from_python_std(t_value=df1, schema_type=_schema_impl.SchemaType(
[('a', _primitives.Integer), ('b', _primitives.Integer)]))
assert s is not None
n = _schema_impl.Schema.fetch(s.uri, schema_type=_schema_impl.SchemaType(
[('a', _primitives.Integer), ('b', _primitives.Integer)]))
with n as reader:
df2 = reader.read()
assert df2.columns.values.all() == df1.columns.values.all()
assert df2['b'].tolist() == df1['b'].tolist()

def list_of_dataframes():
df1 = _pd.DataFrame.from_dict({'a': [1, 2, 3, 4], 'b': [5, 6, 7, 8]})
df2 = _pd.DataFrame.from_dict({'a': [9, 10, 11, 12], 'b': [13, 14, 15, 16]})
s = _schema_impl.Schema.from_python_std(t_value=[df1, df2], schema_type=_schema_impl.SchemaType(
[('a', _primitives.Integer), ('b', _primitives.Integer)]))
assert s is not None
n = _schema_impl.Schema.fetch(s.uri, schema_type=_schema_impl.SchemaType(
[('a', _primitives.Integer), ('b', _primitives.Integer)]))
with n as reader:
actual = []
for df in reader.iter_chunks():
assert df.columns.values.all() == df1.columns.values.all()
actual.extend(df['b'].tolist())
b_val = df1['b'].tolist()
b_val.extend(df2['b'].tolist())
assert actual == b_val

def mixed_list():
df1 = _pd.DataFrame.from_dict({'a': [1, 2, 3, 4], 'b': [5, 6, 7, 8]})
df2 = [1, 2, 3]
with _pytest.raises(_user_exceptions.FlyteTypeException):
_schema_impl.Schema.from_python_std(t_value=[df1, df2], schema_type=_schema_impl.SchemaType(
[('a', _primitives.Integer), ('b', _primitives.Integer)]))

def empty_list():
s = _schema_impl.Schema.from_python_std(t_value=[], schema_type=_schema_impl.SchemaType(
[('a', _primitives.Integer), ('b', _primitives.Integer)]))
assert s is not None
n = _schema_impl.Schema.fetch(s.uri, schema_type=_schema_impl.SchemaType(
[('a', _primitives.Integer), ('b', _primitives.Integer)]))
with n as reader:
df = reader.read()
assert df is None

single_dataframe()
mixed_list()
empty_list()
list_of_dataframes()


def test_promote_from_model_schema_type():
Expand Down

0 comments on commit d0d1421

Please sign in to comment.