diff --git a/temporalio/workflow.py b/temporalio/workflow.py index a650ebd1..ed96324a 100644 --- a/temporalio/workflow.py +++ b/temporalio/workflow.py @@ -1478,6 +1478,16 @@ def _get_init_fn_takes_workflow_input( elif set(init_params) == {"self"}: # User __init__ with (self) signature return False, None + else: + # This (in 3.12) differs from the default signature inherited from + # object.__init__, which is (self, /, *args, **kwargs). We allow both as + # valid __init__ signatures (that do not take workflow input). + class StarArgsStarStarKwargs: + def __init__(self, *args, **kwargs) -> None: + pass + + if init_params == inspect.signature(StarArgsStarStarKwargs.__init__).parameters: + return False, None init_param_types, _ = temporalio.common._type_hints_from_func(init_fn) run_param_types, _ = temporalio.common._type_hints_from_func(run_fn) diff --git a/tests/test_workflow.py b/tests/test_workflow.py index 2082d7df..9d32e6f8 100644 --- a/tests/test_workflow.py +++ b/tests/test_workflow.py @@ -393,10 +393,7 @@ async def run(self, arg: MyDataClass) -> None: pass -class WorkflowInitGoodSlashStarArgsStarStarKwargsInitZeroParamRun: - # TODO: if they include the slash it will be allowed as it is - # indistinguishable from the default __init__ inherited from object. But if - # they don't it will not be. +class WorkflowInitGoodSlashStarArgsStarStarKwargs: def __init__(self, /, *args, **kwargs) -> None: pass @@ -405,6 +402,15 @@ async def run(self) -> None: pass +class WorkflowInitGoodStarArgsStarStarKwargs: + def __init__(self, *args, **kwargs) -> None: + pass + + @workflow.run + async def run(self) -> None: + pass + + @pytest.mark.parametrize( "cls", [ @@ -413,7 +419,8 @@ async def run(self) -> None: WorkflowInitGoodNoArgInitZeroParamRun, WorkflowInitGoodNoArgInitOneParamRun, DataClassTypedWorkflowProto, - WorkflowInitGoodSlashStarArgsStarStarKwargsInitZeroParamRun, + WorkflowInitGoodSlashStarArgsStarStarKwargs, + WorkflowInitGoodStarArgsStarStarKwargs, ], ) def test_workflow_init_good_does_not_take_workflow_input(cls): @@ -503,18 +510,6 @@ def test_workflow_init_good_takes_workflow_input(cls): ).__temporal_workflow_definition.init_fn_takes_workflow_input -class WorkflowInitBadStarArgsStarStarKwargs: - # TODO: if they include the slash it will be allowed as it is - # indistinguishable from the default __init__ inherited from object. But if - # they don't it will not be. - def __init__(self, *args, **kwargs) -> None: - pass - - @workflow.run - async def run(self) -> None: - pass - - class WorkflowInitBadExtraInitParamUntyped: def __init__(self, a) -> None: pass @@ -554,7 +549,6 @@ async def run(self, aa: int, bb: str) -> None: @pytest.mark.parametrize( "cls", [ - WorkflowInitBadStarArgsStarStarKwargs, WorkflowInitBadExtraInitParamUntyped, WorkflowInitBadMismatchedParamUntyped, WorkflowInitBadExtraInitParamTyped,