Skip to content

Commit

Permalink
Simplify and fix implementation
Browse files Browse the repository at this point in the history
  • Loading branch information
dandavison committed Sep 5, 2024
1 parent c47012f commit eb5b02c
Show file tree
Hide file tree
Showing 3 changed files with 276 additions and 107 deletions.
138 changes: 68 additions & 70 deletions temporalio/workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,6 @@
TypeVar,
Union,
cast,
get_origin,
overload,
)

Expand Down Expand Up @@ -1415,10 +1414,10 @@ def _apply_to_class(
if not init_fn:
issues.append("Missing __init__ method")
elif run_fn:
init_fn_takes_workflow_input, init_fn_issue = (
_get_init_fn_takes_workflow_input(init_fn, run_fn)
init_fn_takes_workflow_input, init_fn_issue = _init_fn_takes_workflow_input(
init_fn, run_fn
)
if init_fn_issue:
if init_fn_issue is not None:
issues.append(init_fn_issue)
if issues:
if len(issues) == 1:
Expand All @@ -1428,6 +1427,7 @@ def _apply_to_class(
)

assert run_fn
assert init_fn_takes_workflow_input is not None
defn = _Definition(
name=workflow_name,
cls=cls,
Expand Down Expand Up @@ -1459,79 +1459,77 @@ def __post_init__(self) -> None:
object.__setattr__(self, "ret_type", ret_type)


def _get_init_fn_takes_workflow_input(
def _init_fn_takes_workflow_input(
init_fn: Callable[..., None],
run_fn: Callable[..., Awaitable[Any]],
) -> Tuple[bool, Optional[str]]:
) -> Union[Tuple[bool, None], Tuple[None, str]]:
"""
Return (True, None) if the Workflow __init__ method accepts Workflow input.
If __init__ has a signature that features user-defined parameters, and if
these match those of the @workflow.run method, and if no validity issue is
encountered, then return (True, None). Otherwise return False, along with
any validity issue.
Return (True, None) if Workflow input args should be passed to Workflow __init__.
"""
init_params = inspect.signature(init_fn).parameters
if init_params == inspect.signature(object.__init__).parameters:
# No user __init__, or user __init__ with (self, *args, **kwargs) signature
return False, None
elif set(init_params) == {"self"}:
# User __init__ with (self) signature
# If the workflow class can be instantiated as cls(), i.e. without passing
# workflow input args, then we do that. Otherwise, if the parameters of
# __init__ exactly match those of the @workflow.run method, then we pass the
# workflow input args to __init__ when instantiating the workflow class.
# Otherwise, the workflow definition is invalid.

if _unbound_method_can_be_called_without_args_when_bound(init_fn):
# The workflow cls can be instantiated as cls()
return False, None
else:
# This (in 3.12) differs from the default signature inherited from
# object.__init__, which is (self, /, *args, **kwargs). We allow both as
# valid __init__ signatures (that do not take workflow input).
class StarArgsStarStarKwargs:
def __init__(self, *args, **kwargs) -> None:
pass

if init_params == inspect.signature(StarArgsStarStarKwargs.__init__).parameters:
return False, None

init_param_types, _ = temporalio.common._type_hints_from_func(init_fn)
run_param_types, _ = temporalio.common._type_hints_from_func(run_fn)
if init_param_types and run_param_types:
return _get_init_fn_takes_workflow_input_from_type_annotations(
init_param_types, run_param_types
)

run_params = inspect.signature(run_fn).parameters
if len(init_params) != len(run_params):
return (
False,
(
f"Number of __init__ method parameters ({len(init_params)}) must equal "
f"number of @workflow.run method parameters ({len(run_params)})"
),
)
# TODO: positionals / kwargs
return True, None


def _get_init_fn_takes_workflow_input_from_type_annotations(
init_param_types: List[Type],
run_param_types: List[Type],
) -> Tuple[bool, Optional[str]]:
if len(init_param_types) != len(run_param_types):
return (
False,
(
f"Number of __init__ method parameters ({len(init_param_types)}) must equal "
f"number of @workflow.run method parameters ({len(run_param_types)})"
),
)
else:
for t1, t2 in zip(init_param_types, run_param_types):
# Just check that the types are the same in a naive sense; do not
# support any notion of subtype compatibility for now.
if get_origin(t1) != get_origin(t2):
return (
False,
f"__init__ method param type {t1} does not match corresponding @workflow.run method param type {t2}",
)
# TODO: positionals / kwargs
return True, None
def get_params(fn: Callable) -> List[inspect.Parameter]:
# Ignore name when comparing parameters (remaining fields are kind,
# default, and annotation).
return [
p.replace(name="x") for p in inspect.signature(fn).parameters.values()
]

# We require that any type annotations present match exactly; i.e. we do
# not support any notion of subtype compatibility.
if get_params(init_fn) == get_params(run_fn):
# __init__ requires some args and has the same parameters as @workflow.run
return True, None
else:
return (
None,
"__init__ parameters do not match @workflow.run method parameters",
)


def _unbound_method_can_be_called_without_args_when_bound(
fn: Callable[..., Any],
) -> bool:
"""
Return True if the unbound method fn can be called without arguments when bound.
"""
# An unbound method can be called without arguments when bound if the
# following are both true:
#
# - The first parameter is POSITIONAL_ONLY or POSITIONAL_OR_KEYWORD (this is
# the parameter conventionally named 'self')
#
# - All other POSITIONAL_OR_KEYWORD or KEYWORD_ONLY parameters have default
# values.
params = iter(inspect.signature(fn).parameters.values())
self_param = next(params, None)
if not self_param or self_param.kind not in [
inspect.Parameter.POSITIONAL_OR_KEYWORD,
inspect.Parameter.POSITIONAL_ONLY,
]:
raise ValueError("Not an unbound method. This is a Python SDK bug.")
for p in params:
if p.kind == inspect.Parameter.POSITIONAL_ONLY:
return False
elif (
p.kind
in [
inspect.Parameter.POSITIONAL_OR_KEYWORD,
inspect.Parameter.KEYWORD_ONLY,
]
and p.default is inspect.Parameter.empty
):
return False
return True


# Async safe version of partial
Expand Down
Loading

0 comments on commit eb5b02c

Please sign in to comment.