diff --git a/temporalio/workflow.py b/temporalio/workflow.py index ed96324a..27b3019c 100644 --- a/temporalio/workflow.py +++ b/temporalio/workflow.py @@ -36,7 +36,6 @@ TypeVar, Union, cast, - get_origin, overload, ) @@ -1415,10 +1414,10 @@ def _apply_to_class( if not init_fn: issues.append("Missing __init__ method") elif run_fn: - init_fn_takes_workflow_input, init_fn_issue = ( - _get_init_fn_takes_workflow_input(init_fn, run_fn) + init_fn_takes_workflow_input, init_fn_issue = _init_fn_takes_workflow_input( + init_fn, run_fn ) - if init_fn_issue: + if init_fn_issue is not None: issues.append(init_fn_issue) if issues: if len(issues) == 1: @@ -1428,6 +1427,7 @@ def _apply_to_class( ) assert run_fn + assert init_fn_takes_workflow_input is not None defn = _Definition( name=workflow_name, cls=cls, @@ -1459,79 +1459,77 @@ def __post_init__(self) -> None: object.__setattr__(self, "ret_type", ret_type) -def _get_init_fn_takes_workflow_input( +def _init_fn_takes_workflow_input( init_fn: Callable[..., None], run_fn: Callable[..., Awaitable[Any]], -) -> Tuple[bool, Optional[str]]: +) -> Union[Tuple[bool, None], Tuple[None, str]]: """ - Return (True, None) if the Workflow __init__ method accepts Workflow input. - - If __init__ has a signature that features user-defined parameters, and if - these match those of the @workflow.run method, and if no validity issue is - encountered, then return (True, None). Otherwise return False, along with - any validity issue. + Return (True, None) if Workflow input args should be passed to Workflow __init__. """ - init_params = inspect.signature(init_fn).parameters - if init_params == inspect.signature(object.__init__).parameters: - # No user __init__, or user __init__ with (self, *args, **kwargs) signature - return False, None - elif set(init_params) == {"self"}: - # User __init__ with (self) signature + # If the workflow class can be instantiated as cls(), i.e. without passing + # workflow input args, then we do that. Otherwise, if the parameters of + # __init__ exactly match those of the @workflow.run method, then we pass the + # workflow input args to __init__ when instantiating the workflow class. + # Otherwise, the workflow definition is invalid. + + if _unbound_method_can_be_called_without_args_when_bound(init_fn): + # The workflow cls can be instantiated as cls() return False, None else: - # This (in 3.12) differs from the default signature inherited from - # object.__init__, which is (self, /, *args, **kwargs). We allow both as - # valid __init__ signatures (that do not take workflow input). - class StarArgsStarStarKwargs: - def __init__(self, *args, **kwargs) -> None: - pass - - if init_params == inspect.signature(StarArgsStarStarKwargs.__init__).parameters: - return False, None - - init_param_types, _ = temporalio.common._type_hints_from_func(init_fn) - run_param_types, _ = temporalio.common._type_hints_from_func(run_fn) - if init_param_types and run_param_types: - return _get_init_fn_takes_workflow_input_from_type_annotations( - init_param_types, run_param_types - ) - run_params = inspect.signature(run_fn).parameters - if len(init_params) != len(run_params): - return ( - False, - ( - f"Number of __init__ method parameters ({len(init_params)}) must equal " - f"number of @workflow.run method parameters ({len(run_params)})" - ), - ) - # TODO: positionals / kwargs - return True, None - - -def _get_init_fn_takes_workflow_input_from_type_annotations( - init_param_types: List[Type], - run_param_types: List[Type], -) -> Tuple[bool, Optional[str]]: - if len(init_param_types) != len(run_param_types): - return ( - False, - ( - f"Number of __init__ method parameters ({len(init_param_types)}) must equal " - f"number of @workflow.run method parameters ({len(run_param_types)})" - ), - ) - else: - for t1, t2 in zip(init_param_types, run_param_types): - # Just check that the types are the same in a naive sense; do not - # support any notion of subtype compatibility for now. - if get_origin(t1) != get_origin(t2): - return ( - False, - f"__init__ method param type {t1} does not match corresponding @workflow.run method param type {t2}", - ) - # TODO: positionals / kwargs - return True, None + def get_params(fn: Callable) -> List[inspect.Parameter]: + # Ignore name when comparing parameters (remaining fields are kind, + # default, and annotation). + return [ + p.replace(name="x") for p in inspect.signature(fn).parameters.values() + ] + + # We require that any type annotations present match exactly; i.e. we do + # not support any notion of subtype compatibility. + if get_params(init_fn) == get_params(run_fn): + # __init__ requires some args and has the same parameters as @workflow.run + return True, None + else: + return ( + None, + "__init__ parameters do not match @workflow.run method parameters", + ) + + +def _unbound_method_can_be_called_without_args_when_bound( + fn: Callable[..., Any], +) -> bool: + """ + Return True if the unbound method fn can be called without arguments when bound. + """ + # An unbound method can be called without arguments when bound if the + # following are both true: + # + # - The first parameter is POSITIONAL_ONLY or POSITIONAL_OR_KEYWORD (this is + # the parameter conventionally named 'self') + # + # - All other POSITIONAL_OR_KEYWORD or KEYWORD_ONLY parameters have default + # values. + params = iter(inspect.signature(fn).parameters.values()) + self_param = next(params, None) + if not self_param or self_param.kind not in [ + inspect.Parameter.POSITIONAL_OR_KEYWORD, + inspect.Parameter.POSITIONAL_ONLY, + ]: + raise ValueError("Not an unbound method. This is a Python SDK bug.") + for p in params: + if p.kind == inspect.Parameter.POSITIONAL_ONLY: + return False + elif ( + p.kind + in [ + inspect.Parameter.POSITIONAL_OR_KEYWORD, + inspect.Parameter.KEYWORD_ONLY, + ] + and p.default is inspect.Parameter.empty + ): + return False + return True # Async safe version of partial diff --git a/tests/test_workflow.py b/tests/test_workflow.py index 9d32e6f8..554c11a9 100644 --- a/tests/test_workflow.py +++ b/tests/test_workflow.py @@ -1,5 +1,5 @@ from dataclasses import dataclass -from typing import Protocol, Sequence +from typing import Any, Callable, List, Protocol, Sequence import pytest @@ -351,19 +351,68 @@ def test_workflow_defn_dynamic_handler_warnings(): # -class WorkflowInitGoodNoInitZeroParamRun: +class CanBeCalledWithoutArgs: + def a(self): + pass + + def b(self, arg=1): + pass + + def c(self, /, arg=1): + pass + + def d(self=1): + pass + + def e(self=1, /, arg1=1, *, arg2=2): + pass + + +class CannotBeCalledWithoutArgs: + def a(self, arg): + pass + + def b(self, /, arg): + pass + + def c(self, arg1, arg2=2): + pass + + +@pytest.mark.parametrize( + "fn,expected", + [ + (CanBeCalledWithoutArgs.a, True), + (CanBeCalledWithoutArgs.b, True), + (CanBeCalledWithoutArgs.c, True), + (CanBeCalledWithoutArgs.d, True), + (CanBeCalledWithoutArgs.e, True), + (CannotBeCalledWithoutArgs.a, False), + (CannotBeCalledWithoutArgs.b, False), + (CannotBeCalledWithoutArgs.c, False), + ], +) +def test_unbound_method_can_be_called_without_args_when_bound( + fn: Callable[..., Any], expected: bool +): + assert ( + workflow._unbound_method_can_be_called_without_args_when_bound(fn) == expected + ) + + +class NormalInitNoInit: @workflow.run async def run(self) -> None: pass -class WorkflowInitGoodNoInitOneParamRun: +class NormalInitNoInitOneParamRun: @workflow.run async def run(self, a: int) -> None: pass -class WorkflowInitGoodNoArgInitZeroParamRun: +class NormalInitNoArgInitZeroParamRun: def __init__(self) -> None: pass @@ -372,7 +421,7 @@ async def run(self) -> None: pass -class WorkflowInitGoodNoArgInitOneParamRun: +class NormalInitNoArgInitOneParamRun: def __init__(self) -> None: pass @@ -387,13 +436,13 @@ class MyDataClass: b: str -class DataClassTypedWorkflowProto(Protocol): +class NormalInitDataClassProtocol(Protocol): @workflow.run async def run(self, arg: MyDataClass) -> None: pass -class WorkflowInitGoodSlashStarArgsStarStarKwargs: +class NormalInitSlashStarArgsStarStarKwargs: def __init__(self, /, *args, **kwargs) -> None: pass @@ -402,7 +451,7 @@ async def run(self) -> None: pass -class WorkflowInitGoodStarArgsStarStarKwargs: +class NormalInitStarArgsStarStarKwargs: def __init__(self, *args, **kwargs) -> None: pass @@ -411,16 +460,36 @@ async def run(self) -> None: pass +class NormalInitStarDefault: + def __init__(self, *, arg=1) -> None: + pass + + @workflow.run + async def run(self, arg) -> None: + pass + + +class NormalInitTypedDefault: + def __init__(self, a: int = 1) -> None: + pass + + @workflow.run + async def run(self, aa: int) -> None: + pass + + @pytest.mark.parametrize( "cls", [ - WorkflowInitGoodNoInitZeroParamRun, - WorkflowInitGoodNoInitOneParamRun, - WorkflowInitGoodNoArgInitZeroParamRun, - WorkflowInitGoodNoArgInitOneParamRun, - DataClassTypedWorkflowProto, - WorkflowInitGoodSlashStarArgsStarStarKwargs, - WorkflowInitGoodStarArgsStarStarKwargs, + NormalInitNoInit, + NormalInitNoInitOneParamRun, + NormalInitNoArgInitZeroParamRun, + NormalInitNoArgInitOneParamRun, + NormalInitDataClassProtocol, + NormalInitSlashStarArgsStarStarKwargs, + NormalInitStarArgsStarStarKwargs, + NormalInitStarDefault, + NormalInitTypedDefault, ], ) def test_workflow_init_good_does_not_take_workflow_input(cls): @@ -429,7 +498,7 @@ def test_workflow_init_good_does_not_take_workflow_input(cls): ).__temporal_workflow_definition.init_fn_takes_workflow_input -class WorkflowInitGoodOneParamTyped: +class WorkflowInitOneParamTyped: def __init__(self, a: int) -> None: pass @@ -438,7 +507,7 @@ async def run(self, aa: int) -> None: pass -class WorkflowInitGoodTwoParamsTyped: +class WorkflowInitTwoParamsTyped: def __init__(self, a: int, b: str) -> None: pass @@ -447,7 +516,7 @@ async def run(self, aa: int, bb: str) -> None: pass -class WorkflowInitGoodOneParamUntyped: +class WorkflowInitOneParamUntyped: def __init__(self, a) -> None: pass @@ -456,7 +525,7 @@ async def run(self, aa) -> None: pass -class WorkflowInitGoodTwoParamsUntyped: +class WorkflowInitTwoParamsUntyped: def __init__(self, a, b) -> None: pass @@ -465,43 +534,43 @@ async def run(self, aa, bb) -> None: pass -class WorkflowInitGoodOneParamNoInitType: - def __init__(self, a) -> None: +class WorkflowInitSlashStarArgsStarStarKwargs: + def __init__(self, /, a, *args, **kwargs) -> None: pass @workflow.run - async def run(self, aa: int) -> None: + async def run(self, /, a, *args, **kwargs) -> None: pass -class WorkflowInitGoodOneParamNoRunType: - def __init__(self, a: int) -> None: +class WorkflowInitStarArgsStarStarKwargs: + def __init__(self, *args, a, **kwargs) -> None: pass @workflow.run - async def run(self, aa) -> None: + async def run(self, *args, a, **kwargs) -> None: pass -class WorkflowInitGoodTwoParamsMixedTyping: - def __init__(self, a, b: str) -> None: +class WorkflowInitStarDefault: + def __init__(self, a, *, arg=1) -> None: pass @workflow.run - async def run(self, aa: str, bb) -> None: + async def run(self, a, *, arg=1) -> None: pass @pytest.mark.parametrize( "cls", [ - WorkflowInitGoodOneParamTyped, - WorkflowInitGoodTwoParamsTyped, - WorkflowInitGoodOneParamUntyped, - WorkflowInitGoodTwoParamsUntyped, - WorkflowInitGoodTwoParamsMixedTyping, - WorkflowInitGoodOneParamNoInitType, - WorkflowInitGoodOneParamNoRunType, + WorkflowInitOneParamTyped, + WorkflowInitTwoParamsTyped, + WorkflowInitOneParamUntyped, + WorkflowInitTwoParamsUntyped, + WorkflowInitSlashStarArgsStarStarKwargs, + WorkflowInitStarArgsStarStarKwargs, + WorkflowInitStarDefault, ], ) def test_workflow_init_good_takes_workflow_input(cls): @@ -546,6 +615,62 @@ async def run(self, aa: int, bb: str) -> None: pass +class WorkflowInitBadOneParamNoInitType: + def __init__(self, a) -> None: + pass + + @workflow.run + async def run(self, aa: int) -> None: + pass + + +class WorkflowInitBadGenericSubtype: + # The types must match exactly; we do not support any notion of subtype + # compatibility. + def __init__(self, a: List) -> None: + pass + + @workflow.run + async def run(self, aa: List[int]) -> None: + pass + + +class WorkflowInitBadOneParamNoRunType: + def __init__(self, a: int) -> None: + pass + + @workflow.run + async def run(self, aa) -> None: + pass + + +class WorkflowInitBadMissingDefault: + def __init__(self, a: int, b: int) -> None: + pass + + @workflow.run + async def run(self, aa: int, b: int = 1) -> None: + pass + + +class WorkflowInitBadInconsistentDefaults: + def __init__(self, a: int, b: int = 1) -> None: + pass + + @workflow.run + async def run(self, aa: int, b: int = 2) -> None: + pass + + +class WorkflowInitBadTwoParamsMixedTyping: + def __init__(self, a, b: str) -> None: + pass + + @workflow.run + async def run(self, aa: str, bb) -> None: + pass + + @pytest.mark.parametrize( "cls", [ @@ -553,9 +678,14 @@ async def run(self, aa: int, bb: str) -> None: WorkflowInitBadMismatchedParamUntyped, WorkflowInitBadExtraInitParamTyped, WorkflowInitBadMismatchedParamTyped, + WorkflowInitBadTwoParamsMixedTyping, + WorkflowInitBadOneParamNoInitType, + WorkflowInitBadOneParamNoRunType, + WorkflowInitBadGenericSubtype, + WorkflowInitBadMissingDefault, + WorkflowInitBadInconsistentDefaults, ], ) def test_workflow_init_bad_takes_workflow_input(cls): - with pytest.raises(ValueError) as err: + with pytest.raises(ValueError): workflow.defn(cls) - print(err) diff --git a/tests/worker/test_workflow.py b/tests/worker/test_workflow.py index 7bba0083..4d9b37d3 100644 --- a/tests/worker/test_workflow.py +++ b/tests/worker/test_workflow.py @@ -6053,3 +6053,44 @@ async def test_workflow_init(client: Client): task_queue=worker.task_queue, ) assert workflow_result == "hello, world" + + +@workflow.defn +class WorkflowInitUpdateInFirstWFTWorkflow: + def __init__(self, arg: str) -> None: + self.init_arg = arg + + @workflow.update + async def my_update(self) -> str: + return self.init_arg + + @workflow.run + async def run(self, _: str): + self.init_arg = "value set in run method" + return self.init_arg + + +async def test_update_in_first_wft_sees_workflow_init(client: Client): + # Before running the worker, start a workflow, send the update, and wait + # until update is admitted. + task_queue = "task-queue" + update_id = "update-id" + wf_handle = await client.start_workflow( + WorkflowInitUpdateInFirstWFTWorkflow.run, + "workflow input value", + id=str(uuid.uuid4()), + task_queue=task_queue, + ) + update_task = asyncio.create_task( + wf_handle.execute_update( + WorkflowInitUpdateInFirstWFTWorkflow.my_update, id=update_id + ) + ) + await assert_eq_eventually( + True, lambda: workflow_update_exists(client, wf_handle.id, update_id) + ) + async with new_worker( + client, WorkflowInitUpdateInFirstWFTWorkflow, task_queue=task_queue + ): + assert await update_task == "workflow input value" + assert await wf_handle.result() == "value set in run method"