From b677bdac5260195fe74038be786167c658b3ecd2 Mon Sep 17 00:00:00 2001 From: Dan Davison Date: Wed, 28 Aug 2024 13:29:04 -0400 Subject: [PATCH 01/17] Refactor to support workflow init --- temporalio/worker/_workflow_instance.py | 54 +++++++++++++++++-------- 1 file changed, 37 insertions(+), 17 deletions(-) diff --git a/temporalio/worker/_workflow_instance.py b/temporalio/worker/_workflow_instance.py index f897e02e..af0ddd56 100644 --- a/temporalio/worker/_workflow_instance.py +++ b/temporalio/worker/_workflow_instance.py @@ -201,6 +201,7 @@ def __init__(self, det: WorkflowInstanceDetails) -> None: self._payload_converter = det.payload_converter_class() self._failure_converter = det.failure_converter_class() self._defn = det.defn + self._workflow_input: Optional[ExecuteWorkflowInput] = None self._info = det.info self._extern_functions = det.extern_functions self._disable_eager_activity_execution = det.disable_eager_activity_execution @@ -318,8 +319,9 @@ def get_thread_id(self) -> Optional[int]: return self._current_thread_id #### Activation functions #### - # These are in alphabetical order and besides "activate", all other calls - # are "_apply_" + the job field name. + # These are in alphabetical order and besides "activate", and + # "_make_workflow_input", all other calls are "_apply_" + the job field + # name. def activate( self, act: temporalio.bridge.proto.workflow_activation.WorkflowActivation @@ -342,6 +344,7 @@ def activate( try: # Split into job sets with patches, then signals + updates, then # non-queries, then queries + start_job = None job_sets: List[ List[temporalio.bridge.proto.workflow_activation.WorkflowActivationJob] ] = [[], [], [], []] @@ -351,10 +354,15 @@ def activate( elif job.HasField("signal_workflow") or job.HasField("do_update"): job_sets[1].append(job) elif not job.HasField("query_workflow"): + if job.HasField("start_workflow"): + start_job = job.start_workflow job_sets[2].append(job) else: job_sets[3].append(job) + if start_job: + self._workflow_input = self._make_workflow_input(start_job) + # Apply every job set, running after each set for index, job_set in enumerate(job_sets): if not job_set: @@ -863,34 +871,41 @@ async def run_workflow(input: ExecuteWorkflowInput) -> None: return raise + if not self._workflow_input: + raise RuntimeError( + "Expected workflow input to be set. This is an SDK Python bug." + ) + self._primary_task = self.create_task( + self._run_top_level_workflow_function(run_workflow(self._workflow_input)), + name="run", + ) + + def _apply_update_random_seed( + self, job: temporalio.bridge.proto.workflow_activation.UpdateRandomSeed + ) -> None: + self._random.seed(job.randomness_seed) + + def _make_workflow_input( + self, start_job: temporalio.bridge.proto.workflow_activation.StartWorkflow + ) -> ExecuteWorkflowInput: # Set arg types, using raw values for dynamic arg_types = self._defn.arg_types if not self._defn.name: # Dynamic is just the raw value for each input value - arg_types = [temporalio.common.RawValue] * len(job.arguments) - args = self._convert_payloads(job.arguments, arg_types) + arg_types = [temporalio.common.RawValue] * len(start_job.arguments) + args = self._convert_payloads(start_job.arguments, arg_types) # Put args in a list if dynamic if not self._defn.name: args = [args] - # Schedule it - input = ExecuteWorkflowInput( + return ExecuteWorkflowInput( type=self._defn.cls, # TODO(cretz): Remove cast when https://github.com/python/mypy/issues/5485 fixed run_fn=cast(Callable[..., Awaitable[Any]], self._defn.run_fn), args=args, - headers=job.headers, - ) - self._primary_task = self.create_task( - self._run_top_level_workflow_function(run_workflow(input)), - name="run", + headers=start_job.headers, ) - def _apply_update_random_seed( - self, job: temporalio.bridge.proto.workflow_activation.UpdateRandomSeed - ) -> None: - self._random.seed(job.randomness_seed) - #### _Runtime direct workflow call overrides #### # These are in alphabetical order and all start with "workflow_". @@ -1617,6 +1632,11 @@ def _convert_payloads( except Exception as err: raise RuntimeError("Failed decoding arguments") from err + def _instantiate_workflow_object(self) -> Any: + if not self._workflow_input: + raise RuntimeError("Expected workflow input. This is a Python SDK bug.") + return self._defn.cls() + def _is_workflow_failure_exception(self, err: BaseException) -> bool: # An exception is a failure instead of a task fail if it's already a # failure error or if it is an instance of any of the failure types in @@ -1752,7 +1772,7 @@ def _run_once(self, *, check_conditions: bool) -> None: # We instantiate the workflow class _inside_ here because __init__ # needs to run with this event loop set if not self._object: - self._object = self._defn.cls() + self._object = self._instantiate_workflow_object() # Run while there is anything ready while self._ready: From 0b335935577e45a08c00f75951544191de5c9349 Mon Sep 17 00:00:00 2001 From: Dan Davison Date: Sat, 24 Aug 2024 13:51:20 -0400 Subject: [PATCH 02/17] Implement workflow __init__ support --- temporalio/worker/_workflow_instance.py | 5 +- temporalio/workflow.py | 92 +++++++++- tests/test_workflow.py | 225 +++++++++++++++++++++++- tests/worker/test_workflow.py | 21 +++ 4 files changed, 335 insertions(+), 8 deletions(-) diff --git a/temporalio/worker/_workflow_instance.py b/temporalio/worker/_workflow_instance.py index af0ddd56..b7696858 100644 --- a/temporalio/worker/_workflow_instance.py +++ b/temporalio/worker/_workflow_instance.py @@ -1635,7 +1635,10 @@ def _convert_payloads( def _instantiate_workflow_object(self) -> Any: if not self._workflow_input: raise RuntimeError("Expected workflow input. This is a Python SDK bug.") - return self._defn.cls() + if self._defn.init_fn_takes_workflow_input: + return self._defn.cls(*self._workflow_input.args) + else: + return self._defn.cls() def _is_workflow_failure_exception(self, err: BaseException) -> bool: # An exception is a failure instead of a task fail if it's already a diff --git a/temporalio/workflow.py b/temporalio/workflow.py index 6ee25ad2..fdff99d9 100644 --- a/temporalio/workflow.py +++ b/temporalio/workflow.py @@ -36,6 +36,7 @@ TypeVar, Union, cast, + get_origin, overload, ) @@ -1252,6 +1253,7 @@ class _Definition: name: Optional[str] cls: Type run_fn: Callable[..., Awaitable] + init_fn_takes_workflow_input: bool signals: Mapping[Optional[str], _SignalDefinition] queries: Mapping[Optional[str], _QueryDefinition] updates: Mapping[Optional[str], _UpdateDefinition] @@ -1306,14 +1308,14 @@ def _apply_to_class( raise ValueError("Class already contains workflow definition") issues: List[str] = [] - # Collect run fn and all signal/query/update fns - members = inspect.getmembers(cls) + # Collect init, run, and all signal/query/update fns + init_fn: Optional[Callable[..., None]] = None run_fn: Optional[Callable[..., Awaitable[Any]]] = None seen_run_attr = False signals: Dict[Optional[str], _SignalDefinition] = {} queries: Dict[Optional[str], _QueryDefinition] = {} updates: Dict[Optional[str], _UpdateDefinition] = {} - for name, member in members: + for name, member in inspect.getmembers(cls): if hasattr(member, "__temporal_workflow_run"): seen_run_attr = True if not _is_unbound_method_on_cls(member, cls): @@ -1354,6 +1356,8 @@ def _apply_to_class( ) else: queries[query_defn.name] = query_defn + elif name == "__init__": + init_fn = member elif isinstance(member, UpdateMethodMultiParam): update_defn = member._defn if update_defn.name in updates: @@ -1404,11 +1408,21 @@ def _apply_to_class( f"@workflow.update defined on {base_member.__qualname__} but not on the override" ) + init_fn_takes_workflow_input = False + if not seen_run_attr: issues.append("Missing @workflow.run method") - if len(issues) == 1: - raise ValueError(f"Invalid workflow class: {issues[0]}") - elif issues: + if not init_fn: + issues.append("Missing __init__ method") + elif run_fn: + init_fn_takes_workflow_input, init_fn_issue = ( + _get_init_fn_takes_workflow_input(init_fn, run_fn) + ) + if init_fn_issue: + issues.append(init_fn_issue) + if issues: + if len(issues) == 1: + raise ValueError(f"Invalid workflow class: {issues[0]}") raise ValueError( f"Invalid workflow class for {len(issues)} reasons: {', '.join(issues)}" ) @@ -1418,6 +1432,7 @@ def _apply_to_class( name=workflow_name, cls=cls, run_fn=run_fn, + init_fn_takes_workflow_input=init_fn_takes_workflow_input, signals=signals, queries=queries, updates=updates, @@ -1444,6 +1459,71 @@ def __post_init__(self) -> None: object.__setattr__(self, "ret_type", ret_type) +def _get_init_fn_takes_workflow_input( + init_fn: Callable[..., None], + run_fn: Callable[..., Awaitable[Any]], +) -> Tuple[bool, Optional[str]]: + """ + Return (True, None) if the Workflow __init__ method accepts Workflow input. + + If __init__ has a signature that features user-defined parameters, and if + these match those of the @workflow.run method, and if no validity issue is + encountered, then return (True, None). Otherwise return False, along with + any validity issue. + """ + init_params = inspect.signature(init_fn).parameters + if init_params == inspect.signature(object.__init__).parameters: + # No user __init__, or user __init__ with (self, *args, **kwargs) signature + return False, None + elif set(init_params) == {"self"}: + # User __init__ with (self) signature + return False, None + + init_param_types, _ = temporalio.common._type_hints_from_func(init_fn) + run_param_types, _ = temporalio.common._type_hints_from_func(run_fn) + if init_param_types and run_param_types: + return _get_init_fn_takes_workflow_input_from_type_annotations( + init_param_types, run_param_types + ) + + run_params = inspect.signature(run_fn).parameters + if len(init_params) != len(run_params): + return ( + False, + ( + f"Number of __init__ method parameters ({len(init_params)}) must equal " + f"number of @workflow.run method parameters ({len(run_params)})" + ), + ) + # TODO: positionals / kwargs + return True, None + + +def _get_init_fn_takes_workflow_input_from_type_annotations( + init_param_types: List[Type], + run_param_types: List[Type], +) -> Tuple[bool, Optional[str]]: + if len(init_param_types) != len(run_param_types): + return ( + False, + ( + f"Number of __init__ method parameters ({len(init_param_types)}) must equal " + f"number of @workflow.run method parameters ({len(run_param_types)})" + ), + ) + else: + for t1, t2 in zip(init_param_types, run_param_types): + # Just check that the types are the same in a naive sense; do not + # support any notion of subtype compatibility for now. + if get_origin(t1) != get_origin(t2): + return ( + False, + f"__init__ method param type {t1} does not match corresponding @workflow.run method param type {t2}", + ) + # TODO: positionals / kwargs + return True, None + + # Async safe version of partial def _bind_method(obj: Any, fn: Callable[..., Any]) -> Callable[..., Any]: # Curry instance on the definition function since that represents an diff --git a/tests/test_workflow.py b/tests/test_workflow.py index 510fa18d..2082d7df 100644 --- a/tests/test_workflow.py +++ b/tests/test_workflow.py @@ -1,4 +1,5 @@ -from typing import Any, Sequence +from dataclasses import dataclass +from typing import Protocol, Sequence import pytest @@ -75,6 +76,7 @@ def test_workflow_defn_good(): name="workflow-custom", cls=GoodDefn, run_fn=GoodDefn.run, + init_fn_takes_workflow_input=False, signals={ "signal1": workflow._SignalDefinition( name="signal1", fn=GoodDefn.signal1, is_method=True @@ -342,3 +344,224 @@ def test_workflow_defn_dynamic_handler_warnings(): # We want to make sure they are reporting the right stacklevel warnings[0].filename.endswith("test_workflow.py") warnings[1].filename.endswith("test_workflow.py") + + +# +# workflow init tests +# + + +class WorkflowInitGoodNoInitZeroParamRun: + @workflow.run + async def run(self) -> None: + pass + + +class WorkflowInitGoodNoInitOneParamRun: + @workflow.run + async def run(self, a: int) -> None: + pass + + +class WorkflowInitGoodNoArgInitZeroParamRun: + def __init__(self) -> None: + pass + + @workflow.run + async def run(self) -> None: + pass + + +class WorkflowInitGoodNoArgInitOneParamRun: + def __init__(self) -> None: + pass + + @workflow.run + async def run(self, a: int) -> None: + pass + + +@dataclass +class MyDataClass: + a: int + b: str + + +class DataClassTypedWorkflowProto(Protocol): + @workflow.run + async def run(self, arg: MyDataClass) -> None: + pass + + +class WorkflowInitGoodSlashStarArgsStarStarKwargsInitZeroParamRun: + # TODO: if they include the slash it will be allowed as it is + # indistinguishable from the default __init__ inherited from object. But if + # they don't it will not be. + def __init__(self, /, *args, **kwargs) -> None: + pass + + @workflow.run + async def run(self) -> None: + pass + + +@pytest.mark.parametrize( + "cls", + [ + WorkflowInitGoodNoInitZeroParamRun, + WorkflowInitGoodNoInitOneParamRun, + WorkflowInitGoodNoArgInitZeroParamRun, + WorkflowInitGoodNoArgInitOneParamRun, + DataClassTypedWorkflowProto, + WorkflowInitGoodSlashStarArgsStarStarKwargsInitZeroParamRun, + ], +) +def test_workflow_init_good_does_not_take_workflow_input(cls): + assert not workflow.defn( + cls + ).__temporal_workflow_definition.init_fn_takes_workflow_input + + +class WorkflowInitGoodOneParamTyped: + def __init__(self, a: int) -> None: + pass + + @workflow.run + async def run(self, aa: int) -> None: + pass + + +class WorkflowInitGoodTwoParamsTyped: + def __init__(self, a: int, b: str) -> None: + pass + + @workflow.run + async def run(self, aa: int, bb: str) -> None: + pass + + +class WorkflowInitGoodOneParamUntyped: + def __init__(self, a) -> None: + pass + + @workflow.run + async def run(self, aa) -> None: + pass + + +class WorkflowInitGoodTwoParamsUntyped: + def __init__(self, a, b) -> None: + pass + + @workflow.run + async def run(self, aa, bb) -> None: + pass + + +class WorkflowInitGoodOneParamNoInitType: + def __init__(self, a) -> None: + pass + + @workflow.run + async def run(self, aa: int) -> None: + pass + + +class WorkflowInitGoodOneParamNoRunType: + def __init__(self, a: int) -> None: + pass + + @workflow.run + async def run(self, aa) -> None: + pass + + +class WorkflowInitGoodTwoParamsMixedTyping: + def __init__(self, a, b: str) -> None: + pass + + @workflow.run + async def run(self, aa: str, bb) -> None: + pass + + +@pytest.mark.parametrize( + "cls", + [ + WorkflowInitGoodOneParamTyped, + WorkflowInitGoodTwoParamsTyped, + WorkflowInitGoodOneParamUntyped, + WorkflowInitGoodTwoParamsUntyped, + WorkflowInitGoodTwoParamsMixedTyping, + WorkflowInitGoodOneParamNoInitType, + WorkflowInitGoodOneParamNoRunType, + ], +) +def test_workflow_init_good_takes_workflow_input(cls): + assert workflow.defn( + cls + ).__temporal_workflow_definition.init_fn_takes_workflow_input + + +class WorkflowInitBadStarArgsStarStarKwargs: + # TODO: if they include the slash it will be allowed as it is + # indistinguishable from the default __init__ inherited from object. But if + # they don't it will not be. + def __init__(self, *args, **kwargs) -> None: + pass + + @workflow.run + async def run(self) -> None: + pass + + +class WorkflowInitBadExtraInitParamUntyped: + def __init__(self, a) -> None: + pass + + @workflow.run + async def run(self) -> None: + pass + + +class WorkflowInitBadMismatchedParamUntyped: + def __init__(self, a) -> None: + pass + + @workflow.run + async def run(self, aa, bb) -> None: + pass + + +class WorkflowInitBadExtraInitParamTyped: + def __init__(self, a: int) -> None: + pass + + @workflow.run + async def run(self) -> None: + pass + + +class WorkflowInitBadMismatchedParamTyped: + def __init__(self, a: int) -> None: + pass + + @workflow.run + async def run(self, aa: int, bb: str) -> None: + pass + + +@pytest.mark.parametrize( + "cls", + [ + WorkflowInitBadStarArgsStarStarKwargs, + WorkflowInitBadExtraInitParamUntyped, + WorkflowInitBadMismatchedParamUntyped, + WorkflowInitBadExtraInitParamTyped, + WorkflowInitBadMismatchedParamTyped, + ], +) +def test_workflow_init_bad_takes_workflow_input(cls): + with pytest.raises(ValueError) as err: + workflow.defn(cls) + print(err) diff --git a/tests/worker/test_workflow.py b/tests/worker/test_workflow.py index 9f007a22..2665c653 100644 --- a/tests/worker/test_workflow.py +++ b/tests/worker/test_workflow.py @@ -6032,3 +6032,24 @@ async def test_activity_retry_delay(client: Client): err.cause.cause.next_retry_delay == ActivitiesWithRetryDelayWorkflow.next_retry_delay ) + + +@workflow.defn +class WorkflowInitWorkflow: + def __init__(self, arg: str) -> None: + self.init_arg = arg + + @workflow.run + async def run(self, _: str): + return f"hello, {self.init_arg}" + + +async def test_workflow_init(client: Client): + async with new_worker(client, WorkflowInitWorkflow) as worker: + workflow_result = await client.execute_workflow( + WorkflowInitWorkflow.run, + "world", + id=str(uuid.uuid4()), + task_queue=worker.task_queue, + ) + assert workflow_result == "hello, world" From a68911e1fca29a4b6ccba3b5c12cfd05e934f462 Mon Sep 17 00:00:00 2001 From: Dan Davison Date: Mon, 2 Sep 2024 21:38:49 -0400 Subject: [PATCH 03/17] Handle non-slash form __init__(self, *args, **kwargs) --- temporalio/workflow.py | 10 ++++++++++ tests/test_workflow.py | 30 ++++++++++++------------------ 2 files changed, 22 insertions(+), 18 deletions(-) diff --git a/temporalio/workflow.py b/temporalio/workflow.py index fdff99d9..83e25ca4 100644 --- a/temporalio/workflow.py +++ b/temporalio/workflow.py @@ -1478,6 +1478,16 @@ def _get_init_fn_takes_workflow_input( elif set(init_params) == {"self"}: # User __init__ with (self) signature return False, None + else: + # This (in 3.12) differs from the default signature inherited from + # object.__init__, which is (self, /, *args, **kwargs). We allow both as + # valid __init__ signatures (that do not take workflow input). + class StarArgsStarStarKwargs: + def __init__(self, *args, **kwargs) -> None: + pass + + if init_params == inspect.signature(StarArgsStarStarKwargs.__init__).parameters: + return False, None init_param_types, _ = temporalio.common._type_hints_from_func(init_fn) run_param_types, _ = temporalio.common._type_hints_from_func(run_fn) diff --git a/tests/test_workflow.py b/tests/test_workflow.py index 2082d7df..9d32e6f8 100644 --- a/tests/test_workflow.py +++ b/tests/test_workflow.py @@ -393,10 +393,7 @@ async def run(self, arg: MyDataClass) -> None: pass -class WorkflowInitGoodSlashStarArgsStarStarKwargsInitZeroParamRun: - # TODO: if they include the slash it will be allowed as it is - # indistinguishable from the default __init__ inherited from object. But if - # they don't it will not be. +class WorkflowInitGoodSlashStarArgsStarStarKwargs: def __init__(self, /, *args, **kwargs) -> None: pass @@ -405,6 +402,15 @@ async def run(self) -> None: pass +class WorkflowInitGoodStarArgsStarStarKwargs: + def __init__(self, *args, **kwargs) -> None: + pass + + @workflow.run + async def run(self) -> None: + pass + + @pytest.mark.parametrize( "cls", [ @@ -413,7 +419,8 @@ async def run(self) -> None: WorkflowInitGoodNoArgInitZeroParamRun, WorkflowInitGoodNoArgInitOneParamRun, DataClassTypedWorkflowProto, - WorkflowInitGoodSlashStarArgsStarStarKwargsInitZeroParamRun, + WorkflowInitGoodSlashStarArgsStarStarKwargs, + WorkflowInitGoodStarArgsStarStarKwargs, ], ) def test_workflow_init_good_does_not_take_workflow_input(cls): @@ -503,18 +510,6 @@ def test_workflow_init_good_takes_workflow_input(cls): ).__temporal_workflow_definition.init_fn_takes_workflow_input -class WorkflowInitBadStarArgsStarStarKwargs: - # TODO: if they include the slash it will be allowed as it is - # indistinguishable from the default __init__ inherited from object. But if - # they don't it will not be. - def __init__(self, *args, **kwargs) -> None: - pass - - @workflow.run - async def run(self) -> None: - pass - - class WorkflowInitBadExtraInitParamUntyped: def __init__(self, a) -> None: pass @@ -554,7 +549,6 @@ async def run(self, aa: int, bb: str) -> None: @pytest.mark.parametrize( "cls", [ - WorkflowInitBadStarArgsStarStarKwargs, WorkflowInitBadExtraInitParamUntyped, WorkflowInitBadMismatchedParamUntyped, WorkflowInitBadExtraInitParamTyped, From 3573c0b20453fe869f2056639446df61859e5654 Mon Sep 17 00:00:00 2001 From: Dan Davison Date: Wed, 4 Sep 2024 21:40:59 -0400 Subject: [PATCH 04/17] Simplify and fix implementation --- temporalio/workflow.py | 138 +++++++++++------------ tests/test_workflow.py | 204 ++++++++++++++++++++++++++++------ tests/worker/test_workflow.py | 41 +++++++ 3 files changed, 274 insertions(+), 109 deletions(-) diff --git a/temporalio/workflow.py b/temporalio/workflow.py index 83e25ca4..584c245f 100644 --- a/temporalio/workflow.py +++ b/temporalio/workflow.py @@ -36,7 +36,6 @@ TypeVar, Union, cast, - get_origin, overload, ) @@ -1415,10 +1414,10 @@ def _apply_to_class( if not init_fn: issues.append("Missing __init__ method") elif run_fn: - init_fn_takes_workflow_input, init_fn_issue = ( - _get_init_fn_takes_workflow_input(init_fn, run_fn) + init_fn_takes_workflow_input, init_fn_issue = _init_fn_takes_workflow_input( # type: ignore[assignment] + init_fn, run_fn ) - if init_fn_issue: + if init_fn_issue is not None: issues.append(init_fn_issue) if issues: if len(issues) == 1: @@ -1428,6 +1427,7 @@ def _apply_to_class( ) assert run_fn + assert init_fn_takes_workflow_input is not None defn = _Definition( name=workflow_name, cls=cls, @@ -1459,79 +1459,73 @@ def __post_init__(self) -> None: object.__setattr__(self, "ret_type", ret_type) -def _get_init_fn_takes_workflow_input( +def _init_fn_takes_workflow_input( init_fn: Callable[..., None], run_fn: Callable[..., Awaitable[Any]], -) -> Tuple[bool, Optional[str]]: - """ - Return (True, None) if the Workflow __init__ method accepts Workflow input. - - If __init__ has a signature that features user-defined parameters, and if - these match those of the @workflow.run method, and if no validity issue is - encountered, then return (True, None). Otherwise return False, along with - any validity issue. - """ - init_params = inspect.signature(init_fn).parameters - if init_params == inspect.signature(object.__init__).parameters: - # No user __init__, or user __init__ with (self, *args, **kwargs) signature - return False, None - elif set(init_params) == {"self"}: - # User __init__ with (self) signature +) -> Union[Tuple[bool, None], Tuple[None, str]]: + """Return (True, None) if Workflow input args should be passed to Workflow __init__.""" + # If the workflow class can be instantiated as cls(), i.e. without passing + # workflow input args, then we do that. Otherwise, if the parameters of + # __init__ exactly match those of the @workflow.run method, then we pass the + # workflow input args to __init__ when instantiating the workflow class. + # Otherwise, the workflow definition is invalid. + + if _unbound_method_can_be_called_without_args_when_bound(init_fn): + # The workflow cls can be instantiated as cls() return False, None else: - # This (in 3.12) differs from the default signature inherited from - # object.__init__, which is (self, /, *args, **kwargs). We allow both as - # valid __init__ signatures (that do not take workflow input). - class StarArgsStarStarKwargs: - def __init__(self, *args, **kwargs) -> None: - pass - - if init_params == inspect.signature(StarArgsStarStarKwargs.__init__).parameters: - return False, None - - init_param_types, _ = temporalio.common._type_hints_from_func(init_fn) - run_param_types, _ = temporalio.common._type_hints_from_func(run_fn) - if init_param_types and run_param_types: - return _get_init_fn_takes_workflow_input_from_type_annotations( - init_param_types, run_param_types - ) - run_params = inspect.signature(run_fn).parameters - if len(init_params) != len(run_params): - return ( - False, - ( - f"Number of __init__ method parameters ({len(init_params)}) must equal " - f"number of @workflow.run method parameters ({len(run_params)})" - ), - ) - # TODO: positionals / kwargs - return True, None - - -def _get_init_fn_takes_workflow_input_from_type_annotations( - init_param_types: List[Type], - run_param_types: List[Type], -) -> Tuple[bool, Optional[str]]: - if len(init_param_types) != len(run_param_types): - return ( - False, - ( - f"Number of __init__ method parameters ({len(init_param_types)}) must equal " - f"number of @workflow.run method parameters ({len(run_param_types)})" - ), - ) - else: - for t1, t2 in zip(init_param_types, run_param_types): - # Just check that the types are the same in a naive sense; do not - # support any notion of subtype compatibility for now. - if get_origin(t1) != get_origin(t2): - return ( - False, - f"__init__ method param type {t1} does not match corresponding @workflow.run method param type {t2}", - ) - # TODO: positionals / kwargs - return True, None + def get_params(fn: Callable) -> List[inspect.Parameter]: + # Ignore name when comparing parameters (remaining fields are kind, + # default, and annotation). + return [ + p.replace(name="x") for p in inspect.signature(fn).parameters.values() + ] + + # We require that any type annotations present match exactly; i.e. we do + # not support any notion of subtype compatibility. + if get_params(init_fn) == get_params(run_fn): + # __init__ requires some args and has the same parameters as @workflow.run + return True, None + else: + return ( + None, + "__init__ parameters do not match @workflow.run method parameters", + ) + + +def _unbound_method_can_be_called_without_args_when_bound( + fn: Callable[..., Any], +) -> bool: + """Return True if the unbound method fn can be called without arguments when bound.""" + # An unbound method can be called without arguments when bound if the + # following are both true: + # + # - The first parameter is POSITIONAL_ONLY or POSITIONAL_OR_KEYWORD (this is + # the parameter conventionally named 'self') + # + # - All other POSITIONAL_OR_KEYWORD or KEYWORD_ONLY parameters have default + # values. + params = iter(inspect.signature(fn).parameters.values()) + self_param = next(params, None) + if not self_param or self_param.kind not in [ + inspect.Parameter.POSITIONAL_OR_KEYWORD, + inspect.Parameter.POSITIONAL_ONLY, + ]: + raise ValueError("Not an unbound method. This is a Python SDK bug.") + for p in params: + if p.kind == inspect.Parameter.POSITIONAL_ONLY: + return False + elif ( + p.kind + in [ + inspect.Parameter.POSITIONAL_OR_KEYWORD, + inspect.Parameter.KEYWORD_ONLY, + ] + and p.default is inspect.Parameter.empty + ): + return False + return True # Async safe version of partial diff --git a/tests/test_workflow.py b/tests/test_workflow.py index 9d32e6f8..86bd49ff 100644 --- a/tests/test_workflow.py +++ b/tests/test_workflow.py @@ -1,5 +1,5 @@ from dataclasses import dataclass -from typing import Protocol, Sequence +from typing import Any, Callable, List, Protocol, Sequence import pytest @@ -351,19 +351,68 @@ def test_workflow_defn_dynamic_handler_warnings(): # -class WorkflowInitGoodNoInitZeroParamRun: +class CanBeCalledWithoutArgs: + def a(self): + pass + + def b(self, arg=1): + pass + + def c(self, /, arg=1): + pass + + def d(self=1): # type: ignore + pass + + def e(self=1, /, arg1=1, *, arg2=2): # type: ignore + pass + + +class CannotBeCalledWithoutArgs: + def a(self, arg): + pass + + def b(self, /, arg): + pass + + def c(self, arg1, arg2=2): + pass + + +@pytest.mark.parametrize( + "fn,expected", + [ + (CanBeCalledWithoutArgs.a, True), + (CanBeCalledWithoutArgs.b, True), + (CanBeCalledWithoutArgs.c, True), + (CanBeCalledWithoutArgs.d, True), + (CanBeCalledWithoutArgs.e, True), + (CannotBeCalledWithoutArgs.a, False), + (CannotBeCalledWithoutArgs.b, False), + (CannotBeCalledWithoutArgs.c, False), + ], +) +def test_unbound_method_can_be_called_without_args_when_bound( + fn: Callable[..., Any], expected: bool +): + assert ( + workflow._unbound_method_can_be_called_without_args_when_bound(fn) == expected + ) + + +class NormalInitNoInit: @workflow.run async def run(self) -> None: pass -class WorkflowInitGoodNoInitOneParamRun: +class NormalInitNoInitOneParamRun: @workflow.run async def run(self, a: int) -> None: pass -class WorkflowInitGoodNoArgInitZeroParamRun: +class NormalInitNoArgInitZeroParamRun: def __init__(self) -> None: pass @@ -372,7 +421,7 @@ async def run(self) -> None: pass -class WorkflowInitGoodNoArgInitOneParamRun: +class NormalInitNoArgInitOneParamRun: def __init__(self) -> None: pass @@ -387,13 +436,13 @@ class MyDataClass: b: str -class DataClassTypedWorkflowProto(Protocol): +class NormalInitDataClassProtocol(Protocol): @workflow.run async def run(self, arg: MyDataClass) -> None: pass -class WorkflowInitGoodSlashStarArgsStarStarKwargs: +class NormalInitSlashStarArgsStarStarKwargs: def __init__(self, /, *args, **kwargs) -> None: pass @@ -402,7 +451,7 @@ async def run(self) -> None: pass -class WorkflowInitGoodStarArgsStarStarKwargs: +class NormalInitStarArgsStarStarKwargs: def __init__(self, *args, **kwargs) -> None: pass @@ -411,16 +460,36 @@ async def run(self) -> None: pass +class NormalInitStarDefault: + def __init__(self, *, arg=1) -> None: + pass + + @workflow.run + async def run(self, arg) -> None: + pass + + +class NormalInitTypedDefault: + def __init__(self, a: int = 1) -> None: + pass + + @workflow.run + async def run(self, aa: int) -> None: + pass + + @pytest.mark.parametrize( "cls", [ - WorkflowInitGoodNoInitZeroParamRun, - WorkflowInitGoodNoInitOneParamRun, - WorkflowInitGoodNoArgInitZeroParamRun, - WorkflowInitGoodNoArgInitOneParamRun, - DataClassTypedWorkflowProto, - WorkflowInitGoodSlashStarArgsStarStarKwargs, - WorkflowInitGoodStarArgsStarStarKwargs, + NormalInitNoInit, + NormalInitNoInitOneParamRun, + NormalInitNoArgInitZeroParamRun, + NormalInitNoArgInitOneParamRun, + NormalInitDataClassProtocol, + NormalInitSlashStarArgsStarStarKwargs, + NormalInitStarArgsStarStarKwargs, + NormalInitStarDefault, + NormalInitTypedDefault, ], ) def test_workflow_init_good_does_not_take_workflow_input(cls): @@ -429,7 +498,7 @@ def test_workflow_init_good_does_not_take_workflow_input(cls): ).__temporal_workflow_definition.init_fn_takes_workflow_input -class WorkflowInitGoodOneParamTyped: +class WorkflowInitOneParamTyped: def __init__(self, a: int) -> None: pass @@ -438,7 +507,7 @@ async def run(self, aa: int) -> None: pass -class WorkflowInitGoodTwoParamsTyped: +class WorkflowInitTwoParamsTyped: def __init__(self, a: int, b: str) -> None: pass @@ -447,7 +516,7 @@ async def run(self, aa: int, bb: str) -> None: pass -class WorkflowInitGoodOneParamUntyped: +class WorkflowInitOneParamUntyped: def __init__(self, a) -> None: pass @@ -456,7 +525,7 @@ async def run(self, aa) -> None: pass -class WorkflowInitGoodTwoParamsUntyped: +class WorkflowInitTwoParamsUntyped: def __init__(self, a, b) -> None: pass @@ -465,43 +534,43 @@ async def run(self, aa, bb) -> None: pass -class WorkflowInitGoodOneParamNoInitType: - def __init__(self, a) -> None: +class WorkflowInitSlashStarArgsStarStarKwargs: + def __init__(self, /, a, *args, **kwargs) -> None: pass @workflow.run - async def run(self, aa: int) -> None: + async def run(self, /, a, *args, **kwargs) -> None: pass -class WorkflowInitGoodOneParamNoRunType: - def __init__(self, a: int) -> None: +class WorkflowInitStarArgsStarStarKwargs: + def __init__(self, *args, a, **kwargs) -> None: pass @workflow.run - async def run(self, aa) -> None: + async def run(self, *args, a, **kwargs) -> None: pass -class WorkflowInitGoodTwoParamsMixedTyping: - def __init__(self, a, b: str) -> None: +class WorkflowInitStarDefault: + def __init__(self, a, *, arg=1) -> None: pass @workflow.run - async def run(self, aa: str, bb) -> None: + async def run(self, a, *, arg=1) -> None: pass @pytest.mark.parametrize( "cls", [ - WorkflowInitGoodOneParamTyped, - WorkflowInitGoodTwoParamsTyped, - WorkflowInitGoodOneParamUntyped, - WorkflowInitGoodTwoParamsUntyped, - WorkflowInitGoodTwoParamsMixedTyping, - WorkflowInitGoodOneParamNoInitType, - WorkflowInitGoodOneParamNoRunType, + WorkflowInitOneParamTyped, + WorkflowInitTwoParamsTyped, + WorkflowInitOneParamUntyped, + WorkflowInitTwoParamsUntyped, + WorkflowInitSlashStarArgsStarStarKwargs, + WorkflowInitStarArgsStarStarKwargs, + WorkflowInitStarDefault, ], ) def test_workflow_init_good_takes_workflow_input(cls): @@ -546,6 +615,62 @@ async def run(self, aa: int, bb: str) -> None: pass +class WorkflowInitBadOneParamNoInitType: + def __init__(self, a) -> None: + pass + + @workflow.run + async def run(self, aa: int) -> None: + pass + + +class WorkflowInitBadGenericSubtype: + # The types must match exactly; we do not support any notion of subtype + # compatibility. + def __init__(self, a: List) -> None: + pass + + @workflow.run + async def run(self, aa: List[int]) -> None: + pass + + +class WorkflowInitBadOneParamNoRunType: + def __init__(self, a: int) -> None: + pass + + @workflow.run + async def run(self, aa) -> None: + pass + + +class WorkflowInitBadMissingDefault: + def __init__(self, a: int, b: int) -> None: + pass + + @workflow.run + async def run(self, aa: int, b: int = 1) -> None: + pass + + +class WorkflowInitBadInconsistentDefaults: + def __init__(self, a: int, b: int = 1) -> None: + pass + + @workflow.run + async def run(self, aa: int, b: int = 2) -> None: + pass + + +class WorkflowInitBadTwoParamsMixedTyping: + def __init__(self, a, b: str) -> None: + pass + + @workflow.run + async def run(self, aa: str, bb) -> None: + pass + + @pytest.mark.parametrize( "cls", [ @@ -553,9 +678,14 @@ async def run(self, aa: int, bb: str) -> None: WorkflowInitBadMismatchedParamUntyped, WorkflowInitBadExtraInitParamTyped, WorkflowInitBadMismatchedParamTyped, + WorkflowInitBadTwoParamsMixedTyping, + WorkflowInitBadOneParamNoInitType, + WorkflowInitBadOneParamNoRunType, + WorkflowInitBadGenericSubtype, + WorkflowInitBadMissingDefault, + WorkflowInitBadInconsistentDefaults, ], ) def test_workflow_init_bad_takes_workflow_input(cls): - with pytest.raises(ValueError) as err: + with pytest.raises(ValueError): workflow.defn(cls) - print(err) diff --git a/tests/worker/test_workflow.py b/tests/worker/test_workflow.py index 2665c653..7abbf4f3 100644 --- a/tests/worker/test_workflow.py +++ b/tests/worker/test_workflow.py @@ -6053,3 +6053,44 @@ async def test_workflow_init(client: Client): task_queue=worker.task_queue, ) assert workflow_result == "hello, world" + + +@workflow.defn +class WorkflowInitUpdateInFirstWFTWorkflow: + def __init__(self, arg: str) -> None: + self.init_arg = arg + + @workflow.update + async def my_update(self) -> str: + return self.init_arg + + @workflow.run + async def run(self, _: str): + self.init_arg = "value set in run method" + return self.init_arg + + +async def test_update_in_first_wft_sees_workflow_init(client: Client): + # Before running the worker, start a workflow, send the update, and wait + # until update is admitted. + task_queue = "task-queue" + update_id = "update-id" + wf_handle = await client.start_workflow( + WorkflowInitUpdateInFirstWFTWorkflow.run, + "workflow input value", + id=str(uuid.uuid4()), + task_queue=task_queue, + ) + update_task = asyncio.create_task( + wf_handle.execute_update( + WorkflowInitUpdateInFirstWFTWorkflow.my_update, id=update_id + ) + ) + await assert_eq_eventually( + True, lambda: workflow_update_exists(client, wf_handle.id, update_id) + ) + async with new_worker( + client, WorkflowInitUpdateInFirstWFTWorkflow, task_queue=task_queue + ): + assert await update_task == "workflow input value" + assert await wf_handle.result() == "value set in run method" From 5fda934d75c4b9fac93f6123c0913c0aba5e217a Mon Sep 17 00:00:00 2001 From: Dan Davison Date: Fri, 6 Sep 2024 10:01:53 -0400 Subject: [PATCH 05/17] Add (fatally) failing test case --- tests/test_workflow.py | 21 +++++++++++++++++++++ 1 file changed, 21 insertions(+) diff --git a/tests/test_workflow.py b/tests/test_workflow.py index 86bd49ff..43c7af93 100644 --- a/tests/test_workflow.py +++ b/tests/test_workflow.py @@ -442,6 +442,25 @@ async def run(self, arg: MyDataClass) -> None: pass +# Although the class is abstract, a user may decorate it with @workflow.defn, +# for example in order to set the name as the same as the child class, so that +# the client codebase need only import the interface. +class NormalInitAbstractBaseClass: + def __init__(self, arg_supplied_by_child_cls) -> None: + pass + + @workflow.run + async def run(self) -> None: ... + + +class NormalInitChildClass(NormalInitAbstractBaseClass): + def __init__(self) -> None: + super().__init__(arg_supplied_by_child_cls=None) + + @workflow.run + async def run(self) -> None: ... + + class NormalInitSlashStarArgsStarStarKwargs: def __init__(self, /, *args, **kwargs) -> None: pass @@ -490,6 +509,8 @@ async def run(self, aa: int) -> None: NormalInitStarArgsStarStarKwargs, NormalInitStarDefault, NormalInitTypedDefault, + NormalInitAbstractBaseClass, + NormalInitChildClass, ], ) def test_workflow_init_good_does_not_take_workflow_input(cls): From 4f7a0e018746a0e95ad00528e3d20fe176365fbb Mon Sep 17 00:00:00 2001 From: Dan Davison Date: Thu, 5 Sep 2024 21:23:18 -0400 Subject: [PATCH 06/17] Defer decision until workflow-start time --- temporalio/worker/_workflow_instance.py | 76 ++++++++++++++++++++- temporalio/workflow.py | 87 +------------------------ tests/test_workflow.py | 28 ++++---- 3 files changed, 93 insertions(+), 98 deletions(-) diff --git a/temporalio/worker/_workflow_instance.py b/temporalio/worker/_workflow_instance.py index b7696858..1d4c3fb9 100644 --- a/temporalio/worker/_workflow_instance.py +++ b/temporalio/worker/_workflow_instance.py @@ -1635,7 +1635,12 @@ def _convert_payloads( def _instantiate_workflow_object(self) -> Any: if not self._workflow_input: raise RuntimeError("Expected workflow input. This is a Python SDK bug.") - if self._defn.init_fn_takes_workflow_input: + takes_workflow_input, err = _init_fn_takes_workflow_input( + self._defn.cls.__init__, self._workflow_input.run_fn + ) + if takes_workflow_input is None: + raise RuntimeError(f"Cannot instantiate workflow: {err}") + elif takes_workflow_input: return self._defn.cls(*self._workflow_input.args) else: return self._defn.cls() @@ -2852,3 +2857,72 @@ def _make_unfinished_signal_handler_message( [{"name": name, "count": count} for name, count in names.most_common()] ) ) + + +def _init_fn_takes_workflow_input( + init_fn: Callable[..., None], + run_fn: Callable[..., Awaitable[Any]], +) -> Union[Tuple[bool, None], Tuple[None, str]]: + """Return (True, None) if Workflow input args should be passed to Workflow __init__.""" + # If the workflow class can be instantiated as cls(), i.e. without passing + # workflow input args, then we do that. Otherwise, if the parameters of + # __init__ exactly match those of the @workflow.run method, then we pass the + # workflow input args to __init__ when instantiating the workflow class. + # Otherwise, the workflow definition is invalid. + + if _unbound_method_can_be_called_without_args_when_bound(init_fn): + # The workflow cls can be instantiated as cls() + return False, None + else: + + def get_params(fn: Callable) -> List[inspect.Parameter]: + # Ignore name when comparing parameters (remaining fields are kind, + # default, and annotation). + return [ + p.replace(name="x") for p in inspect.signature(fn).parameters.values() + ] + + # We require that any type annotations present match exactly; i.e. we do + # not support any notion of subtype compatibility. + if get_params(init_fn) == get_params(run_fn): + # __init__ requires some args and has the same parameters as @workflow.run + return True, None + else: + return ( + None, + "__init__ parameters do not match @workflow.run method parameters", + ) + + +def _unbound_method_can_be_called_without_args_when_bound( + fn: Callable[..., Any], +) -> bool: + """Return True if the unbound method fn can be called without arguments when bound.""" + # An unbound method can be called without arguments when bound if the + # following are both true: + # + # - The first parameter is POSITIONAL_ONLY or POSITIONAL_OR_KEYWORD (this is + # the parameter conventionally named 'self') + # + # - All other POSITIONAL_OR_KEYWORD or KEYWORD_ONLY parameters have default + # values. + params = iter(inspect.signature(fn).parameters.values()) + self_param = next(params, None) + if not self_param or self_param.kind not in [ + inspect.Parameter.POSITIONAL_OR_KEYWORD, + inspect.Parameter.POSITIONAL_ONLY, + ]: + raise ValueError("Not an unbound method. This is a Python SDK bug.") + for p in params: + if p.kind == inspect.Parameter.POSITIONAL_ONLY: + return False + elif ( + p.kind + in [ + inspect.Parameter.POSITIONAL_OR_KEYWORD, + inspect.Parameter.KEYWORD_ONLY, + ] + and p.default is inspect.Parameter.empty + ): + return False + return True diff --git a/temporalio/workflow.py b/temporalio/workflow.py index 584c245f..92101eb8 100644 --- a/temporalio/workflow.py +++ b/temporalio/workflow.py @@ -1252,7 +1252,6 @@ class _Definition: name: Optional[str] cls: Type run_fn: Callable[..., Awaitable] - init_fn_takes_workflow_input: bool signals: Mapping[Optional[str], _SignalDefinition] queries: Mapping[Optional[str], _QueryDefinition] updates: Mapping[Optional[str], _UpdateDefinition] @@ -1307,8 +1306,7 @@ def _apply_to_class( raise ValueError("Class already contains workflow definition") issues: List[str] = [] - # Collect init, run, and all signal/query/update fns - init_fn: Optional[Callable[..., None]] = None + # Collect run fn and all signal/query/update fns run_fn: Optional[Callable[..., Awaitable[Any]]] = None seen_run_attr = False signals: Dict[Optional[str], _SignalDefinition] = {} @@ -1355,8 +1353,6 @@ def _apply_to_class( ) else: queries[query_defn.name] = query_defn - elif name == "__init__": - init_fn = member elif isinstance(member, UpdateMethodMultiParam): update_defn = member._defn if update_defn.name in updates: @@ -1407,18 +1403,8 @@ def _apply_to_class( f"@workflow.update defined on {base_member.__qualname__} but not on the override" ) - init_fn_takes_workflow_input = False - if not seen_run_attr: issues.append("Missing @workflow.run method") - if not init_fn: - issues.append("Missing __init__ method") - elif run_fn: - init_fn_takes_workflow_input, init_fn_issue = _init_fn_takes_workflow_input( # type: ignore[assignment] - init_fn, run_fn - ) - if init_fn_issue is not None: - issues.append(init_fn_issue) if issues: if len(issues) == 1: raise ValueError(f"Invalid workflow class: {issues[0]}") @@ -1427,12 +1413,10 @@ def _apply_to_class( ) assert run_fn - assert init_fn_takes_workflow_input is not None defn = _Definition( name=workflow_name, cls=cls, run_fn=run_fn, - init_fn_takes_workflow_input=init_fn_takes_workflow_input, signals=signals, queries=queries, updates=updates, @@ -1459,75 +1443,6 @@ def __post_init__(self) -> None: object.__setattr__(self, "ret_type", ret_type) -def _init_fn_takes_workflow_input( - init_fn: Callable[..., None], - run_fn: Callable[..., Awaitable[Any]], -) -> Union[Tuple[bool, None], Tuple[None, str]]: - """Return (True, None) if Workflow input args should be passed to Workflow __init__.""" - # If the workflow class can be instantiated as cls(), i.e. without passing - # workflow input args, then we do that. Otherwise, if the parameters of - # __init__ exactly match those of the @workflow.run method, then we pass the - # workflow input args to __init__ when instantiating the workflow class. - # Otherwise, the workflow definition is invalid. - - if _unbound_method_can_be_called_without_args_when_bound(init_fn): - # The workflow cls can be instantiated as cls() - return False, None - else: - - def get_params(fn: Callable) -> List[inspect.Parameter]: - # Ignore name when comparing parameters (remaining fields are kind, - # default, and annotation). - return [ - p.replace(name="x") for p in inspect.signature(fn).parameters.values() - ] - - # We require that any type annotations present match exactly; i.e. we do - # not support any notion of subtype compatibility. - if get_params(init_fn) == get_params(run_fn): - # __init__ requires some args and has the same parameters as @workflow.run - return True, None - else: - return ( - None, - "__init__ parameters do not match @workflow.run method parameters", - ) - - -def _unbound_method_can_be_called_without_args_when_bound( - fn: Callable[..., Any], -) -> bool: - """Return True if the unbound method fn can be called without arguments when bound.""" - # An unbound method can be called without arguments when bound if the - # following are both true: - # - # - The first parameter is POSITIONAL_ONLY or POSITIONAL_OR_KEYWORD (this is - # the parameter conventionally named 'self') - # - # - All other POSITIONAL_OR_KEYWORD or KEYWORD_ONLY parameters have default - # values. - params = iter(inspect.signature(fn).parameters.values()) - self_param = next(params, None) - if not self_param or self_param.kind not in [ - inspect.Parameter.POSITIONAL_OR_KEYWORD, - inspect.Parameter.POSITIONAL_ONLY, - ]: - raise ValueError("Not an unbound method. This is a Python SDK bug.") - for p in params: - if p.kind == inspect.Parameter.POSITIONAL_ONLY: - return False - elif ( - p.kind - in [ - inspect.Parameter.POSITIONAL_OR_KEYWORD, - inspect.Parameter.KEYWORD_ONLY, - ] - and p.default is inspect.Parameter.empty - ): - return False - return True - - # Async safe version of partial def _bind_method(obj: Any, fn: Callable[..., Any]) -> Callable[..., Any]: # Curry instance on the definition function since that represents an diff --git a/tests/test_workflow.py b/tests/test_workflow.py index 43c7af93..eff46268 100644 --- a/tests/test_workflow.py +++ b/tests/test_workflow.py @@ -5,6 +5,7 @@ from temporalio import workflow from temporalio.common import RawValue +from temporalio.worker import _workflow_instance class GoodDefnBase: @@ -76,7 +77,6 @@ def test_workflow_defn_good(): name="workflow-custom", cls=GoodDefn, run_fn=GoodDefn.run, - init_fn_takes_workflow_input=False, signals={ "signal1": workflow._SignalDefinition( name="signal1", fn=GoodDefn.signal1, is_method=True @@ -396,7 +396,8 @@ def test_unbound_method_can_be_called_without_args_when_bound( fn: Callable[..., Any], expected: bool ): assert ( - workflow._unbound_method_can_be_called_without_args_when_bound(fn) == expected + _workflow_instance._unbound_method_can_be_called_without_args_when_bound(fn) + == expected ) @@ -509,14 +510,16 @@ async def run(self, aa: int) -> None: NormalInitStarArgsStarStarKwargs, NormalInitStarDefault, NormalInitTypedDefault, - NormalInitAbstractBaseClass, + # The base class is abstract, so will never be encountered by the worker + # during workflow task processing. NormalInitChildClass, ], ) def test_workflow_init_good_does_not_take_workflow_input(cls): - assert not workflow.defn( - cls - ).__temporal_workflow_definition.init_fn_takes_workflow_input + takes_workflow_input, _ = _workflow_instance._init_fn_takes_workflow_input( + cls.__init__, workflow.defn(cls).__temporal_workflow_definition.run_fn + ) + assert takes_workflow_input is False class WorkflowInitOneParamTyped: @@ -595,9 +598,10 @@ async def run(self, a, *, arg=1) -> None: ], ) def test_workflow_init_good_takes_workflow_input(cls): - assert workflow.defn( - cls - ).__temporal_workflow_definition.init_fn_takes_workflow_input + takes_workflow_input, _ = _workflow_instance._init_fn_takes_workflow_input( + cls.__init__, workflow.defn(cls).__temporal_workflow_definition.run_fn + ) + assert takes_workflow_input is True class WorkflowInitBadExtraInitParamUntyped: @@ -708,5 +712,7 @@ async def run(self, aa: str, bb) -> None: ], ) def test_workflow_init_bad_takes_workflow_input(cls): - with pytest.raises(ValueError): - workflow.defn(cls) + takes_workflow_input, err = _workflow_instance._init_fn_takes_workflow_input( + cls.__init__, workflow.defn(cls).__temporal_workflow_definition.run_fn + ) + assert takes_workflow_input is None and err is not None From e9e2af4d69ceddc4881e7aad51ebe018971265c1 Mon Sep 17 00:00:00 2001 From: Dan Davison Date: Tue, 17 Sep 2024 09:37:36 -0400 Subject: [PATCH 07/17] s/set/used/ --- temporalio/workflow.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/temporalio/workflow.py b/temporalio/workflow.py index 92101eb8..16ce7976 100644 --- a/temporalio/workflow.py +++ b/temporalio/workflow.py @@ -146,7 +146,7 @@ def decorator(cls: ClassType) -> ClassType: def run(fn: CallableAsyncType) -> CallableAsyncType: """Decorator for the workflow run method. - This must be set on one and only one async method defined on the same class + This must be used on one and only one async method defined on the same class as ``@workflow.defn``. This can be defined on a base class method but must then be explicitly overridden and defined on the workflow class. @@ -238,7 +238,7 @@ def signal( ): """Decorator for a workflow signal method. - This is set on any async or non-async method that you wish to be called upon + This is used on any async or non-async method that you wish to be called upon receiving a signal. If a function overrides one with this decorator, it too must be decorated. @@ -309,7 +309,7 @@ def query( ): """Decorator for a workflow query method. - This is set on any non-async method that expects to handle a query. If a + This is used on any non-async method that expects to handle a query. If a function overrides one with this decorator, it too must be decorated. Query methods can only have positional parameters. Best practice for @@ -983,7 +983,7 @@ def update( ): """Decorator for a workflow update handler method. - This is set on any async or non-async method that you wish to be called upon + This is used on any async or non-async method that you wish to be called upon receiving an update. If a function overrides one with this decorator, it too must be decorated. From a739d50dbb49f63214edb048688893bb2527d2c9 Mon Sep 17 00:00:00 2001 From: Dan Davison Date: Tue, 17 Sep 2024 09:38:09 -0400 Subject: [PATCH 08/17] decorator --- temporalio/worker/_workflow_instance.py | 76 +---- temporalio/workflow.py | 28 ++ tests/test_workflow.py | 376 +----------------------- tests/worker/test_workflow.py | 3 +- 4 files changed, 32 insertions(+), 451 deletions(-) diff --git a/temporalio/worker/_workflow_instance.py b/temporalio/worker/_workflow_instance.py index 1d4c3fb9..1ca70a23 100644 --- a/temporalio/worker/_workflow_instance.py +++ b/temporalio/worker/_workflow_instance.py @@ -1635,12 +1635,7 @@ def _convert_payloads( def _instantiate_workflow_object(self) -> Any: if not self._workflow_input: raise RuntimeError("Expected workflow input. This is a Python SDK bug.") - takes_workflow_input, err = _init_fn_takes_workflow_input( - self._defn.cls.__init__, self._workflow_input.run_fn - ) - if takes_workflow_input is None: - raise RuntimeError(f"Cannot instantiate workflow: {err}") - elif takes_workflow_input: + if hasattr(self._defn.cls.__init__, "__temporal_workflow_init"): return self._defn.cls(*self._workflow_input.args) else: return self._defn.cls() @@ -2857,72 +2852,3 @@ def _make_unfinished_signal_handler_message( [{"name": name, "count": count} for name, count in names.most_common()] ) ) - - -def _init_fn_takes_workflow_input( - init_fn: Callable[..., None], - run_fn: Callable[..., Awaitable[Any]], -) -> Union[Tuple[bool, None], Tuple[None, str]]: - """Return (True, None) if Workflow input args should be passed to Workflow __init__.""" - # If the workflow class can be instantiated as cls(), i.e. without passing - # workflow input args, then we do that. Otherwise, if the parameters of - # __init__ exactly match those of the @workflow.run method, then we pass the - # workflow input args to __init__ when instantiating the workflow class. - # Otherwise, the workflow definition is invalid. - - if _unbound_method_can_be_called_without_args_when_bound(init_fn): - # The workflow cls can be instantiated as cls() - return False, None - else: - - def get_params(fn: Callable) -> List[inspect.Parameter]: - # Ignore name when comparing parameters (remaining fields are kind, - # default, and annotation). - return [ - p.replace(name="x") for p in inspect.signature(fn).parameters.values() - ] - - # We require that any type annotations present match exactly; i.e. we do - # not support any notion of subtype compatibility. - if get_params(init_fn) == get_params(run_fn): - # __init__ requires some args and has the same parameters as @workflow.run - return True, None - else: - return ( - None, - "__init__ parameters do not match @workflow.run method parameters", - ) - - -def _unbound_method_can_be_called_without_args_when_bound( - fn: Callable[..., Any], -) -> bool: - """Return True if the unbound method fn can be called without arguments when bound.""" - # An unbound method can be called without arguments when bound if the - # following are both true: - # - # - The first parameter is POSITIONAL_ONLY or POSITIONAL_OR_KEYWORD (this is - # the parameter conventionally named 'self') - # - # - All other POSITIONAL_OR_KEYWORD or KEYWORD_ONLY parameters have default - # values. - params = iter(inspect.signature(fn).parameters.values()) - self_param = next(params, None) - if not self_param or self_param.kind not in [ - inspect.Parameter.POSITIONAL_OR_KEYWORD, - inspect.Parameter.POSITIONAL_ONLY, - ]: - raise ValueError("Not an unbound method. This is a Python SDK bug.") - for p in params: - if p.kind == inspect.Parameter.POSITIONAL_ONLY: - return False - elif ( - p.kind - in [ - inspect.Parameter.POSITIONAL_OR_KEYWORD, - inspect.Parameter.KEYWORD_ONLY, - ] - and p.default is inspect.Parameter.empty - ): - return False - return True diff --git a/temporalio/workflow.py b/temporalio/workflow.py index 16ce7976..55009712 100644 --- a/temporalio/workflow.py +++ b/temporalio/workflow.py @@ -143,6 +143,34 @@ def decorator(cls: ClassType) -> ClassType: return decorator +def init( + init_fn: CallableType, +) -> CallableType: + """Decorator for the workflow init method. + + This may be used on the __init__ method of the workflow class to specify + that it accepts the same workflow input arguments as the ``@workflow.run`` + method. It may not be used on any other method. + + If used, the workflow will be instantiated as + ``MyWorkflow(**workflow_input_args)``. If not used, the workflow will be + instantiated as ``MyWorkflow()``. + + Note that the ``@workflow.run`` method is always called as + ``my_workflow.my_run_method(**workflow_input_args)``. If you use the + ``@workflow.init`` decorator, your __init__ method and your + ``@workflow.run`` method will typically have exactly the same parameters. + + Args: + init_fn: The __init__function to decorate. + """ + if init_fn.__name__ != "__init__": + raise ValueError("@workflow.init may only be used on the __init__ method") + + setattr(init_fn, "__temporal_workflow_init", True) + return init_fn + + def run(fn: CallableAsyncType) -> CallableAsyncType: """Decorator for the workflow run method. diff --git a/tests/test_workflow.py b/tests/test_workflow.py index eff46268..510fa18d 100644 --- a/tests/test_workflow.py +++ b/tests/test_workflow.py @@ -1,11 +1,9 @@ -from dataclasses import dataclass -from typing import Any, Callable, List, Protocol, Sequence +from typing import Any, Sequence import pytest from temporalio import workflow from temporalio.common import RawValue -from temporalio.worker import _workflow_instance class GoodDefnBase: @@ -344,375 +342,3 @@ def test_workflow_defn_dynamic_handler_warnings(): # We want to make sure they are reporting the right stacklevel warnings[0].filename.endswith("test_workflow.py") warnings[1].filename.endswith("test_workflow.py") - - -# -# workflow init tests -# - - -class CanBeCalledWithoutArgs: - def a(self): - pass - - def b(self, arg=1): - pass - - def c(self, /, arg=1): - pass - - def d(self=1): # type: ignore - pass - - def e(self=1, /, arg1=1, *, arg2=2): # type: ignore - pass - - -class CannotBeCalledWithoutArgs: - def a(self, arg): - pass - - def b(self, /, arg): - pass - - def c(self, arg1, arg2=2): - pass - - -@pytest.mark.parametrize( - "fn,expected", - [ - (CanBeCalledWithoutArgs.a, True), - (CanBeCalledWithoutArgs.b, True), - (CanBeCalledWithoutArgs.c, True), - (CanBeCalledWithoutArgs.d, True), - (CanBeCalledWithoutArgs.e, True), - (CannotBeCalledWithoutArgs.a, False), - (CannotBeCalledWithoutArgs.b, False), - (CannotBeCalledWithoutArgs.c, False), - ], -) -def test_unbound_method_can_be_called_without_args_when_bound( - fn: Callable[..., Any], expected: bool -): - assert ( - _workflow_instance._unbound_method_can_be_called_without_args_when_bound(fn) - == expected - ) - - -class NormalInitNoInit: - @workflow.run - async def run(self) -> None: - pass - - -class NormalInitNoInitOneParamRun: - @workflow.run - async def run(self, a: int) -> None: - pass - - -class NormalInitNoArgInitZeroParamRun: - def __init__(self) -> None: - pass - - @workflow.run - async def run(self) -> None: - pass - - -class NormalInitNoArgInitOneParamRun: - def __init__(self) -> None: - pass - - @workflow.run - async def run(self, a: int) -> None: - pass - - -@dataclass -class MyDataClass: - a: int - b: str - - -class NormalInitDataClassProtocol(Protocol): - @workflow.run - async def run(self, arg: MyDataClass) -> None: - pass - - -# Although the class is abstract, a user may decorate it with @workflow.defn, -# for example in order to set the name as the same as the child class, so that -# the client codebase need only import the interface. -class NormalInitAbstractBaseClass: - def __init__(self, arg_supplied_by_child_cls) -> None: - pass - - @workflow.run - async def run(self) -> None: ... - - -class NormalInitChildClass(NormalInitAbstractBaseClass): - def __init__(self) -> None: - super().__init__(arg_supplied_by_child_cls=None) - - @workflow.run - async def run(self) -> None: ... - - -class NormalInitSlashStarArgsStarStarKwargs: - def __init__(self, /, *args, **kwargs) -> None: - pass - - @workflow.run - async def run(self) -> None: - pass - - -class NormalInitStarArgsStarStarKwargs: - def __init__(self, *args, **kwargs) -> None: - pass - - @workflow.run - async def run(self) -> None: - pass - - -class NormalInitStarDefault: - def __init__(self, *, arg=1) -> None: - pass - - @workflow.run - async def run(self, arg) -> None: - pass - - -class NormalInitTypedDefault: - def __init__(self, a: int = 1) -> None: - pass - - @workflow.run - async def run(self, aa: int) -> None: - pass - - -@pytest.mark.parametrize( - "cls", - [ - NormalInitNoInit, - NormalInitNoInitOneParamRun, - NormalInitNoArgInitZeroParamRun, - NormalInitNoArgInitOneParamRun, - NormalInitDataClassProtocol, - NormalInitSlashStarArgsStarStarKwargs, - NormalInitStarArgsStarStarKwargs, - NormalInitStarDefault, - NormalInitTypedDefault, - # The base class is abstract, so will never be encountered by the worker - # during workflow task processing. - NormalInitChildClass, - ], -) -def test_workflow_init_good_does_not_take_workflow_input(cls): - takes_workflow_input, _ = _workflow_instance._init_fn_takes_workflow_input( - cls.__init__, workflow.defn(cls).__temporal_workflow_definition.run_fn - ) - assert takes_workflow_input is False - - -class WorkflowInitOneParamTyped: - def __init__(self, a: int) -> None: - pass - - @workflow.run - async def run(self, aa: int) -> None: - pass - - -class WorkflowInitTwoParamsTyped: - def __init__(self, a: int, b: str) -> None: - pass - - @workflow.run - async def run(self, aa: int, bb: str) -> None: - pass - - -class WorkflowInitOneParamUntyped: - def __init__(self, a) -> None: - pass - - @workflow.run - async def run(self, aa) -> None: - pass - - -class WorkflowInitTwoParamsUntyped: - def __init__(self, a, b) -> None: - pass - - @workflow.run - async def run(self, aa, bb) -> None: - pass - - -class WorkflowInitSlashStarArgsStarStarKwargs: - def __init__(self, /, a, *args, **kwargs) -> None: - pass - - @workflow.run - async def run(self, /, a, *args, **kwargs) -> None: - pass - - -class WorkflowInitStarArgsStarStarKwargs: - def __init__(self, *args, a, **kwargs) -> None: - pass - - @workflow.run - async def run(self, *args, a, **kwargs) -> None: - pass - - -class WorkflowInitStarDefault: - def __init__(self, a, *, arg=1) -> None: - pass - - @workflow.run - async def run(self, a, *, arg=1) -> None: - pass - - -@pytest.mark.parametrize( - "cls", - [ - WorkflowInitOneParamTyped, - WorkflowInitTwoParamsTyped, - WorkflowInitOneParamUntyped, - WorkflowInitTwoParamsUntyped, - WorkflowInitSlashStarArgsStarStarKwargs, - WorkflowInitStarArgsStarStarKwargs, - WorkflowInitStarDefault, - ], -) -def test_workflow_init_good_takes_workflow_input(cls): - takes_workflow_input, _ = _workflow_instance._init_fn_takes_workflow_input( - cls.__init__, workflow.defn(cls).__temporal_workflow_definition.run_fn - ) - assert takes_workflow_input is True - - -class WorkflowInitBadExtraInitParamUntyped: - def __init__(self, a) -> None: - pass - - @workflow.run - async def run(self) -> None: - pass - - -class WorkflowInitBadMismatchedParamUntyped: - def __init__(self, a) -> None: - pass - - @workflow.run - async def run(self, aa, bb) -> None: - pass - - -class WorkflowInitBadExtraInitParamTyped: - def __init__(self, a: int) -> None: - pass - - @workflow.run - async def run(self) -> None: - pass - - -class WorkflowInitBadMismatchedParamTyped: - def __init__(self, a: int) -> None: - pass - - @workflow.run - async def run(self, aa: int, bb: str) -> None: - pass - - -class WorkflowInitBadOneParamNoInitType: - def __init__(self, a) -> None: - pass - - @workflow.run - async def run(self, aa: int) -> None: - pass - - -class WorkflowInitBadGenericSubtype: - # The types must match exactly; we do not support any notion of subtype - # compatibility. - def __init__(self, a: List) -> None: - pass - - @workflow.run - async def run(self, aa: List[int]) -> None: - pass - - -class WorkflowInitBadOneParamNoRunType: - def __init__(self, a: int) -> None: - pass - - @workflow.run - async def run(self, aa) -> None: - pass - - -class WorkflowInitBadMissingDefault: - def __init__(self, a: int, b: int) -> None: - pass - - @workflow.run - async def run(self, aa: int, b: int = 1) -> None: - pass - - -class WorkflowInitBadInconsistentDefaults: - def __init__(self, a: int, b: int = 1) -> None: - pass - - @workflow.run - async def run(self, aa: int, b: int = 2) -> None: - pass - - -class WorkflowInitBadTwoParamsMixedTyping: - def __init__(self, a, b: str) -> None: - pass - - @workflow.run - async def run(self, aa: str, bb) -> None: - pass - - -@pytest.mark.parametrize( - "cls", - [ - WorkflowInitBadExtraInitParamUntyped, - WorkflowInitBadMismatchedParamUntyped, - WorkflowInitBadExtraInitParamTyped, - WorkflowInitBadMismatchedParamTyped, - WorkflowInitBadTwoParamsMixedTyping, - WorkflowInitBadOneParamNoInitType, - WorkflowInitBadOneParamNoRunType, - WorkflowInitBadGenericSubtype, - WorkflowInitBadMissingDefault, - WorkflowInitBadInconsistentDefaults, - ], -) -def test_workflow_init_bad_takes_workflow_input(cls): - takes_workflow_input, err = _workflow_instance._init_fn_takes_workflow_input( - cls.__init__, workflow.defn(cls).__temporal_workflow_definition.run_fn - ) - assert takes_workflow_input is None and err is not None diff --git a/tests/worker/test_workflow.py b/tests/worker/test_workflow.py index 7abbf4f3..fe78fd5a 100644 --- a/tests/worker/test_workflow.py +++ b/tests/worker/test_workflow.py @@ -6057,7 +6057,8 @@ async def test_workflow_init(client: Client): @workflow.defn class WorkflowInitUpdateInFirstWFTWorkflow: - def __init__(self, arg: str) -> None: + @workflow.init + def __init__(self, arg: str = "value from parameter default") -> None: self.init_arg = arg @workflow.update From 16395b4ddbd8b242dd5451607dfded79178c8146 Mon Sep 17 00:00:00 2001 From: Dan Davison Date: Tue, 17 Sep 2024 10:09:18 -0400 Subject: [PATCH 09/17] Update test --- tests/worker/test_workflow.py | 194 ++++++++++++++++++++++++++++------ 1 file changed, 161 insertions(+), 33 deletions(-) diff --git a/tests/worker/test_workflow.py b/tests/worker/test_workflow.py index fe78fd5a..56c4d2d4 100644 --- a/tests/worker/test_workflow.py +++ b/tests/worker/test_workflow.py @@ -6035,63 +6035,191 @@ async def test_activity_retry_delay(client: Client): @workflow.defn -class WorkflowInitWorkflow: - def __init__(self, arg: str) -> None: - self.init_arg = arg +class WorkflowWithoutInit: + value = "from class attribute" + _expected_update_result = "from class attribute" + + @workflow.update + async def my_update(self) -> str: + return self.value @workflow.run - async def run(self, _: str): - return f"hello, {self.init_arg}" + async def run(self, _: str) -> str: + self.value = "set in run method" + return self.value -async def test_workflow_init(client: Client): - async with new_worker(client, WorkflowInitWorkflow) as worker: - workflow_result = await client.execute_workflow( - WorkflowInitWorkflow.run, - "world", - id=str(uuid.uuid4()), - task_queue=worker.task_queue, - ) - assert workflow_result == "hello, world" +@workflow.defn +class WorkflowWithWorkflowInit: + _expected_update_result = "workflow input value" + + @workflow.init + def __init__(self, arg: str = "from parameter default") -> None: + self.value = arg + + @workflow.update + async def my_update(self) -> str: + return self.value + + @workflow.run + async def run(self, _: str) -> str: + self.value = "set in run method" + return self.value @workflow.defn -class WorkflowInitUpdateInFirstWFTWorkflow: +class WorkflowWithNonWorkflowInitInit: + _expected_update_result = "from parameter default" + + def __init__(self, arg: str = "from parameter default") -> None: + self.value = arg + + @workflow.update + async def my_update(self) -> str: + return self.value + + @workflow.run + async def run(self, _: str) -> str: + self.value = "set in run method" + return self.value + + +@workflow.defn(name="MyWorkflow") +class WorkflowWithWorkflowInitBaseDecorated: + use_workflow_init = True + @workflow.init - def __init__(self, arg: str = "value from parameter default") -> None: - self.init_arg = arg + def __init__( + self, required_param_that_will_be_supplied_by_child_init_method + ) -> None: + self.value = required_param_that_will_be_supplied_by_child_init_method + + if use_workflow_init: + __init__ = workflow.init(__init__) + + @workflow.run + async def run(self, _: str): ... + + @workflow.update + async def my_update(self) -> str: ... + + +class WorkflowWithWorkflowInitBaseUndecorated(WorkflowWithWorkflowInitBaseDecorated): + # The base class does not need the @workflow.init decorator + use_workflow_init = False + + +@workflow.defn(name="MyWorkflow") +class WorkflowWithWorkflowInitChild(WorkflowWithWorkflowInitBaseDecorated): + use_workflow_init = True + _expected_update_result = "workflow input value" + + def __init__(self, arg: str = "from parameter default") -> None: + super().__init__("from child __init__") + self.value = arg + + if use_workflow_init: + __init__ = workflow.init(__init__) + + @workflow.run + async def run(self, _: str) -> str: + self.value = "set in run method" + return self.value @workflow.update async def my_update(self) -> str: - return self.init_arg + return self.value + + +@workflow.defn(name="MyWorkflow") +class WorkflowWithWorkflowInitChildNoWorkflowInit( + WorkflowWithWorkflowInitBaseDecorated +): + use_workflow_init = False + _expected_update_result = "from parameter default" + + def __init__(self, arg: str = "from parameter default") -> None: + super().__init__("from child __init__") + self.value = arg + + if use_workflow_init: + __init__ = workflow.init(__init__) @workflow.run - async def run(self, _: str): - self.init_arg = "value set in run method" - return self.init_arg + async def run(self, _: str) -> str: + self.value = "set in run method" + return self.value + + @workflow.update + async def my_update(self) -> str: + return self.value + +@pytest.mark.parametrize( + ["client_cls", "worker_cls"], + [ + (WorkflowWithoutInit, WorkflowWithoutInit), + (WorkflowWithNonWorkflowInitInit, WorkflowWithNonWorkflowInitInit), + (WorkflowWithWorkflowInit, WorkflowWithWorkflowInit), + (WorkflowWithWorkflowInitBaseDecorated, WorkflowWithWorkflowInitChild), + (WorkflowWithWorkflowInitBaseUndecorated, WorkflowWithWorkflowInitChild), + ( + WorkflowWithWorkflowInitBaseUndecorated, + WorkflowWithWorkflowInitChildNoWorkflowInit, + ), + ], +) +async def test_update_in_first_wft_sees_workflow_init( + client: Client, client_cls: Type, worker_cls: Type +): + """ + Test how @workflow.init affects what an update in the first WFT sees. -async def test_update_in_first_wft_sees_workflow_init(client: Client): - # Before running the worker, start a workflow, send the update, and wait - # until update is admitted. + Such an update is guaranteed to start executing before the main workflow + coroutine. The update should see the side effects of the __init__ method if + and only if @workflow.init is in effect. + """ + # This test must ensure that the update is in the first WFT. To do so, + # before running the worker, we start the workflow, send the update, and + # wait until the update is admitted. task_queue = "task-queue" update_id = "update-id" wf_handle = await client.start_workflow( - WorkflowInitUpdateInFirstWFTWorkflow.run, + client_cls.run, "workflow input value", id=str(uuid.uuid4()), task_queue=task_queue, ) update_task = asyncio.create_task( - wf_handle.execute_update( - WorkflowInitUpdateInFirstWFTWorkflow.my_update, id=update_id - ) + wf_handle.execute_update(client_cls.my_update, id=update_id) ) await assert_eq_eventually( True, lambda: workflow_update_exists(client, wf_handle.id, update_id) ) - async with new_worker( - client, WorkflowInitUpdateInFirstWFTWorkflow, task_queue=task_queue - ): - assert await update_task == "workflow input value" - assert await wf_handle.result() == "value set in run method" + # When the worker starts polling it will receive a first WFT containing the + # update, in addition to the start_workflow job. + async with new_worker(client, worker_cls, task_queue=task_queue): + assert await update_task == worker_cls._expected_update_result + assert await wf_handle.result() == "set in run method" + + +@workflow.defn +class WorkflowRunSeesWorkflowInitWorkflow: + @workflow.init + def __init__(self, arg: str) -> None: + self.value = arg + + @workflow.run + async def run(self, _: str): + return f"hello, {self.value}" + + +async def test_workflow_run_sees_workflow_init(client: Client): + async with new_worker(client, WorkflowRunSeesWorkflowInitWorkflow) as worker: + workflow_result = await client.execute_workflow( + WorkflowRunSeesWorkflowInitWorkflow.run, + "world", + id=str(uuid.uuid4()), + task_queue=worker.task_queue, + ) + assert workflow_result == "hello, world" From 84958a6eddb65925ed4a8b1a9e89374155f0b231 Mon Sep 17 00:00:00 2001 From: Dan Davison Date: Wed, 18 Sep 2024 11:01:18 -0400 Subject: [PATCH 10/17] Tweak docstring --- temporalio/workflow.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/temporalio/workflow.py b/temporalio/workflow.py index 55009712..8f1a2c44 100644 --- a/temporalio/workflow.py +++ b/temporalio/workflow.py @@ -157,9 +157,9 @@ def init( instantiated as ``MyWorkflow()``. Note that the ``@workflow.run`` method is always called as - ``my_workflow.my_run_method(**workflow_input_args)``. If you use the - ``@workflow.init`` decorator, your __init__ method and your - ``@workflow.run`` method will typically have exactly the same parameters. + ``my_workflow.my_run_method(**workflow_input_args)``. Therefore, if you use + the ``@workflow.init`` decorator, the parameter list of your __init__ and + ``@workflow.run`` methods will usually be identical. Args: init_fn: The __init__function to decorate. From 1f81d84d384b048e29d16324df816c141716f4a8 Mon Sep 17 00:00:00 2001 From: Dan Davison Date: Wed, 18 Sep 2024 12:01:02 -0400 Subject: [PATCH 11/17] Change update method inheritance for mypy --- tests/worker/test_workflow.py | 11 ++--------- 1 file changed, 2 insertions(+), 9 deletions(-) diff --git a/tests/worker/test_workflow.py b/tests/worker/test_workflow.py index 56c4d2d4..ee20120d 100644 --- a/tests/worker/test_workflow.py +++ b/tests/worker/test_workflow.py @@ -6101,7 +6101,8 @@ def __init__( async def run(self, _: str): ... @workflow.update - async def my_update(self) -> str: ... + async def my_update(self) -> str: + return self.value class WorkflowWithWorkflowInitBaseUndecorated(WorkflowWithWorkflowInitBaseDecorated): @@ -6126,10 +6127,6 @@ async def run(self, _: str) -> str: self.value = "set in run method" return self.value - @workflow.update - async def my_update(self) -> str: - return self.value - @workflow.defn(name="MyWorkflow") class WorkflowWithWorkflowInitChildNoWorkflowInit( @@ -6150,10 +6147,6 @@ async def run(self, _: str) -> str: self.value = "set in run method" return self.value - @workflow.update - async def my_update(self) -> str: - return self.value - @pytest.mark.parametrize( ["client_cls", "worker_cls"], From d8b5dffa61c1dc7960c1db5c5bfed54ea234c94a Mon Sep 17 00:00:00 2001 From: Dan Davison Date: Mon, 23 Sep 2024 11:34:45 -0400 Subject: [PATCH 12/17] Require identical parameters --- temporalio/workflow.py | 21 ++++++++++ tests/worker/test_workflow.py | 76 ++--------------------------------- 2 files changed, 24 insertions(+), 73 deletions(-) diff --git a/temporalio/workflow.py b/temporalio/workflow.py index 8f1a2c44..4f6c7b7f 100644 --- a/temporalio/workflow.py +++ b/temporalio/workflow.py @@ -1335,6 +1335,7 @@ def _apply_to_class( issues: List[str] = [] # Collect run fn and all signal/query/update fns + init_fn: Optional[Callable[..., None]] = None run_fn: Optional[Callable[..., Awaitable[Any]]] = None seen_run_attr = False signals: Dict[Optional[str], _SignalDefinition] = {} @@ -1381,6 +1382,8 @@ def _apply_to_class( ) else: queries[query_defn.name] = query_defn + elif name == "__init__" and hasattr(member, "__temporal_workflow_init"): + init_fn = member elif isinstance(member, UpdateMethodMultiParam): update_defn = member._defn if update_defn.name in updates: @@ -1433,6 +1436,11 @@ def _apply_to_class( if not seen_run_attr: issues.append("Missing @workflow.run method") + if init_fn and run_fn: + if not _parameters_identical_up_to_naming(init_fn, run_fn): + issues.append( + "@workflow.init and @workflow.run method parameters do not match" + ) if issues: if len(issues) == 1: raise ValueError(f"Invalid workflow class: {issues[0]}") @@ -1471,6 +1479,19 @@ def __post_init__(self) -> None: object.__setattr__(self, "ret_type", ret_type) +def _parameters_identical_up_to_naming(fn1: Callable, fn2: Callable) -> bool: + """Return True if the functions have identical parameter lists, ignoring parameter names.""" + + def params(fn: Callable) -> List[inspect.Parameter]: + # Ignore name when comparing parameters (remaining fields are kind, + # default, and annotation). + return [p.replace(name="x") for p in inspect.signature(fn).parameters.values()] + + # We require that any type annotations present match exactly; i.e. we do + # not support any notion of subtype compatibility. + return params(fn1) == params(fn2) + + # Async safe version of partial def _bind_method(obj: Any, fn: Callable[..., Any]) -> Callable[..., Any]: # Curry instance on the definition function since that represents an diff --git a/tests/worker/test_workflow.py b/tests/worker/test_workflow.py index ee20120d..15afe8c4 100644 --- a/tests/worker/test_workflow.py +++ b/tests/worker/test_workflow.py @@ -3902,7 +3902,7 @@ def matches_metric_line( return False # Must have labels (don't escape for this test) for k, v in at_least_labels.items(): - if not f'{k}="{v}"' in line: + if f'{k}="{v}"' not in line: return False return line.endswith(f" {value}") @@ -4856,7 +4856,7 @@ async def assert_scenario( update_scenario: Optional[FailureTypesScenario] = None, ) -> None: logging.debug( - f"Asserting scenario %s", + "Asserting scenario %s", { "workflow": workflow, "expect_task_fail": expect_task_fail, @@ -6054,7 +6054,7 @@ class WorkflowWithWorkflowInit: _expected_update_result = "workflow input value" @workflow.init - def __init__(self, arg: str = "from parameter default") -> None: + def __init__(self, arg: str) -> None: self.value = arg @workflow.update @@ -6084,82 +6084,12 @@ async def run(self, _: str) -> str: return self.value -@workflow.defn(name="MyWorkflow") -class WorkflowWithWorkflowInitBaseDecorated: - use_workflow_init = True - - @workflow.init - def __init__( - self, required_param_that_will_be_supplied_by_child_init_method - ) -> None: - self.value = required_param_that_will_be_supplied_by_child_init_method - - if use_workflow_init: - __init__ = workflow.init(__init__) - - @workflow.run - async def run(self, _: str): ... - - @workflow.update - async def my_update(self) -> str: - return self.value - - -class WorkflowWithWorkflowInitBaseUndecorated(WorkflowWithWorkflowInitBaseDecorated): - # The base class does not need the @workflow.init decorator - use_workflow_init = False - - -@workflow.defn(name="MyWorkflow") -class WorkflowWithWorkflowInitChild(WorkflowWithWorkflowInitBaseDecorated): - use_workflow_init = True - _expected_update_result = "workflow input value" - - def __init__(self, arg: str = "from parameter default") -> None: - super().__init__("from child __init__") - self.value = arg - - if use_workflow_init: - __init__ = workflow.init(__init__) - - @workflow.run - async def run(self, _: str) -> str: - self.value = "set in run method" - return self.value - - -@workflow.defn(name="MyWorkflow") -class WorkflowWithWorkflowInitChildNoWorkflowInit( - WorkflowWithWorkflowInitBaseDecorated -): - use_workflow_init = False - _expected_update_result = "from parameter default" - - def __init__(self, arg: str = "from parameter default") -> None: - super().__init__("from child __init__") - self.value = arg - - if use_workflow_init: - __init__ = workflow.init(__init__) - - @workflow.run - async def run(self, _: str) -> str: - self.value = "set in run method" - return self.value - - @pytest.mark.parametrize( ["client_cls", "worker_cls"], [ (WorkflowWithoutInit, WorkflowWithoutInit), (WorkflowWithNonWorkflowInitInit, WorkflowWithNonWorkflowInitInit), (WorkflowWithWorkflowInit, WorkflowWithWorkflowInit), - (WorkflowWithWorkflowInitBaseDecorated, WorkflowWithWorkflowInitChild), - (WorkflowWithWorkflowInitBaseUndecorated, WorkflowWithWorkflowInitChild), - ( - WorkflowWithWorkflowInitBaseUndecorated, - WorkflowWithWorkflowInitChildNoWorkflowInit, - ), ], ) async def test_update_in_first_wft_sees_workflow_init( From 0557b28e41e2b8998e394a19b227567c9907c4ae Mon Sep 17 00:00:00 2001 From: Dan Davison Date: Mon, 23 Sep 2024 17:03:43 -0400 Subject: [PATCH 13/17] Update docstring --- temporalio/workflow.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/temporalio/workflow.py b/temporalio/workflow.py index 4f6c7b7f..11aa60a4 100644 --- a/temporalio/workflow.py +++ b/temporalio/workflow.py @@ -157,9 +157,9 @@ def init( instantiated as ``MyWorkflow()``. Note that the ``@workflow.run`` method is always called as - ``my_workflow.my_run_method(**workflow_input_args)``. Therefore, if you use - the ``@workflow.init`` decorator, the parameter list of your __init__ and - ``@workflow.run`` methods will usually be identical. + ``my_workflow.my_run_method(**workflow_input_args)``. If you use the + ``@workflow.init`` decorator, the parameter list of your __init__ and + ``@workflow.run`` methods must be identical. Args: init_fn: The __init__function to decorate. From 0c108884990304955307dd0bd59ae2fe61aa9add Mon Sep 17 00:00:00 2001 From: Dan Davison Date: Mon, 23 Sep 2024 18:05:49 -0400 Subject: [PATCH 14/17] Add unit test of _parameters_identical_up_to_naming --- tests/test_workflow.py | 53 +++++++++++++++++++++++++++++++++++++++++- 1 file changed, 52 insertions(+), 1 deletion(-) diff --git a/tests/test_workflow.py b/tests/test_workflow.py index 510fa18d..c724ba66 100644 --- a/tests/test_workflow.py +++ b/tests/test_workflow.py @@ -1,4 +1,6 @@ -from typing import Any, Sequence +import inspect +import itertools +from typing import Sequence import pytest @@ -342,3 +344,52 @@ def test_workflow_defn_dynamic_handler_warnings(): # We want to make sure they are reporting the right stacklevel warnings[0].filename.endswith("test_workflow.py") warnings[1].filename.endswith("test_workflow.py") + + +class _TestParametersIdenticalUpToNaming: + def a1(self, a): + pass + + def a2(self, b): + pass + + def b1(self, a: int): + pass + + def b2(self, b: int) -> str: + return "" + + def c1(self, a1: int, a2: str) -> str: + return "" + + def c2(self, b1: int, b2: str) -> int: + return 0 + + def d1(self, a1, a2: str) -> None: + pass + + def d2(self, b1, b2: str) -> str: + return "" + + def e1(self, a1, a2: str = "") -> None: + return None + + def e2(self, b1, b2: str = "") -> str: + return "" + + def f1(self, a1, a2: str = "a") -> None: + return None + + +def test_parameters_identical_up_to_naming(): + fns = [ + f + for _, f in inspect.getmembers(_TestParametersIdenticalUpToNaming) + if inspect.isfunction(f) + ] + for f1, f2 in itertools.combinations(fns, 2): + name1, name2 = f1.__name__, f2.__name__ + expect_equal = name1[0] == name2[0] + assert ( + workflow._parameters_identical_up_to_naming(f1, f2) == (expect_equal) + ), f"expected {name1} and {name2} parameters{" " if expect_equal else " not "}to compare equal" From 95f06301d6d1b0a8bf291b667c50251a2630ca9c Mon Sep 17 00:00:00 2001 From: Dan Davison Date: Tue, 24 Sep 2024 11:11:02 -0400 Subject: [PATCH 15/17] Document @workflow.init --- README.md | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/README.md b/README.md index aa4202c4..34eed9c3 100644 --- a/README.md +++ b/README.md @@ -575,6 +575,11 @@ Here are the decorators that can be applied: * The method's arguments are the workflow's arguments * The first parameter must be `self`, followed by positional arguments. Best practice is to only take a single argument that is an object/dataclass of fields that can be added to as needed. +* `@workflow.init` - Specifies that the `__init__` method accepts the workflow's arguments. + * If present, may only be applied to the `__init__` method, the parameters of which must then be identical to those of + the `@workflow.run` method. + * The purpose of this decorator is to allow operations involving workflow arguments to be performed in the `__init__` + method, before any signal or update handler has a chance to execute. * `@workflow.signal` - Defines a method as a signal * Can be defined on an `async` or non-`async` function at any hierarchy depth, but if decorated method is overridden, the override must also be decorated From d87d7b119696a1c3f4ba7bb5cee265d2a4ed514e Mon Sep 17 00:00:00 2001 From: Dan Davison Date: Tue, 24 Sep 2024 11:39:17 -0400 Subject: [PATCH 16/17] Test that @workflow.init may only be used on __init__ --- tests/test_workflow.py | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) diff --git a/tests/test_workflow.py b/tests/test_workflow.py index c724ba66..24763caa 100644 --- a/tests/test_workflow.py +++ b/tests/test_workflow.py @@ -393,3 +393,19 @@ def test_parameters_identical_up_to_naming(): assert ( workflow._parameters_identical_up_to_naming(f1, f2) == (expect_equal) ), f"expected {name1} and {name2} parameters{" " if expect_equal else " not "}to compare equal" + + +@workflow.defn +class BadWorkflowInit: + def not__init__(self): + pass + + @workflow.run + async def run(self): + pass + + +def test_workflow_init_not__init__(): + with pytest.raises(ValueError) as err: + workflow.init(BadWorkflowInit.not__init__) + assert "@workflow.init may only be used on the __init__ method" in str(err.value) From fcf2d7b70082431be9516df4ea6f5a23d9ba0e6f Mon Sep 17 00:00:00 2001 From: Dan Davison Date: Tue, 24 Sep 2024 12:11:27 -0400 Subject: [PATCH 17/17] Lint --- tests/test_workflow.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_workflow.py b/tests/test_workflow.py index 24763caa..d4a5b45e 100644 --- a/tests/test_workflow.py +++ b/tests/test_workflow.py @@ -392,7 +392,7 @@ def test_parameters_identical_up_to_naming(): expect_equal = name1[0] == name2[0] assert ( workflow._parameters_identical_up_to_naming(f1, f2) == (expect_equal) - ), f"expected {name1} and {name2} parameters{" " if expect_equal else " not "}to compare equal" + ), f"expected {name1} and {name2} parameters{' ' if expect_equal else ' not '}to compare equal" @workflow.defn