Skip to content

Commit

Permalink
Handle non-slash form __init__(self, *args, **kwargs)
Browse files Browse the repository at this point in the history
  • Loading branch information
dandavison committed Sep 5, 2024
1 parent bfabdbf commit c47012f
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 18 deletions.
10 changes: 10 additions & 0 deletions temporalio/workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -1478,6 +1478,16 @@ def _get_init_fn_takes_workflow_input(
elif set(init_params) == {"self"}:
# User __init__ with (self) signature
return False, None
else:
# This (in 3.12) differs from the default signature inherited from
# object.__init__, which is (self, /, *args, **kwargs). We allow both as
# valid __init__ signatures (that do not take workflow input).
class StarArgsStarStarKwargs:
def __init__(self, *args, **kwargs) -> None:
pass

if init_params == inspect.signature(StarArgsStarStarKwargs.__init__).parameters:
return False, None

init_param_types, _ = temporalio.common._type_hints_from_func(init_fn)
run_param_types, _ = temporalio.common._type_hints_from_func(run_fn)
Expand Down
30 changes: 12 additions & 18 deletions tests/test_workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -393,10 +393,7 @@ async def run(self, arg: MyDataClass) -> None:
pass


class WorkflowInitGoodSlashStarArgsStarStarKwargsInitZeroParamRun:
# TODO: if they include the slash it will be allowed as it is
# indistinguishable from the default __init__ inherited from object. But if
# they don't it will not be.
class WorkflowInitGoodSlashStarArgsStarStarKwargs:
def __init__(self, /, *args, **kwargs) -> None:
pass

Expand All @@ -405,6 +402,15 @@ async def run(self) -> None:
pass


class WorkflowInitGoodStarArgsStarStarKwargs:
def __init__(self, *args, **kwargs) -> None:
pass

@workflow.run
async def run(self) -> None:
pass


@pytest.mark.parametrize(
"cls",
[
Expand All @@ -413,7 +419,8 @@ async def run(self) -> None:
WorkflowInitGoodNoArgInitZeroParamRun,
WorkflowInitGoodNoArgInitOneParamRun,
DataClassTypedWorkflowProto,
WorkflowInitGoodSlashStarArgsStarStarKwargsInitZeroParamRun,
WorkflowInitGoodSlashStarArgsStarStarKwargs,
WorkflowInitGoodStarArgsStarStarKwargs,
],
)
def test_workflow_init_good_does_not_take_workflow_input(cls):
Expand Down Expand Up @@ -503,18 +510,6 @@ def test_workflow_init_good_takes_workflow_input(cls):
).__temporal_workflow_definition.init_fn_takes_workflow_input


class WorkflowInitBadStarArgsStarStarKwargs:
# TODO: if they include the slash it will be allowed as it is
# indistinguishable from the default __init__ inherited from object. But if
# they don't it will not be.
def __init__(self, *args, **kwargs) -> None:
pass

@workflow.run
async def run(self) -> None:
pass


class WorkflowInitBadExtraInitParamUntyped:
def __init__(self, a) -> None:
pass
Expand Down Expand Up @@ -554,7 +549,6 @@ async def run(self, aa: int, bb: str) -> None:
@pytest.mark.parametrize(
"cls",
[
WorkflowInitBadStarArgsStarStarKwargs,
WorkflowInitBadExtraInitParamUntyped,
WorkflowInitBadMismatchedParamUntyped,
WorkflowInitBadExtraInitParamTyped,
Expand Down

0 comments on commit c47012f

Please sign in to comment.