Skip to content

Commit

Permalink
Merge pull request #2 from lyft/fix-hive-unit-test
Browse files Browse the repository at this point in the history
Implement Hive Unit Test Behavior
  • Loading branch information
matthewphsmith authored Aug 23, 2019
2 parents 5e97f30 + b1d31a1 commit 8274b2c
Show file tree
Hide file tree
Showing 4 changed files with 97 additions and 17 deletions.
2 changes: 1 addition & 1 deletion flytekit/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from __future__ import absolute_import
import flytekit.plugins

__version__ = '0.1.5'
__version__ = '0.1.6'
66 changes: 51 additions & 15 deletions flytekit/engines/unit/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,16 @@
from datetime import datetime as _datetime
from six import moves as _six_moves

from google.protobuf.json_format import ParseDict as _ParseDict
from flyteidl.plugins import qubole_pb2 as _qubole_pb2
from flytekit.common import constants as _sdk_constants, utils as _common_utils
from flytekit.common.exceptions import user as _user_exceptions, system as _system_exception
from flytekit.common.types import helpers as _type_helpers
from flytekit.configuration import TemporaryConfiguration as _TemporaryConfiguration
from flytekit.engines import common as _common_engine
from flytekit.engines.unit.mock_stats import MockStats
from flytekit.interfaces.data import data_proxy as _data_proxy
from flytekit.models import literals as _literals, array_job as _array_job
from flytekit.models import literals as _literals, array_job as _array_job, qubole as _qubole_models
from flytekit.models.core.identifier import WorkflowExecutionIdentifier


Expand All @@ -32,9 +34,12 @@ def get_task(self, sdk_task):
return ReturnOutputsTask(sdk_task)
elif sdk_task.type in {
_sdk_constants.SdkTaskType.DYNAMIC_TASK,
_sdk_constants.SdkTaskType.BATCH_HIVE_TASK
}:
return DynamicTask(sdk_task)
elif sdk_task.type in {
_sdk_constants.SdkTaskType.BATCH_HIVE_TASK,
}:
return HiveTask(sdk_task)
else:
raise _user_exceptions.FlyteAssertion(
"Unit tests are not currently supported for tasks of type: {}".format(
Expand Down Expand Up @@ -76,20 +81,20 @@ def execute(self, inputs, context=None):
Just execute the function and return the outputs as a user-readable dictionary.
:param flytekit.models.literals.LiteralMap inputs:
:param context:
:rtype: dict[Text, flytekit.models.common.FlyteIdlEntity]
:rtype: dict[Text,flytekit.models.common.FlyteIdlEntity]
"""
with _TemporaryConfiguration(
_os.path.join(_os.path.dirname(__file__), 'unit.config'),
internal_overrides={'image': 'unit_image'}
):
with _common_utils.AutoDeletingTempDir("unit_test_dir") as working_directory:
with _data_proxy.LocalWorkingDirectoryContext(working_directory):
return self._execute_user_code(inputs)
return self._transform_for_user_output(self._execute_user_code(inputs))

def _execute_user_code(self, inputs):
"""
:param flytekit.models.literals.LiteralMap inputs:
:rtype: dict[Text, flytekit.models.common.FlyteIdlEntity]
:rtype: dict[Text,flytekit.models.common.FlyteIdlEntity]
"""
with _common_utils.AutoDeletingTempDir("user_dir") as user_working_directory:
return self.sdk_task.execute(
Expand All @@ -107,24 +112,32 @@ def _execute_user_code(self, inputs):
inputs
)

def _transform_for_user_output(self, outputs):
"""
Take whatever is returned from the task execution and convert to a reasonable output for the behavior of this
task's unit test.
:param dict[Text,flytekit.models.common.FlyteIdlEntity] outputs:
:rtype: T
"""
return outputs

def register(self, identifier, version):
raise _user_exceptions.FlyteAssertion("You cannot register unit test tasks.")


class ReturnOutputsTask(UnitTestEngineTask):
def execute(self, inputs, context=None):
def _transform_for_user_output(self, outputs):
"""
Just execute the function and return the outputs as a user-readable dictionary.
:param flytekit.models.literals.LiteralMap inputs:
:param context:
:rtype: dict[Text, T]
Just return the outputs as a user-readable dictionary.
:param dict[Text,flytekit.models.common.FlyteIdlEntity] outputs:
:rtype: T
"""
outputs = super(ReturnOutputsTask, self).execute(inputs)[_sdk_constants.OUTPUT_FILE_NAME]
literal_map = outputs[_sdk_constants.OUTPUT_FILE_NAME]
return {
name: _type_helpers.get_sdk_type_from_literal_type(
variable.type
).promote_from_model(
outputs.literals[name]
literal_map.literals[name]
).to_python_std()
for name, variable in _six.iteritems(self.sdk_task.interface.outputs)
}
Expand All @@ -135,7 +148,7 @@ class DynamicTask(ReturnOutputsTask):
def _execute_user_code(self, inputs):
"""
:param flytekit.models.literals.LiteralMap inputs:
:rtype: dict[Text, flytekit.models.common.FlyteIdlEntity]
:rtype: dict[Text,flytekit.models.common.FlyteIdlEntity]
"""
results = super(DynamicTask, self)._execute_user_code(inputs)
if _sdk_constants.FUTURES_FILE_NAME in results:
Expand All @@ -151,7 +164,7 @@ def _execute_user_code(self, inputs):
# TODO: futures.outputs should have the Schema instances.
# After schema is implemented, fill out random data into the random locations
# then check output in test function
# From Haytham even though we recommend people use typed schemas, they might not always do so...
# Even though we recommend people use typed schemas, they might not always do so...
# in which case it'll be impossible to predict the actual schema, we should support a
# way for unit test authors to provide fake data regardless
sub_task_output = None
Expand Down Expand Up @@ -201,7 +214,7 @@ def fulfil_bindings(binding_data, fulfilled_promises):
fulfilled_promises
:param _interface.BindingData binding_data:
:param dict[Text, T] fulfilled_promises:
:param dict[Text,T] fulfilled_promises:
:rtype:
"""
if binding_data.scalar:
Expand All @@ -228,3 +241,26 @@ def fulfil_bindings(binding_data, fulfilled_promises):
k: DynamicTask.fulfil_bindings(sub_binding_data, fulfilled_promises) for k, sub_binding_data in
_six.iteritems(binding_data.map.bindings)
}))


