diff --git a/temporalio/bridge/src/worker.rs b/temporalio/bridge/src/worker.rs index aade5e29..fa37049f 100644 --- a/temporalio/bridge/src/worker.rs +++ b/temporalio/bridge/src/worker.rs @@ -44,6 +44,7 @@ pub struct WorkerConfig { max_activities_per_second: Option, max_task_queue_activities_per_second: Option, graceful_shutdown_period_millis: u64, + use_worker_versioning: bool, } macro_rules! enter_sync { @@ -232,6 +233,7 @@ impl TryFrom for temporal_sdk_core::WorkerConfig { // auto-cancel-activity behavior of shutdown will not occur, so we // always set it even if 0. .graceful_shutdown_period(Duration::from_millis(conf.graceful_shutdown_period_millis)) + .use_worker_versioning(conf.use_worker_versioning) .build() .map_err(|err| PyValueError::new_err(format!("Invalid worker config: {}", err))) } diff --git a/temporalio/bridge/worker.py b/temporalio/bridge/worker.py index f01745db..73935aa3 100644 --- a/temporalio/bridge/worker.py +++ b/temporalio/bridge/worker.py @@ -47,6 +47,7 @@ class WorkerConfig: max_activities_per_second: Optional[float] max_task_queue_activities_per_second: Optional[float] graceful_shutdown_period_millis: int + use_worker_versioning: bool class Worker: diff --git a/temporalio/client.py b/temporalio/client.py index 2f1060e5..81ab4ed3 100644 --- a/temporalio/client.py +++ b/temporalio/client.py @@ -12,13 +12,14 @@ from abc import ABC, abstractmethod from dataclasses import dataclass from datetime import datetime, timedelta, timezone -from enum import IntEnum +from enum import Enum, IntEnum from typing import ( Any, AsyncIterator, Awaitable, Callable, Dict, + FrozenSet, Generic, Iterable, Mapping, @@ -831,6 +832,108 @@ async def list_schedules( ) ) + async def update_worker_build_id_compatibility( + self, + task_queue: str, + operation: BuildIdOp, + rpc_metadata: Mapping[str, str] = {}, + rpc_timeout: Optional[timedelta] = None, + ) -> None: + """Used to add new Build IDs or otherwise update the relative compatibility of Build Ids as + defined on a specific task queue for the Worker Versioning feature. + + For more on this feature, see https://docs.temporal.io/workers#worker-versioning + + .. warning:: + This API is experimental + + Args: + task_queue: The task queue to target. + operation: The operation to perform. + rpc_metadata: Headers used on each RPC call. Keys here override + client-level RPC metadata keys. + rpc_timeout: Optional RPC deadline to set for each RPC call. + """ + return await self._impl.update_worker_build_id_compatibility( + UpdateWorkerBuildIdCompatibilityInput( + task_queue, + operation, + rpc_metadata=rpc_metadata, + rpc_timeout=rpc_timeout, + ) + ) + + async def get_worker_build_id_compatibility( + self, + task_queue: str, + max_sets: Optional[int] = None, + rpc_metadata: Mapping[str, str] = {}, + rpc_timeout: Optional[timedelta] = None, + ) -> WorkerBuildIdVersionSets: + """Get the Build ID compatibility sets for a specific task queue. + + For more on this feature, see https://docs.temporal.io/workers#worker-versioning + + .. warning:: + This API is experimental + + Args: + task_queue: The task queue to target. + max_sets: The maximum number of sets to return. If not specified, all sets will be + returned. + rpc_metadata: Headers used on each RPC call. Keys here override + client-level RPC metadata keys. + rpc_timeout: Optional RPC deadline to set for each RPC call. + """ + return await self._impl.get_worker_build_id_compatibility( + GetWorkerBuildIdCompatibilityInput( + task_queue, + max_sets, + rpc_metadata=rpc_metadata, + rpc_timeout=rpc_timeout, + ) + ) + + async def get_worker_task_reachability( + self, + build_ids: Sequence[str], + task_queues: Sequence[str] = [], + reachability_type: Optional[TaskReachabilityType] = None, + rpc_metadata: Mapping[str, str] = {}, + rpc_timeout: Optional[timedelta] = None, + ) -> WorkerTaskReachability: + """Determine if some Build IDs for certain Task Queues could have tasks dispatched to them. + + For more on this feature, see https://docs.temporal.io/workers#worker-versioning + + .. warning:: + This API is experimental + + Args: + build_ids: The Build IDs to query the reachability of. At least one must be specified. + task_queues: Task Queues to restrict the query to. If not specified, all Task Queues + will be searched. When requesting a large number of task queues or all task queues + associated with the given Build IDs in a namespace, all Task Queues will be listed + in the response but some of them may not contain reachability information due to a + server enforced limit. When reaching the limit, task queues that reachability + information could not be retrieved for will be marked with a `NotFetched` entry in + {@link BuildIdReachability.taskQueueReachability}. The caller may issue another call + to get the reachability for those task queues. + reachability_type: The kind of reachability this request is concerned with. + rpc_metadata: Headers used on each RPC call. Keys here override + client-level RPC metadata keys. + rpc_timeout: Optional RPC deadline to set for each RPC call. + """ + return await self._impl.get_worker_task_reachability( + GetWorkerTaskReachabilityInput( + build_ids, + task_queues, + reachability_type, + rpc_metadata=rpc_metadata, + rpc_timeout=rpc_timeout, + ) + ) + class ClientConfig(TypedDict, total=False): """TypedDict of config originally passed to :py:meth:`Client`.""" @@ -3866,6 +3969,37 @@ class UpdateScheduleInput: rpc_timeout: Optional[timedelta] +@dataclass +class UpdateWorkerBuildIdCompatibilityInput: + """Input for :py:meth:`OutboundInterceptor.update_worker_build_id_compatibility`.""" + + task_queue: str + operation: BuildIdOp + rpc_metadata: Mapping[str, str] + rpc_timeout: Optional[timedelta] + + +@dataclass +class GetWorkerBuildIdCompatibilityInput: + """Input for :py:meth:`OutboundInterceptor.get_worker_build_id_compatibility`.""" + + task_queue: str + max_sets: Optional[int] + rpc_metadata: Mapping[str, str] + rpc_timeout: Optional[timedelta] + + +@dataclass +class GetWorkerTaskReachabilityInput: + """Input for :py:meth:`OutboundInterceptor.get_worker_build_id_reachability`.""" + + build_ids: Sequence[str] + task_queues: Sequence[str] + reachability: Optional[TaskReachabilityType] + rpc_metadata: Mapping[str, str] + rpc_timeout: Optional[timedelta] + + @dataclass class Interceptor: """Interceptor for clients. @@ -4005,6 +4139,24 @@ async def update_schedule(self, input: UpdateScheduleInput) -> None: """Called for every :py:meth:`ScheduleHandle.update` call.""" await self.next.update_schedule(input) + async def update_worker_build_id_compatibility( + self, input: UpdateWorkerBuildIdCompatibilityInput + ) -> None: + """Called for every :py:meth:`Client.update_worker_build_id_compatibility` call.""" + await self.next.update_worker_build_id_compatibility(input) + + async def get_worker_build_id_compatibility( + self, input: GetWorkerBuildIdCompatibilityInput + ) -> WorkerBuildIdVersionSets: + """Called for every :py:meth:`Client.get_worker_build_id_compatibility` call.""" + return await self.next.get_worker_build_id_compatibility(input) + + async def get_worker_task_reachability( + self, input: GetWorkerTaskReachabilityInput + ) -> WorkerTaskReachability: + """Called for every :py:meth:`Client.get_worker_build_id_reachability` call.""" + return await self.next.get_worker_task_reachability(input) + class _ClientImpl(OutboundInterceptor): def __init__(self, client: Client) -> None: @@ -4612,6 +4764,42 @@ async def update_schedule(self, input: UpdateScheduleInput) -> None: timeout=input.rpc_timeout, ) + async def update_worker_build_id_compatibility( + self, input: UpdateWorkerBuildIdCompatibilityInput + ) -> None: + req = input.operation._as_partial_proto() + req.namespace = self._client.namespace + req.task_queue = input.task_queue + await self._client.workflow_service.update_worker_build_id_compatibility( + req, retry=True, metadata=input.rpc_metadata, timeout=input.rpc_timeout + ) + + async def get_worker_build_id_compatibility( + self, input: GetWorkerBuildIdCompatibilityInput + ) -> WorkerBuildIdVersionSets: + req = temporalio.api.workflowservice.v1.GetWorkerBuildIdCompatibilityRequest( + namespace=self._client.namespace, + task_queue=input.task_queue, + max_sets=input.max_sets or 0, + ) + resp = await self._client.workflow_service.get_worker_build_id_compatibility( + req, retry=True, metadata=input.rpc_metadata, timeout=input.rpc_timeout + ) + return WorkerBuildIdVersionSets._from_proto(resp) + + async def get_worker_task_reachability( + self, input: GetWorkerTaskReachabilityInput + ) -> WorkerTaskReachability: + req = temporalio.api.workflowservice.v1.GetWorkerTaskReachabilityRequest( + namespace=self._client.namespace, + build_ids=input.build_ids, + task_queues=input.task_queues, + ) + resp = await self._client.workflow_service.get_worker_task_reachability( + req, retry=True, metadata=input.rpc_metadata, timeout=input.rpc_timeout + ) + return WorkerTaskReachability._from_proto(resp) + def _history_from_json( history: Union[str, Dict[str, Any]] @@ -4734,3 +4922,251 @@ def _fix_history_enum(prefix: str, parent: Dict[str, Any], *attrs: str) -> None: for child_item in child: if isinstance(child_item, dict): _fix_history_enum(prefix, child_item, *attrs[1:]) + + +@dataclass(frozen=True) +class WorkerBuildIdVersionSets: + """Represents the sets of compatible Build ID versions associated with some Task Queue, as + fetched by :py:meth:`Client.get_worker_build_id_compatibility`. + """ + + version_sets: Sequence[BuildIdVersionSet] + """All version sets that were fetched for this task queue.""" + + def default_set(self) -> BuildIdVersionSet: + """Returns the default version set for this task queue.""" + return self.version_sets[-1] + + def default_build_id(self) -> str: + """Returns the default Build ID for this task queue.""" + return self.default_set().default() + + @staticmethod + def _from_proto( + resp: temporalio.api.workflowservice.v1.GetWorkerBuildIdCompatibilityResponse, + ) -> WorkerBuildIdVersionSets: + return WorkerBuildIdVersionSets( + version_sets=[ + BuildIdVersionSet(mvs.build_ids) for mvs in resp.major_version_sets + ] + ) + + +@dataclass(frozen=True) +class BuildIdVersionSet: + """A set of Build IDs which are compatible with each other.""" + + build_ids: Sequence[str] + """All Build IDs contained in the set.""" + + def default(self) -> str: + """Returns the default Build ID for this set.""" + return self.build_ids[-1] + + +class BuildIdOp(ABC): + """Base class for Build ID operations as used by + :py:meth:`Client.update_worker_build_id_compatibility`. + """ + + @abstractmethod + def _as_partial_proto( + self, + ) -> temporalio.api.workflowservice.v1.UpdateWorkerBuildIdCompatibilityRequest: + """Returns a partial request with the operation populated. Caller must populate + non-operation fields. This is done b/c there's no good way to assign a non-primitive message + as the operation after initializing the request. + """ + ... + + +@dataclass(frozen=True) +class BuildIdOpAddNewDefault(BuildIdOp): + """Adds a new Build Id into a new set, which will be used as the default set for + the queue. This means all new workflows will start on this Build Id. + """ + + build_id: str + + def _as_partial_proto( + self, + ) -> temporalio.api.workflowservice.v1.UpdateWorkerBuildIdCompatibilityRequest: + return ( + temporalio.api.workflowservice.v1.UpdateWorkerBuildIdCompatibilityRequest( + add_new_build_id_in_new_default_set=self.build_id + ) + ) + + +@dataclass(frozen=True) +class BuildIdOpAddNewCompatible(BuildIdOp): + """Adds a new Build Id into an existing compatible set. The newly added ID becomes + the default for that compatible set, and thus new workflow tasks for workflows which have been + executing on workers in that set will now start on this new Build Id. + """ + + build_id: str + """The Build Id to add to the compatible set.""" + + existing_compatible_build_id: str + """A Build Id which must already be defined on the task queue, and is used to find the + compatible set to add the new id to. + """ + + promote_set: bool = False + """If set to true, the targeted set will also be promoted to become the overall default set for + the queue.""" + + def _as_partial_proto( + self, + ) -> temporalio.api.workflowservice.v1.UpdateWorkerBuildIdCompatibilityRequest: + return temporalio.api.workflowservice.v1.UpdateWorkerBuildIdCompatibilityRequest( + add_new_compatible_build_id=temporalio.api.workflowservice.v1.UpdateWorkerBuildIdCompatibilityRequest.AddNewCompatibleVersion( + new_build_id=self.build_id, + existing_compatible_build_id=self.existing_compatible_build_id, + make_set_default=self.promote_set, + ) + ) + + +@dataclass(frozen=True) +class BuildIdOpPromoteSetByBuildId(BuildIdOp): + """Promotes a set of compatible Build Ids to become the current default set for the task queue. + Any Build Id in the set may be used to target it. + """ + + build_id: str + """A Build Id which must already be defined on the task queue, and is used to find the + compatible set to promote.""" + + def _as_partial_proto( + self, + ) -> temporalio.api.workflowservice.v1.UpdateWorkerBuildIdCompatibilityRequest: + return ( + temporalio.api.workflowservice.v1.UpdateWorkerBuildIdCompatibilityRequest( + promote_set_by_build_id=self.build_id + ) + ) + + +@dataclass(frozen=True) +class BuildIdOpPromoteBuildIdWithinSet(BuildIdOp): + """Promotes a Build Id within an existing set to become the default ID for that set.""" + + build_id: str + + def _as_partial_proto( + self, + ) -> temporalio.api.workflowservice.v1.UpdateWorkerBuildIdCompatibilityRequest: + return ( + temporalio.api.workflowservice.v1.UpdateWorkerBuildIdCompatibilityRequest( + promote_build_id_within_set=self.build_id + ) + ) + + +@dataclass(frozen=True) +class BuildIdOpMergeSets(BuildIdOp): + """Merges two sets into one set, thus declaring all the Build Ids in both as compatible with one + another. The default of the primary set is maintained as the merged set's overall default. + """ + + primary_build_id: str + """A Build Id which and is used to find the primary set to be merged.""" + + secondary_build_id: str + """A Build Id which and is used to find the secondary set to be merged.""" + + def _as_partial_proto( + self, + ) -> temporalio.api.workflowservice.v1.UpdateWorkerBuildIdCompatibilityRequest: + return temporalio.api.workflowservice.v1.UpdateWorkerBuildIdCompatibilityRequest( + merge_sets=temporalio.api.workflowservice.v1.UpdateWorkerBuildIdCompatibilityRequest.MergeSets( + primary_set_build_id=self.primary_build_id, + secondary_set_build_id=self.secondary_build_id, + ) + ) + + +@dataclass(frozen=True) +class WorkerTaskReachability: + """Contains information about the reachability of some Build IDs""" + + build_id_reachability: Mapping[str, BuildIdReachability] + """Maps Build IDs to information about their reachability""" + + @staticmethod + def _from_proto( + resp: temporalio.api.workflowservice.v1.GetWorkerTaskReachabilityResponse, + ) -> WorkerTaskReachability: + mapping = dict() + for bid_reach in resp.build_id_reachability: + tq_mapping = dict() + unretrieved = set() + for tq_reach in bid_reach.task_queue_reachability: + if tq_reach.reachability == [ + temporalio.api.enums.v1.TaskReachability.TASK_REACHABILITY_UNSPECIFIED + ]: + unretrieved.add(tq_reach.task_queue) + continue + tq_mapping[tq_reach.task_queue] = [ + TaskReachabilityType._from_proto(r) for r in tq_reach.reachability + ] + + mapping[bid_reach.build_id] = BuildIdReachability( + task_queue_reachability=tq_mapping, + unretrieved_task_queues=frozenset(unretrieved), + ) + + return WorkerTaskReachability(build_id_reachability=mapping) + + +@dataclass(frozen=True) +class BuildIdReachability: + """Contains information about the reachability of a specific Build ID""" + + task_queue_reachability: Mapping[str, Sequence[TaskReachabilityType]] + """Maps Task Queue names to the reachability status of the Build ID on that queue. If the value + is an empty list, the Build ID is not reachable on that queue. + """ + + unretrieved_task_queues: FrozenSet[str] + """If any Task Queues could not be retrieved because the server limits the number that can be + queried at once, they will be listed here. + """ + + +class TaskReachabilityType(Enum): + """Enumerates how a task might reach certain kinds of workflows""" + + NEW_WORKFLOWS = 1 + EXISTING_WORKFLOWS = 2 + OPEN_WORKFLOWS = 3 + CLOSED_WORKFLOWS = 4 + + @staticmethod + def _from_proto( + reachability: temporalio.api.enums.v1.TaskReachability.ValueType, + ) -> TaskReachabilityType: + if ( + reachability + == temporalio.api.enums.v1.TaskReachability.TASK_REACHABILITY_NEW_WORKFLOWS + ): + return TaskReachabilityType.NEW_WORKFLOWS + elif ( + reachability + == temporalio.api.enums.v1.TaskReachability.TASK_REACHABILITY_EXISTING_WORKFLOWS + ): + return TaskReachabilityType.EXISTING_WORKFLOWS + elif ( + reachability + == temporalio.api.enums.v1.TaskReachability.TASK_REACHABILITY_OPEN_WORKFLOWS + ): + return TaskReachabilityType.OPEN_WORKFLOWS + elif ( + reachability + == temporalio.api.enums.v1.TaskReachability.TASK_REACHABILITY_CLOSED_WORKFLOWS + ): + return TaskReachabilityType.CLOSED_WORKFLOWS + else: + raise ValueError(f"Cannot convert reachability type: {reachability}") diff --git a/temporalio/worker/_replayer.py b/temporalio/worker/_replayer.py index a6531e95..44350e66 100644 --- a/temporalio/worker/_replayer.py +++ b/temporalio/worker/_replayer.py @@ -176,6 +176,7 @@ async def workflow_replay_iterator( max_activities_per_second=None, max_task_queue_activities_per_second=None, graceful_shutdown_period_millis=0, + use_worker_versioning=False, ), ) diff --git a/temporalio/worker/_worker.py b/temporalio/worker/_worker.py index 47234ece..15af12ce 100644 --- a/temporalio/worker/_worker.py +++ b/temporalio/worker/_worker.py @@ -75,6 +75,7 @@ def __init__( debug_mode: bool = False, disable_eager_activity_execution: bool = False, on_fatal_error: Optional[Callable[[BaseException], Awaitable[None]]] = None, + use_worker_versioning: bool = False, ) -> None: """Create a worker to process workflows and/or activities. @@ -177,9 +178,18 @@ def __init__( on_fatal_error: An async function that can handle a failure before the worker shutdown commences. This cannot stop the shutdown and any exception raised is logged and ignored. + use_worker_versioning: If true, the `build_id` argument must be specified, and this + worker opts into the worker versioning feature. This ensures it only receives + workflow tasks for workflows which it claims to be compatible with. + + For more information, see https://docs.temporal.io/workers#worker-versioning """ if not activities and not workflows: raise ValueError("At least one activity or workflow must be specified") + if use_worker_versioning and not build_id: + raise ValueError( + "build_id must be specified when use_worker_versioning is True" + ) # Prepend applicable client interceptors to the given ones client_config = client.config() @@ -238,6 +248,7 @@ def __init__( debug_mode=debug_mode, disable_eager_activity_execution=disable_eager_activity_execution, on_fatal_error=on_fatal_error, + use_worker_versioning=use_worker_versioning, ) self._started = False self._shutdown_event = asyncio.Event() @@ -324,6 +335,7 @@ def __init__( graceful_shutdown_period_millis=int( 1000 * graceful_shutdown_timeout.total_seconds() ), + use_worker_versioning=use_worker_versioning, ), ) @@ -558,6 +570,7 @@ class WorkerConfig(TypedDict, total=False): debug_mode: bool disable_eager_activity_execution: bool on_fatal_error: Optional[Callable[[BaseException], Awaitable[None]]] + use_worker_versioning: bool _default_build_id: Optional[str] = None diff --git a/tests/helpers/__init__.py b/tests/helpers/__init__.py index b8140d24..d15dcf88 100644 --- a/tests/helpers/__init__.py +++ b/tests/helpers/__init__.py @@ -4,7 +4,8 @@ from datetime import timedelta from typing import Awaitable, Callable, Optional, Sequence, Type, TypeVar -from temporalio.client import Client +from temporalio.client import BuildIdOpAddNewDefault, Client +from temporalio.service import RPCError, RPCStatusCode from temporalio.worker import Worker, WorkflowRunner from temporalio.worker.workflow_sandbox import SandboxedWorkflowRunner @@ -16,6 +17,7 @@ def new_worker( task_queue: Optional[str] = None, workflow_runner: WorkflowRunner = SandboxedWorkflowRunner(), max_cached_workflows: int = 1000, + **kwargs, ) -> Worker: return Worker( client, @@ -24,6 +26,7 @@ def new_worker( activities=activities, workflow_runner=workflow_runner, max_cached_workflows=max_cached_workflows, + **kwargs, ) @@ -47,3 +50,16 @@ async def assert_eq_eventually( assert ( expected == last_value ), "timed out waiting for equal, asserted against last value" + + +async def worker_versioning_enabled(client: Client) -> bool: + tq = f"worker-versioning-init-test-{uuid.uuid4()}" + try: + await client.update_worker_build_id_compatibility( + tq, BuildIdOpAddNewDefault("testver") + ) + return True + except RPCError as e: + if e.status in [RPCStatusCode.PERMISSION_DENIED, RPCStatusCode.UNIMPLEMENTED]: + return False + raise diff --git a/tests/test_client.py b/tests/test_client.py index 6f523714..913f3aef 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -27,6 +27,11 @@ from temporalio.api.history.v1 import History from temporalio.api.workflowservice.v1 import GetSystemInfoRequest from temporalio.client import ( + BuildIdOpAddNewCompatible, + BuildIdOpAddNewDefault, + BuildIdOpMergeSets, + BuildIdOpPromoteBuildIdWithinSet, + BuildIdOpPromoteSetByBuildId, CancelWorkflowInput, Client, Interceptor, @@ -50,6 +55,7 @@ ScheduleUpdateInput, SignalWorkflowInput, StartWorkflowInput, + TaskReachabilityType, TerminateWorkflowInput, WorkflowContinuedAsNewError, WorkflowExecutionStatus, @@ -63,7 +69,7 @@ from temporalio.converter import DataConverter from temporalio.exceptions import WorkflowAlreadyStartedError from temporalio.testing import WorkflowEnvironment -from tests.helpers import assert_eq_eventually, new_worker +from tests.helpers import assert_eq_eventually, new_worker, worker_versioning_enabled from tests.helpers.worker import ( ExternalWorker, KSAction, @@ -984,3 +990,51 @@ async def schedule_count() -> int: return len([d async for d in await client.list_schedules()]) await assert_eq_eventually(0, schedule_count) + + +async def test_build_id_interactions(client: Client, env: WorkflowEnvironment): + if env.supports_time_skipping: + pytest.skip("Java test server does not support worker versioning") + if not await worker_versioning_enabled(client): + pytest.skip("This server does not have worker versioning enabled") + + tq = "test-build-id-interactions_" + str(uuid.uuid4()) + + await client.update_worker_build_id_compatibility(tq, BuildIdOpAddNewDefault("1.0")) + await client.update_worker_build_id_compatibility( + tq, BuildIdOpAddNewCompatible("1.1", "1.0") + ) + sets = await client.get_worker_build_id_compatibility(tq) + assert sets.default_build_id() == "1.1" + assert sets.default_set().build_ids[0] == "1.0" + + await client.update_worker_build_id_compatibility( + tq, BuildIdOpPromoteBuildIdWithinSet("1.0") + ) + sets = await client.get_worker_build_id_compatibility(tq) + assert sets.default_build_id() == "1.0" + + await client.update_worker_build_id_compatibility(tq, BuildIdOpAddNewDefault("2.0")) + sets = await client.get_worker_build_id_compatibility(tq) + assert sets.default_build_id() == "2.0" + + await client.update_worker_build_id_compatibility( + tq, BuildIdOpPromoteSetByBuildId("1.0") + ) + sets = await client.get_worker_build_id_compatibility(tq) + assert sets.default_build_id() == "1.0" + + await client.update_worker_build_id_compatibility( + tq, BuildIdOpMergeSets(primary_build_id="2.0", secondary_build_id="1.0") + ) + sets = await client.get_worker_build_id_compatibility(tq) + assert sets.default_build_id() == "2.0" + + reachability = await client.get_worker_task_reachability( + build_ids=["2.0", "1.0", "1.1"] + ) + assert reachability.build_id_reachability["2.0"].task_queue_reachability[tq] == [ + TaskReachabilityType.NEW_WORKFLOWS + ] + assert reachability.build_id_reachability["1.0"].task_queue_reachability[tq] == [] + assert reachability.build_id_reachability["1.1"].task_queue_reachability[tq] == [] diff --git a/tests/worker/test_worker.py b/tests/worker/test_worker.py index a2a7f52f..e2cb7303 100644 --- a/tests/worker/test_worker.py +++ b/tests/worker/test_worker.py @@ -8,8 +8,10 @@ import temporalio.worker._worker from temporalio import activity, workflow -from temporalio.client import Client +from temporalio.client import BuildIdOpAddNewDefault, Client +from temporalio.testing import WorkflowEnvironment from temporalio.worker import Worker +from tests.helpers import new_worker, worker_versioning_enabled def test_load_default_worker_binary_id(): @@ -126,6 +128,67 @@ async def test_worker_cancel_run(client: Client): assert not worker.is_running and worker.is_shutdown +@workflow.defn +class WaitOnSignalWorkflow: + def __init__(self) -> None: + self._last_signal = "" + + @workflow.run + async def run(self) -> None: + await workflow.wait_condition(lambda: self._last_signal == "finish") + + @workflow.signal + def my_signal(self, value: str) -> None: + self._last_signal = value + workflow.logger.info(f"Signal: {value}") + + +async def test_worker_versioning(client: Client, env: WorkflowEnvironment): + if env.supports_time_skipping: + pytest.skip("Java test server does not support worker versioning") + if not await worker_versioning_enabled(client): + pytest.skip("This server does not have worker versioning enabled") + + task_queue = f"worker-versioning-{uuid.uuid4()}" + await client.update_worker_build_id_compatibility( + task_queue, BuildIdOpAddNewDefault("1.0") + ) + + async with new_worker( + client, + WaitOnSignalWorkflow, + task_queue=task_queue, + build_id="1.0", + use_worker_versioning=True, + ): + wf1 = await client.start_workflow( + WaitOnSignalWorkflow.run, + id=f"worker-versioning-1-{uuid.uuid4()}", + task_queue=task_queue, + ) + # Sleep for a beat, otherwise it's possible for new workflow to start on 2.0 + await asyncio.sleep(0.1) + await client.update_worker_build_id_compatibility( + task_queue, BuildIdOpAddNewDefault("2.0") + ) + wf2 = await client.start_workflow( + WaitOnSignalWorkflow.run, + id=f"worker-versioning-2-{uuid.uuid4()}", + task_queue=task_queue, + ) + async with new_worker( + client, + WaitOnSignalWorkflow, + task_queue=task_queue, + build_id="2.0", + use_worker_versioning=True, + ): + await wf1.signal(WaitOnSignalWorkflow.my_signal, "finish") + await wf2.signal(WaitOnSignalWorkflow.my_signal, "finish") + await wf1.result() + await wf2.result() + + def create_worker( client: Client, on_fatal_error: Optional[Callable[[BaseException], Awaitable[None]]] = None,