Skip to content

Commit

Permalink
Workflow init (#645)
Browse files Browse the repository at this point in the history
* Introduce @workflow.init decorator
  • Loading branch information
dandavison authored Sep 24, 2024
1 parent 09ac120 commit 0995ae0
Show file tree
Hide file tree
Showing 5 changed files with 286 additions and 29 deletions.
5 changes: 5 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -575,6 +575,11 @@ Here are the decorators that can be applied:
* The method's arguments are the workflow's arguments
* The first parameter must be `self`, followed by positional arguments. Best practice is to only take a single
argument that is an object/dataclass of fields that can be added to as needed.
* `@workflow.init` - Specifies that the `__init__` method accepts the workflow's arguments.
* If present, may only be applied to the `__init__` method, the parameters of which must then be identical to those of
the `@workflow.run` method.
* The purpose of this decorator is to allow operations involving workflow arguments to be performed in the `__init__`
method, before any signal or update handler has a chance to execute.
* `@workflow.signal` - Defines a method as a signal
* Can be defined on an `async` or non-`async` function at any hierarchy depth, but if decorated method is overridden,
the override must also be decorated
Expand Down
57 changes: 40 additions & 17 deletions temporalio/worker/_workflow_instance.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,6 +201,7 @@ def __init__(self, det: WorkflowInstanceDetails) -> None:
self._payload_converter = det.payload_converter_class()
self._failure_converter = det.failure_converter_class()
self._defn = det.defn
self._workflow_input: Optional[ExecuteWorkflowInput] = None
self._info = det.info
self._extern_functions = det.extern_functions
self._disable_eager_activity_execution = det.disable_eager_activity_execution
Expand Down Expand Up @@ -318,8 +319,9 @@ def get_thread_id(self) -> Optional[int]:
return self._current_thread_id

#### Activation functions ####
# These are in alphabetical order and besides "activate", all other calls
# are "_apply_" + the job field name.
# These are in alphabetical order and besides "activate", and
# "_make_workflow_input", all other calls are "_apply_" + the job field
# name.

def activate(
self, act: temporalio.bridge.proto.workflow_activation.WorkflowActivation
Expand All @@ -342,6 +344,7 @@ def activate(
try:
# Split into job sets with patches, then signals + updates, then
# non-queries, then queries
start_job = None
job_sets: List[
List[temporalio.bridge.proto.workflow_activation.WorkflowActivationJob]
] = [[], [], [], []]
Expand All @@ -351,10 +354,15 @@ def activate(
elif job.HasField("signal_workflow") or job.HasField("do_update"):
job_sets[1].append(job)
elif not job.HasField("query_workflow"):
if job.HasField("start_workflow"):
start_job = job.start_workflow
job_sets[2].append(job)
else:
job_sets[3].append(job)

if start_job:
self._workflow_input = self._make_workflow_input(start_job)

# Apply every job set, running after each set
for index, job_set in enumerate(job_sets):
if not job_set:
Expand Down Expand Up @@ -863,34 +871,41 @@ async def run_workflow(input: ExecuteWorkflowInput) -> None:
return
raise

if not self._workflow_input:
raise RuntimeError(
"Expected workflow input to be set. This is an SDK Python bug."
)
self._primary_task = self.create_task(
self._run_top_level_workflow_function(run_workflow(self._workflow_input)),
name="run",
)

def _apply_update_random_seed(
self, job: temporalio.bridge.proto.workflow_activation.UpdateRandomSeed
) -> None:
self._random.seed(job.randomness_seed)

def _make_workflow_input(
self, start_job: temporalio.bridge.proto.workflow_activation.StartWorkflow
) -> ExecuteWorkflowInput:
# Set arg types, using raw values for dynamic
arg_types = self._defn.arg_types
if not self._defn.name:
# Dynamic is just the raw value for each input value
arg_types = [temporalio.common.RawValue] * len(job.arguments)
args = self._convert_payloads(job.arguments, arg_types)
arg_types = [temporalio.common.RawValue] * len(start_job.arguments)
args = self._convert_payloads(start_job.arguments, arg_types)
# Put args in a list if dynamic
if not self._defn.name:
args = [args]

# Schedule it
input = ExecuteWorkflowInput(
return ExecuteWorkflowInput(
type=self._defn.cls,
# TODO(cretz): Remove cast when https://github.com/python/mypy/issues/5485 fixed
run_fn=cast(Callable[..., Awaitable[Any]], self._defn.run_fn),
args=args,
headers=job.headers,
)
self._primary_task = self.create_task(
self._run_top_level_workflow_function(run_workflow(input)),
name="run",
headers=start_job.headers,
)

def _apply_update_random_seed(
self, job: temporalio.bridge.proto.workflow_activation.UpdateRandomSeed
) -> None:
self._random.seed(job.randomness_seed)

#### _Runtime direct workflow call overrides ####
# These are in alphabetical order and all start with "workflow_".

Expand Down Expand Up @@ -1617,6 +1632,14 @@ def _convert_payloads(
except Exception as err:
raise RuntimeError("Failed decoding arguments") from err

def _instantiate_workflow_object(self) -> Any:
if not self._workflow_input:
raise RuntimeError("Expected workflow input. This is a Python SDK bug.")
if hasattr(self._defn.cls.__init__, "__temporal_workflow_init"):
return self._defn.cls(*self._workflow_input.args)
else:
return self._defn.cls()

def _is_workflow_failure_exception(self, err: BaseException) -> bool:
# An exception is a failure instead of a task fail if it's already a
# failure error or if it is an instance of any of the failure types in
Expand Down Expand Up @@ -1752,7 +1775,7 @@ def _run_once(self, *, check_conditions: bool) -> None:
# We instantiate the workflow class _inside_ here because __init__
# needs to run with this event loop set
if not self._object:
self._object = self._defn.cls()
self._object = self._instantiate_workflow_object()

# Run while there is anything ready
while self._ready:
Expand Down
66 changes: 57 additions & 9 deletions temporalio/workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,10 +143,38 @@ def decorator(cls: ClassType) -> ClassType:
return decorator


def init(
init_fn: CallableType,
) -> CallableType:
"""Decorator for the workflow init method.
This may be used on the __init__ method of the workflow class to specify
that it accepts the same workflow input arguments as the ``@workflow.run``
method. It may not be used on any other method.
If used, the workflow will be instantiated as
``MyWorkflow(**workflow_input_args)``. If not used, the workflow will be
instantiated as ``MyWorkflow()``.
Note that the ``@workflow.run`` method is always called as
``my_workflow.my_run_method(**workflow_input_args)``. If you use the
``@workflow.init`` decorator, the parameter list of your __init__ and
``@workflow.run`` methods must be identical.
Args:
init_fn: The __init__function to decorate.
"""
if init_fn.__name__ != "__init__":
raise ValueError("@workflow.init may only be used on the __init__ method")

setattr(init_fn, "__temporal_workflow_init", True)
return init_fn


def run(fn: CallableAsyncType) -> CallableAsyncType:
"""Decorator for the workflow run method.
This must be set on one and only one async method defined on the same class
This must be used on one and only one async method defined on the same class
as ``@workflow.defn``. This can be defined on a base class method but must
then be explicitly overridden and defined on the workflow class.
Expand Down Expand Up @@ -238,7 +266,7 @@ def signal(
):
"""Decorator for a workflow signal method.
This is set on any async or non-async method that you wish to be called upon
This is used on any async or non-async method that you wish to be called upon
receiving a signal. If a function overrides one with this decorator, it too
must be decorated.
Expand Down Expand Up @@ -309,7 +337,7 @@ def query(
):
"""Decorator for a workflow query method.
This is set on any non-async method that expects to handle a query. If a
This is used on any non-async method that expects to handle a query. If a
function overrides one with this decorator, it too must be decorated.
Query methods can only have positional parameters. Best practice for
Expand Down Expand Up @@ -983,7 +1011,7 @@ def update(
):
"""Decorator for a workflow update handler method.
This is set on any async or non-async method that you wish to be called upon
This is used on any async or non-async method that you wish to be called upon
receiving an update. If a function overrides one with this decorator, it too
must be decorated.
Expand Down Expand Up @@ -1307,13 +1335,13 @@ def _apply_to_class(
issues: List[str] = []

# Collect run fn and all signal/query/update fns
members = inspect.getmembers(cls)
init_fn: Optional[Callable[..., None]] = None
run_fn: Optional[Callable[..., Awaitable[Any]]] = None
seen_run_attr = False
signals: Dict[Optional[str], _SignalDefinition] = {}
queries: Dict[Optional[str], _QueryDefinition] = {}
updates: Dict[Optional[str], _UpdateDefinition] = {}
for name, member in members:
for name, member in inspect.getmembers(cls):
if hasattr(member, "__temporal_workflow_run"):
seen_run_attr = True
if not _is_unbound_method_on_cls(member, cls):
Expand Down Expand Up @@ -1354,6 +1382,8 @@ def _apply_to_class(
)
else:
queries[query_defn.name] = query_defn
elif name == "__init__" and hasattr(member, "__temporal_workflow_init"):
init_fn = member
elif isinstance(member, UpdateMethodMultiParam):
update_defn = member._defn
if update_defn.name in updates:
Expand Down Expand Up @@ -1406,9 +1436,14 @@ def _apply_to_class(

if not seen_run_attr:
issues.append("Missing @workflow.run method")
if len(issues) == 1:
raise ValueError(f"Invalid workflow class: {issues[0]}")
elif issues:
if init_fn and run_fn:
if not _parameters_identical_up_to_naming(init_fn, run_fn):
issues.append(
"@workflow.init and @workflow.run method parameters do not match"
)
if issues:
if len(issues) == 1:
raise ValueError(f"Invalid workflow class: {issues[0]}")
raise ValueError(
f"Invalid workflow class for {len(issues)} reasons: {', '.join(issues)}"
)
Expand Down Expand Up @@ -1444,6 +1479,19 @@ def __post_init__(self) -> None:
object.__setattr__(self, "ret_type", ret_type)


def _parameters_identical_up_to_naming(fn1: Callable, fn2: Callable) -> bool:
"""Return True if the functions have identical parameter lists, ignoring parameter names."""

def params(fn: Callable) -> List[inspect.Parameter]:
# Ignore name when comparing parameters (remaining fields are kind,
# default, and annotation).
return [p.replace(name="x") for p in inspect.signature(fn).parameters.values()]

# We require that any type annotations present match exactly; i.e. we do
# not support any notion of subtype compatibility.
return params(fn1) == params(fn2)


# Async safe version of partial
def _bind_method(obj: Any, fn: Callable[..., Any]) -> Callable[..., Any]:
# Curry instance on the definition function since that represents an
Expand Down
69 changes: 68 additions & 1 deletion tests/test_workflow.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
from typing import Any, Sequence
import inspect
import itertools
from typing import Sequence

import pytest

Expand Down Expand Up @@ -342,3 +344,68 @@ def test_workflow_defn_dynamic_handler_warnings():
# We want to make sure they are reporting the right stacklevel
warnings[0].filename.endswith("test_workflow.py")
warnings[1].filename.endswith("test_workflow.py")


class _TestParametersIdenticalUpToNaming:
def a1(self, a):
pass

def a2(self, b):
pass

def b1(self, a: int):
pass

def b2(self, b: int) -> str:
return ""

def c1(self, a1: int, a2: str) -> str:
return ""

def c2(self, b1: int, b2: str) -> int:
return 0

def d1(self, a1, a2: str) -> None:
pass

def d2(self, b1, b2: str) -> str:
return ""

def e1(self, a1, a2: str = "") -> None:
return None

def e2(self, b1, b2: str = "") -> str:
return ""

def f1(self, a1, a2: str = "a") -> None:
return None


def test_parameters_identical_up_to_naming():
fns = [
f
for _, f in inspect.getmembers(_TestParametersIdenticalUpToNaming)
if inspect.isfunction(f)
]
for f1, f2 in itertools.combinations(fns, 2):
name1, name2 = f1.__name__, f2.__name__
expect_equal = name1[0] == name2[0]
assert (
workflow._parameters_identical_up_to_naming(f1, f2) == (expect_equal)
), f"expected {name1} and {name2} parameters{' ' if expect_equal else ' not '}to compare equal"


@workflow.defn
class BadWorkflowInit:
def not__init__(self):
pass

@workflow.run
async def run(self):
pass


def test_workflow_init_not__init__():
with pytest.raises(ValueError) as err:
workflow.init(BadWorkflowInit.not__init__)
assert "@workflow.init may only be used on the __init__ method" in str(err.value)
Loading

0 comments on commit 0995ae0

Please sign in to comment.