class HiveTask(DynamicTask):
def _transform_for_user_output(self, outputs):
"""
Just execute the function and return the list of Hive queries returned.
:param dict[Text,flytekit.models.common.FlyteIdlEntity] outputs:
:rtype: list[Text]
"""
futures = outputs.get(_sdk_constants.FUTURES_FILE_NAME)
if futures:
task_ids_to_defs = {
t.id.name: _qubole_models.QuboleHiveJob.from_flyte_idl(
_ParseDict(t.custom, _qubole_pb2.QuboleHiveJob())
)
for t in futures.tasks
}
return [
q.query
for q in task_ids_to_defs[futures.nodes[0].task_node.reference_id.name].query_collection.queries
]
else:
return []
2 changes: 1 addition & 1 deletion setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ max-complexity=16
[tool:pytest]
norecursedirs = common workflows spark
log_cli = true
log_cli_level = 100
log_cli_level = 20

[pep8]
max-line-length = 120
Expand Down
44 changes: 44 additions & 0 deletions tests/flytekit/unit/use_scenarios/unit_testing/hive_tasks.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
from __future__ import absolute_import
from flytekit.sdk.tasks import hive_task
import pytest


def test_no_queries():
@hive_task
def test_hive_task(wf_params):
pass

assert test_hive_task.unit_test() == []


def test_empty_list_queries():
@hive_task
def test_hive_task(wf_params):
return []

assert test_hive_task.unit_test() == []


def test_one_query():
@hive_task
def test_hive_task(wf_params):
return "abc"

assert test_hive_task.unit_test() == ["abc"]


def test_multiple_queries():
@hive_task
def test_hive_task(wf_params):
return ["abc", "cde"]

assert test_hive_task.unit_test() == ["abc", "cde"]


def test_raise_exception():
@hive_task
def test_hive_task(wf_params):
raise FloatingPointError("Floating point error for some reason.")

with pytest.raises(FloatingPointError):
test_hive_task.unit_test()

0 comments on commit 8274b2c

Please sign in to comment.