Skip to content

Commit

Permalink
Address various review comments
Browse files Browse the repository at this point in the history
  • Loading branch information
Sushisource committed Oct 23, 2023
1 parent cc0871d commit de98531
Show file tree
Hide file tree
Showing 4 changed files with 61 additions and 69 deletions.
99 changes: 41 additions & 58 deletions temporalio/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -1619,7 +1619,7 @@ async def terminate(
@overload
async def execute_update(
self,
update: temporalio.workflow.UpdateMethodMultiArg[[SelfType], LocalReturnType],
update: temporalio.workflow.UpdateMethodMultiParam[[SelfType], LocalReturnType],
*,
id: Optional[str] = None,
rpc_metadata: Mapping[str, str] = {},
Expand All @@ -1631,7 +1631,7 @@ async def execute_update(
@overload
async def execute_update(
self,
update: temporalio.workflow.UpdateMethodMultiArg[
update: temporalio.workflow.UpdateMethodMultiParam[
[SelfType, ParamType], LocalReturnType
],
arg: ParamType,
Expand All @@ -1645,7 +1645,7 @@ async def execute_update(
@overload
async def execute_update(
self,
update: temporalio.workflow.UpdateMethodMultiArg[
update: temporalio.workflow.UpdateMethodMultiParam[
MultiParamSpec, LocalReturnType
],
*,
Expand Down Expand Up @@ -1784,14 +1784,9 @@ async def _start_update(
) -> WorkflowUpdateHandle:
update_name: str
ret_type = result_type
if isinstance(update, temporalio.workflow.UpdateMethodMultiArg):
if isinstance(update, temporalio.workflow.UpdateMethodMultiParam):
defn = update._defn
if not defn:
raise RuntimeError(
f"Update definition not found on {update.__qualname__}, "
"is it decorated with @workflow.update?"
)
elif not defn.name:
if not defn.name:
raise RuntimeError("Cannot invoke dynamic update definition")
# TODO(cretz): Check count/type of args at runtime?
update_name = defn.name
Expand All @@ -1801,9 +1796,9 @@ async def _start_update(

return await self._client._impl.start_workflow_update(
UpdateWorkflowInput(
workflow_id=self._id,
id=self._id,
run_id=self._run_id,
update_id=id or "",
update_id=id,
update=update_name,
args=temporalio.common._arg_or_args(arg, args),
headers={},
Expand Down Expand Up @@ -3878,7 +3873,7 @@ def __init__(
name: str,
workflow_id: str,
*,
run_id: Optional[str] = None,
workflow_run_id: Optional[str] = None,
result_type: Optional[Type] = None,
):
"""Create a workflow update handle.
Expand All @@ -3890,29 +3885,29 @@ def __init__(
self._id = id
self._name = name
self._workflow_id = workflow_id
self._run_id = run_id
self._workflow_run_id = workflow_run_id
self._result_type = result_type
self._known_result: Optional[temporalio.api.update.v1.Outcome] = None

@property
def id(self) -> str:
"""ID of this Update request"""
"""ID of this Update request."""
return self._id

@property
def name(self) -> str:
"""The name of the Update being invoked"""
"""The name of the Update being invoked."""
return self._name

@property
def workflow_id(self) -> str:
"""The ID of the Workflow targeted by this Update"""
"""The ID of the Workflow targeted by this Update."""
return self._workflow_id

@property
def run_id(self) -> Optional[str]:
"""If specified, the specific run of the Workflow targeted by this Update"""
return self._run_id
def workflow_run_id(self) -> Optional[str]:
"""If specified, the specific run of the Workflow targeted by this Update."""
return self._workflow_run_id

async def result(
self,
Expand All @@ -3934,7 +3929,6 @@ async def result(
TimeoutError: The specified timeout was reached when waiting for the update result.
RPCError: Update result could not be fetched for some other reason.
"""
outcome: temporalio.api.update.v1.Outcome
if self._known_result is not None:
outcome = self._known_result
return await _update_outcome_to_result(
Expand All @@ -3944,23 +3938,20 @@ async def result(
self._client.data_converter,
self._result_type,
)
else:
return await self._client._impl.poll_workflow_update(
PollUpdateWorkflowInput(
self.workflow_id,
self.run_id,
self.id,
self.name,
timeout,
{},
self._result_type,
rpc_metadata,
rpc_timeout,
)
)

def _set_known_result(self, result: temporalio.api.update.v1.Outcome) -> None:
self._known_result = result
return await self._client._impl.poll_workflow_update(
PollUpdateWorkflowInput(
self.workflow_id,
self.workflow_run_id,
self.id,
self.name,
timeout,
{},
self._result_type,
rpc_metadata,
rpc_timeout,
)
)


class WorkflowFailureError(temporalio.exceptions.TemporalError):
Expand Down Expand Up @@ -4023,11 +4014,9 @@ def message(self) -> str:
class WorkflowUpdateFailedError(temporalio.exceptions.TemporalError):
"""Error that occurs when an update fails."""

def __init__(self, update_id: str, update_name: str, cause: BaseException) -> None:
def __init__(self, cause: BaseException) -> None:
"""Create workflow update failed error."""
super().__init__("Workflow update failed")
self._update_id = update_id
self._update_name = update_name
self.__cause__ = cause

@property
Expand Down Expand Up @@ -4171,9 +4160,9 @@ class TerminateWorkflowInput:
class UpdateWorkflowInput:
"""Input for :py:meth:`OutboundInterceptor.start_workflow_update`."""

workflow_id: str
id: str
run_id: Optional[str]
update_id: str
update_id: Optional[str]
update: str
args: Sequence[Any]
wait_for_stage: Optional[
Expand Down Expand Up @@ -4787,12 +4776,12 @@ async def start_workflow_update(
req = temporalio.api.workflowservice.v1.UpdateWorkflowExecutionRequest(
namespace=self._client.namespace,
workflow_execution=temporalio.api.common.v1.WorkflowExecution(
workflow_id=input.workflow_id,
workflow_id=input.id,
run_id=input.run_id or "",
),
request=temporalio.api.update.v1.Request(
meta=temporalio.api.update.v1.Meta(
update_id=input.update_id,
update_id=input.update_id or "",
identity=self._client.identity,
),
input=temporalio.api.update.v1.Input(
Expand All @@ -4814,23 +4803,19 @@ async def start_workflow_update(
req, retry=True, metadata=input.rpc_metadata, timeout=input.rpc_timeout
)
except RPCError as err:
# If the status is INVALID_ARGUMENT, we can assume it's an update
# failed error
if err.status == RPCStatusCode.INVALID_ARGUMENT:
raise WorkflowUpdateFailedError(input.workflow_id, input.update, err)
else:
raise
raise

determined_id = resp.update_ref.update_id
update_handle = WorkflowUpdateHandle(
client=self._client,
id=input.update_id,
id=determined_id,
name=input.update,
workflow_id=input.workflow_id,
run_id=input.run_id,
workflow_id=input.id,
workflow_run_id=input.run_id,
result_type=input.ret_type,
)
if resp.HasField("outcome"):
update_handle._set_known_result(resp.outcome)
update_handle._known_result = resp.outcome

return update_handle

Expand Down Expand Up @@ -4869,8 +4854,8 @@ async def poll_loop():
input.ret_type,
)
except RPCError as err:
if err.status == RPCStatusCode.DEADLINE_EXCEEDED:
continue
if err.status != RPCStatusCode.DEADLINE_EXCEEDED:
raise

# Wait for at most the *overall* timeout
return await asyncio.wait_for(
Expand Down Expand Up @@ -5415,8 +5400,6 @@ async def _update_outcome_to_result(
) -> Any:
if outcome.HasField("failure"):
raise WorkflowUpdateFailedError(
id,
name,
await converter.decode_failure(outcome.failure.cause),
)
if not outcome.success.payloads:
Expand Down
2 changes: 1 addition & 1 deletion temporalio/contrib/opentelemetry.py
Original file line number Diff line number Diff line change
Expand Up @@ -249,7 +249,7 @@ async def start_workflow_update(
) -> temporalio.client.WorkflowUpdateHandle:
with self.root._start_as_current_span(
f"StartWorkflowUpdate:{input.update}",
attributes={"temporalWorkflowID": input.workflow_id},
attributes={"temporalWorkflowID": input.id},
input=input,
kind=opentelemetry.trace.SpanKind.CLIENT,
):
Expand Down
21 changes: 12 additions & 9 deletions temporalio/workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -773,7 +773,7 @@ def time_ns() -> int:

# Needs to be defined here to avoid a circular import
@runtime_checkable
class UpdateMethodMultiArg(Protocol[MultiParamSpec, ProtocolReturnType]):
class UpdateMethodMultiParam(Protocol[MultiParamSpec, ProtocolReturnType]):
"""Decorated workflow update functions implement this."""

_defn: temporalio.workflow._UpdateDefinition
Expand All @@ -784,22 +784,24 @@ def __call__(
"""Generic callable type callback."""
...

def validator(self, vfunc: Callable[MultiParamSpec, None]) -> None:
def validator(
self, vfunc: Callable[MultiParamSpec, None]
) -> Callable[MultiParamSpec, None]:
"""Use to decorate a function to validate the arguments passed to the update handler."""
...


@overload
def update(
fn: Callable[MultiParamSpec, Awaitable[ReturnType]]
) -> UpdateMethodMultiArg[MultiParamSpec, ReturnType]:
) -> UpdateMethodMultiParam[MultiParamSpec, ReturnType]:
...


@overload
def update(
fn: Callable[MultiParamSpec, ReturnType]
) -> UpdateMethodMultiArg[MultiParamSpec, ReturnType]:
) -> UpdateMethodMultiParam[MultiParamSpec, ReturnType]:
...


Expand All @@ -808,7 +810,7 @@ def update(
*, name: str
) -> Callable[
[Callable[MultiParamSpec, ReturnType]],
UpdateMethodMultiArg[MultiParamSpec, ReturnType],
UpdateMethodMultiParam[MultiParamSpec, ReturnType],
]:
...

Expand All @@ -818,7 +820,7 @@ def update(
*, dynamic: Literal[True]
) -> Callable[
[Callable[MultiParamSpec, ReturnType]],
UpdateMethodMultiArg[MultiParamSpec, ReturnType],
UpdateMethodMultiParam[MultiParamSpec, ReturnType],
]:
...

Expand Down Expand Up @@ -880,10 +882,11 @@ def with_name(

def _update_validator(
update_def: _UpdateDefinition, fn: Optional[Callable[..., None]] = None
):
) -> Optional[Callable[..., None]]:
"""Decorator for a workflow update validator method."""
if fn is not None:
update_def.set_validator(fn)
return fn


def upsert_search_attributes(attributes: temporalio.common.SearchAttributes) -> None:
Expand Down Expand Up @@ -1187,7 +1190,7 @@ def _apply_to_class(
)
else:
queries[query_defn.name] = query_defn
elif isinstance(member, UpdateMethodMultiArg):
elif isinstance(member, UpdateMethodMultiParam):
update_defn = member._defn
if update_defn.name in updates:
defn_name = update_defn.name or "<dynamic>"
Expand Down Expand Up @@ -1230,7 +1233,7 @@ def _apply_to_class(
issues.append(
f"@workflow.query defined on {base_member.__qualname__} but not on the override"
)
elif isinstance(base_member, UpdateMethodMultiArg):
elif isinstance(base_member, UpdateMethodMultiParam):
update_defn = base_member._defn
if update_defn.name not in updates:
issues.append(
Expand Down
8 changes: 7 additions & 1 deletion tests/worker/test_workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -3592,9 +3592,10 @@ async def test_workflow_update_handlers_happy(client: Client, env: WorkflowEnvir
async with new_worker(
client, UpdateHandlersWorkflow, activities=[say_hello]
) as worker:
wf_id = f"update-handlers-workflow-{uuid.uuid4()}"
handle = await client.start_workflow(
UpdateHandlersWorkflow.run,
id=f"update-handlers-workflow-{uuid.uuid4()}",
id=wf_id,
task_queue=worker.task_queue,
)

Expand Down Expand Up @@ -3622,6 +3623,11 @@ async def test_workflow_update_handlers_happy(client: Client, env: WorkflowEnvir
UpdateHandlersWorkflow.async_named
)

# Get untyped handle
assert "val3" == await client.get_workflow_handle(wf_id).execute_update(
UpdateHandlersWorkflow.last_event, "val4"
)


async def test_workflow_update_handlers_unhappy(
client: Client, env: WorkflowEnvironment
Expand Down

0 comments on commit de98531

Please sign in to comment.