diff --git a/README.md b/README.md index aa4202c4..34eed9c3 100644 --- a/README.md +++ b/README.md @@ -575,6 +575,11 @@ Here are the decorators that can be applied: * The method's arguments are the workflow's arguments * The first parameter must be `self`, followed by positional arguments. Best practice is to only take a single argument that is an object/dataclass of fields that can be added to as needed. +* `@workflow.init` - Specifies that the `__init__` method accepts the workflow's arguments. + * If present, may only be applied to the `__init__` method, the parameters of which must then be identical to those of + the `@workflow.run` method. + * The purpose of this decorator is to allow operations involving workflow arguments to be performed in the `__init__` + method, before any signal or update handler has a chance to execute. * `@workflow.signal` - Defines a method as a signal * Can be defined on an `async` or non-`async` function at any hierarchy depth, but if decorated method is overridden, the override must also be decorated diff --git a/temporalio/worker/_workflow_instance.py b/temporalio/worker/_workflow_instance.py index f897e02e..1ca70a23 100644 --- a/temporalio/worker/_workflow_instance.py +++ b/temporalio/worker/_workflow_instance.py @@ -201,6 +201,7 @@ def __init__(self, det: WorkflowInstanceDetails) -> None: self._payload_converter = det.payload_converter_class() self._failure_converter = det.failure_converter_class() self._defn = det.defn + self._workflow_input: Optional[ExecuteWorkflowInput] = None self._info = det.info self._extern_functions = det.extern_functions self._disable_eager_activity_execution = det.disable_eager_activity_execution @@ -318,8 +319,9 @@ def get_thread_id(self) -> Optional[int]: return self._current_thread_id #### Activation functions #### - # These are in alphabetical order and besides "activate", all other calls - # are "_apply_" + the job field name. + # These are in alphabetical order and besides "activate", and + # "_make_workflow_input", all other calls are "_apply_" + the job field + # name. def activate( self, act: temporalio.bridge.proto.workflow_activation.WorkflowActivation @@ -342,6 +344,7 @@ def activate( try: # Split into job sets with patches, then signals + updates, then # non-queries, then queries + start_job = None job_sets: List[ List[temporalio.bridge.proto.workflow_activation.WorkflowActivationJob] ] = [[], [], [], []] @@ -351,10 +354,15 @@ def activate( elif job.HasField("signal_workflow") or job.HasField("do_update"): job_sets[1].append(job) elif not job.HasField("query_workflow"): + if job.HasField("start_workflow"): + start_job = job.start_workflow job_sets[2].append(job) else: job_sets[3].append(job) + if start_job: + self._workflow_input = self._make_workflow_input(start_job) + # Apply every job set, running after each set for index, job_set in enumerate(job_sets): if not job_set: @@ -863,34 +871,41 @@ async def run_workflow(input: ExecuteWorkflowInput) -> None: return raise + if not self._workflow_input: + raise RuntimeError( + "Expected workflow input to be set. This is an SDK Python bug." + ) + self._primary_task = self.create_task( + self._run_top_level_workflow_function(run_workflow(self._workflow_input)), + name="run", + ) + + def _apply_update_random_seed( + self, job: temporalio.bridge.proto.workflow_activation.UpdateRandomSeed + ) -> None: + self._random.seed(job.randomness_seed) + + def _make_workflow_input( + self, start_job: temporalio.bridge.proto.workflow_activation.StartWorkflow + ) -> ExecuteWorkflowInput: # Set arg types, using raw values for dynamic arg_types = self._defn.arg_types if not self._defn.name: # Dynamic is just the raw value for each input value - arg_types = [temporalio.common.RawValue] * len(job.arguments) - args = self._convert_payloads(job.arguments, arg_types) + arg_types = [temporalio.common.RawValue] * len(start_job.arguments) + args = self._convert_payloads(start_job.arguments, arg_types) # Put args in a list if dynamic if not self._defn.name: args = [args] - # Schedule it - input = ExecuteWorkflowInput( + return ExecuteWorkflowInput( type=self._defn.cls, # TODO(cretz): Remove cast when https://github.com/python/mypy/issues/5485 fixed run_fn=cast(Callable[..., Awaitable[Any]], self._defn.run_fn), args=args, - headers=job.headers, - ) - self._primary_task = self.create_task( - self._run_top_level_workflow_function(run_workflow(input)), - name="run", + headers=start_job.headers, ) - def _apply_update_random_seed( - self, job: temporalio.bridge.proto.workflow_activation.UpdateRandomSeed - ) -> None: - self._random.seed(job.randomness_seed) - #### _Runtime direct workflow call overrides #### # These are in alphabetical order and all start with "workflow_". @@ -1617,6 +1632,14 @@ def _convert_payloads( except Exception as err: raise RuntimeError("Failed decoding arguments") from err + def _instantiate_workflow_object(self) -> Any: + if not self._workflow_input: + raise RuntimeError("Expected workflow input. This is a Python SDK bug.") + if hasattr(self._defn.cls.__init__, "__temporal_workflow_init"): + return self._defn.cls(*self._workflow_input.args) + else: + return self._defn.cls() + def _is_workflow_failure_exception(self, err: BaseException) -> bool: # An exception is a failure instead of a task fail if it's already a # failure error or if it is an instance of any of the failure types in @@ -1752,7 +1775,7 @@ def _run_once(self, *, check_conditions: bool) -> None: # We instantiate the workflow class _inside_ here because __init__ # needs to run with this event loop set if not self._object: - self._object = self._defn.cls() + self._object = self._instantiate_workflow_object() # Run while there is anything ready while self._ready: diff --git a/temporalio/workflow.py b/temporalio/workflow.py index 6ee25ad2..11aa60a4 100644 --- a/temporalio/workflow.py +++ b/temporalio/workflow.py @@ -143,10 +143,38 @@ def decorator(cls: ClassType) -> ClassType: return decorator +def init( + init_fn: CallableType, +) -> CallableType: + """Decorator for the workflow init method. + + This may be used on the __init__ method of the workflow class to specify + that it accepts the same workflow input arguments as the ``@workflow.run`` + method. It may not be used on any other method. + + If used, the workflow will be instantiated as + ``MyWorkflow(**workflow_input_args)``. If not used, the workflow will be + instantiated as ``MyWorkflow()``. + + Note that the ``@workflow.run`` method is always called as + ``my_workflow.my_run_method(**workflow_input_args)``. If you use the + ``@workflow.init`` decorator, the parameter list of your __init__ and + ``@workflow.run`` methods must be identical. + + Args: + init_fn: The __init__function to decorate. + """ + if init_fn.__name__ != "__init__": + raise ValueError("@workflow.init may only be used on the __init__ method") + + setattr(init_fn, "__temporal_workflow_init", True) + return init_fn + + def run(fn: CallableAsyncType) -> CallableAsyncType: """Decorator for the workflow run method. - This must be set on one and only one async method defined on the same class + This must be used on one and only one async method defined on the same class as ``@workflow.defn``. This can be defined on a base class method but must then be explicitly overridden and defined on the workflow class. @@ -238,7 +266,7 @@ def signal( ): """Decorator for a workflow signal method. - This is set on any async or non-async method that you wish to be called upon + This is used on any async or non-async method that you wish to be called upon receiving a signal. If a function overrides one with this decorator, it too must be decorated. @@ -309,7 +337,7 @@ def query( ): """Decorator for a workflow query method. - This is set on any non-async method that expects to handle a query. If a + This is used on any non-async method that expects to handle a query. If a function overrides one with this decorator, it too must be decorated. Query methods can only have positional parameters. Best practice for @@ -983,7 +1011,7 @@ def update( ): """Decorator for a workflow update handler method. - This is set on any async or non-async method that you wish to be called upon + This is used on any async or non-async method that you wish to be called upon receiving an update. If a function overrides one with this decorator, it too must be decorated. @@ -1307,13 +1335,13 @@ def _apply_to_class( issues: List[str] = [] # Collect run fn and all signal/query/update fns - members = inspect.getmembers(cls) + init_fn: Optional[Callable[..., None]] = None run_fn: Optional[Callable[..., Awaitable[Any]]] = None seen_run_attr = False signals: Dict[Optional[str], _SignalDefinition] = {} queries: Dict[Optional[str], _QueryDefinition] = {} updates: Dict[Optional[str], _UpdateDefinition] = {} - for name, member in members: + for name, member in inspect.getmembers(cls): if hasattr(member, "__temporal_workflow_run"): seen_run_attr = True if not _is_unbound_method_on_cls(member, cls): @@ -1354,6 +1382,8 @@ def _apply_to_class( ) else: queries[query_defn.name] = query_defn + elif name == "__init__" and hasattr(member, "__temporal_workflow_init"): + init_fn = member elif isinstance(member, UpdateMethodMultiParam): update_defn = member._defn if update_defn.name in updates: @@ -1406,9 +1436,14 @@ def _apply_to_class( if not seen_run_attr: issues.append("Missing @workflow.run method") - if len(issues) == 1: - raise ValueError(f"Invalid workflow class: {issues[0]}") - elif issues: + if init_fn and run_fn: + if not _parameters_identical_up_to_naming(init_fn, run_fn): + issues.append( + "@workflow.init and @workflow.run method parameters do not match" + ) + if issues: + if len(issues) == 1: + raise ValueError(f"Invalid workflow class: {issues[0]}") raise ValueError( f"Invalid workflow class for {len(issues)} reasons: {', '.join(issues)}" ) @@ -1444,6 +1479,19 @@ def __post_init__(self) -> None: object.__setattr__(self, "ret_type", ret_type) +def _parameters_identical_up_to_naming(fn1: Callable, fn2: Callable) -> bool: + """Return True if the functions have identical parameter lists, ignoring parameter names.""" + + def params(fn: Callable) -> List[inspect.Parameter]: + # Ignore name when comparing parameters (remaining fields are kind, + # default, and annotation). + return [p.replace(name="x") for p in inspect.signature(fn).parameters.values()] + + # We require that any type annotations present match exactly; i.e. we do + # not support any notion of subtype compatibility. + return params(fn1) == params(fn2) + + # Async safe version of partial def _bind_method(obj: Any, fn: Callable[..., Any]) -> Callable[..., Any]: # Curry instance on the definition function since that represents an diff --git a/tests/test_workflow.py b/tests/test_workflow.py index 510fa18d..d4a5b45e 100644 --- a/tests/test_workflow.py +++ b/tests/test_workflow.py @@ -1,4 +1,6 @@ -from typing import Any, Sequence +import inspect +import itertools +from typing import Sequence import pytest @@ -342,3 +344,68 @@ def test_workflow_defn_dynamic_handler_warnings(): # We want to make sure they are reporting the right stacklevel warnings[0].filename.endswith("test_workflow.py") warnings[1].filename.endswith("test_workflow.py") + + +class _TestParametersIdenticalUpToNaming: + def a1(self, a): + pass + + def a2(self, b): + pass + + def b1(self, a: int): + pass + + def b2(self, b: int) -> str: + return "" + + def c1(self, a1: int, a2: str) -> str: + return "" + + def c2(self, b1: int, b2: str) -> int: + return 0 + + def d1(self, a1, a2: str) -> None: + pass + + def d2(self, b1, b2: str) -> str: + return "" + + def e1(self, a1, a2: str = "") -> None: + return None + + def e2(self, b1, b2: str = "") -> str: + return "" + + def f1(self, a1, a2: str = "a") -> None: + return None + + +def test_parameters_identical_up_to_naming(): + fns = [ + f + for _, f in inspect.getmembers(_TestParametersIdenticalUpToNaming) + if inspect.isfunction(f) + ] + for f1, f2 in itertools.combinations(fns, 2): + name1, name2 = f1.__name__, f2.__name__ + expect_equal = name1[0] == name2[0] + assert ( + workflow._parameters_identical_up_to_naming(f1, f2) == (expect_equal) + ), f"expected {name1} and {name2} parameters{' ' if expect_equal else ' not '}to compare equal" + + +@workflow.defn +class BadWorkflowInit: + def not__init__(self): + pass + + @workflow.run + async def run(self): + pass + + +def test_workflow_init_not__init__(): + with pytest.raises(ValueError) as err: + workflow.init(BadWorkflowInit.not__init__) + assert "@workflow.init may only be used on the __init__ method" in str(err.value) diff --git a/tests/worker/test_workflow.py b/tests/worker/test_workflow.py index 9f007a22..15afe8c4 100644 --- a/tests/worker/test_workflow.py +++ b/tests/worker/test_workflow.py @@ -3902,7 +3902,7 @@ def matches_metric_line( return False # Must have labels (don't escape for this test) for k, v in at_least_labels.items(): - if not f'{k}="{v}"' in line: + if f'{k}="{v}"' not in line: return False return line.endswith(f" {value}") @@ -4856,7 +4856,7 @@ async def assert_scenario( update_scenario: Optional[FailureTypesScenario] = None, ) -> None: logging.debug( - f"Asserting scenario %s", + "Asserting scenario %s", { "workflow": workflow, "expect_task_fail": expect_task_fail, @@ -6032,3 +6032,117 @@ async def test_activity_retry_delay(client: Client): err.cause.cause.next_retry_delay == ActivitiesWithRetryDelayWorkflow.next_retry_delay ) + + +@workflow.defn +class WorkflowWithoutInit: + value = "from class attribute" + _expected_update_result = "from class attribute" + + @workflow.update + async def my_update(self) -> str: + return self.value + + @workflow.run + async def run(self, _: str) -> str: + self.value = "set in run method" + return self.value + + +@workflow.defn +class WorkflowWithWorkflowInit: + _expected_update_result = "workflow input value" + + @workflow.init + def __init__(self, arg: str) -> None: + self.value = arg + + @workflow.update + async def my_update(self) -> str: + return self.value + + @workflow.run + async def run(self, _: str) -> str: + self.value = "set in run method" + return self.value + + +@workflow.defn +class WorkflowWithNonWorkflowInitInit: + _expected_update_result = "from parameter default" + + def __init__(self, arg: str = "from parameter default") -> None: + self.value = arg + + @workflow.update + async def my_update(self) -> str: + return self.value + + @workflow.run + async def run(self, _: str) -> str: + self.value = "set in run method" + return self.value + + +@pytest.mark.parametrize( + ["client_cls", "worker_cls"], + [ + (WorkflowWithoutInit, WorkflowWithoutInit), + (WorkflowWithNonWorkflowInitInit, WorkflowWithNonWorkflowInitInit), + (WorkflowWithWorkflowInit, WorkflowWithWorkflowInit), + ], +) +async def test_update_in_first_wft_sees_workflow_init( + client: Client, client_cls: Type, worker_cls: Type +): + """ + Test how @workflow.init affects what an update in the first WFT sees. + + Such an update is guaranteed to start executing before the main workflow + coroutine. The update should see the side effects of the __init__ method if + and only if @workflow.init is in effect. + """ + # This test must ensure that the update is in the first WFT. To do so, + # before running the worker, we start the workflow, send the update, and + # wait until the update is admitted. + task_queue = "task-queue" + update_id = "update-id" + wf_handle = await client.start_workflow( + client_cls.run, + "workflow input value", + id=str(uuid.uuid4()), + task_queue=task_queue, + ) + update_task = asyncio.create_task( + wf_handle.execute_update(client_cls.my_update, id=update_id) + ) + await assert_eq_eventually( + True, lambda: workflow_update_exists(client, wf_handle.id, update_id) + ) + # When the worker starts polling it will receive a first WFT containing the + # update, in addition to the start_workflow job. + async with new_worker(client, worker_cls, task_queue=task_queue): + assert await update_task == worker_cls._expected_update_result + assert await wf_handle.result() == "set in run method" + + +@workflow.defn +class WorkflowRunSeesWorkflowInitWorkflow: + @workflow.init + def __init__(self, arg: str) -> None: + self.value = arg + + @workflow.run + async def run(self, _: str): + return f"hello, {self.value}" + + +async def test_workflow_run_sees_workflow_init(client: Client): + async with new_worker(client, WorkflowRunSeesWorkflowInitWorkflow) as worker: + workflow_result = await client.execute_workflow( + WorkflowRunSeesWorkflowInitWorkflow.run, + "world", + id=str(uuid.uuid4()), + task_queue=worker.task_queue, + ) + assert workflow_result == "hello, world"