Skip to content

Commit

Permalink
fix: ensure task monitor in taskqueue functions properly during rollo…
Browse files Browse the repository at this point in the history
…ut restart (#919)

…ut restarts
  • Loading branch information
luke-lombardi authored Feb 3, 2025
1 parent 845be2e commit 38f93d9
Show file tree
Hide file tree
Showing 2 changed files with 84 additions and 56 deletions.
16 changes: 15 additions & 1 deletion pkg/abstractions/taskqueue/taskqueue.go
Original file line number Diff line number Diff line change
Expand Up @@ -427,6 +427,9 @@ func (tq *RedisTaskQueue) TaskQueueMonitor(req *pb.TaskQueueMonitorRequest, stre

for {
select {
case <-tq.ctx.Done():
return

case <-timeoutChan:
err := timeoutCallback()
if err != nil {
Expand Down Expand Up @@ -456,8 +459,19 @@ func (tq *RedisTaskQueue) TaskQueueMonitor(req *pb.TaskQueueMonitorRequest, stre

for {
select {
case <-tq.ctx.Done():
return nil

case <-stream.Context().Done():
tq.rdb.Del(context.Background(), Keys.taskQueueTaskRunningLock(authInfo.Workspace.Name, req.StubId, req.ContainerId, task.ExternalId))
task, err := tq.backendRepo.GetTask(ctx, req.TaskId)
if err != nil {
return err
}

if task.Status.IsCompleted() {
tq.rdb.Del(context.Background(), Keys.taskQueueTaskRunningLock(authInfo.Workspace.Name, req.StubId, req.ContainerId, task.ExternalId))
}

return nil

case <-cancelFlag:
Expand Down
124 changes: 69 additions & 55 deletions sdk/src/beta9/runner/taskqueue.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
import time
import traceback
from concurrent import futures
from concurrent.futures import CancelledError
from multiprocessing import Event, Process, set_start_method
from multiprocessing.synchronize import Event as TEvent
from typing import Any, List, NamedTuple, Type, Union
Expand Down Expand Up @@ -197,71 +196,86 @@ def _monitor_task(
taskqueue_stub: TaskQueueServiceStub,
gateway_stub: GatewayServiceStub,
) -> None:
initial_backoff = 5
max_retries = 5
backoff = initial_backoff
retry = 0

while retry <= max_retries:
try:
for response in taskqueue_stub.task_queue_monitor(
TaskQueueMonitorRequest(
task_id=task.id,
stub_id=stub_id,
container_id=container_id,
)
):
response: TaskQueueMonitorResponse
if response.cancelled:
print(f"Task cancelled: {task.id}")

send_callback(
gateway_stub=gateway_stub,
context=context,
payload={},
task_status=TaskStatus.Cancelled,
def _monitor_stream() -> bool:
"""
Returns True if the stream ended with no errors (and should be restarted),
or False if a exit event occurred (cancellation, completion, timeout,
or a connection issue that caused us to kill the parent process)
"""
initial_backoff = 5
max_retries = 5
backoff = initial_backoff
retry = 0

while retry <= max_retries:
try:
for response in taskqueue_stub.task_queue_monitor(
TaskQueueMonitorRequest(
task_id=task.id,
stub_id=stub_id,
container_id=container_id,
)
os._exit(TaskExitCode.Cancelled)
):
response: TaskQueueMonitorResponse
if response.cancelled:
print(f"Task cancelled: {task.id}")

if response.complete:
return
send_callback(
gateway_stub=gateway_stub,
context=context,
payload={},
task_status=TaskStatus.Cancelled,
)
os._exit(TaskExitCode.Cancelled)

if response.timed_out:
print(f"Task timed out: {task.id}")
if response.complete:
return False

send_callback(
gateway_stub=gateway_stub,
context=context,
payload={},
task_status=TaskStatus.Timeout,
)
os._exit(TaskExitCode.Timeout)
if response.timed_out:
print(f"Task timed out: {task.id}")

retry = 0
backoff = initial_backoff
send_callback(
gateway_stub=gateway_stub,
context=context,
payload={},
task_status=TaskStatus.Timeout,
)
os._exit(TaskExitCode.Timeout)

# If successful, it means the stream is finished.
# Break out of the retry loop
break
retry = 0
backoff = initial_backoff

except (
grpc.RpcError,
ConnectionRefusedError,
):
if retry == max_retries:
print("Lost connection to task monitor, exiting")
os._exit(0)
# Reaching here means that the stream ended with no errors,
# which can occur during a rollout restart of the gateway
# returning True here tells the outer loop to restart the stream
return True

time.sleep(backoff)
backoff *= 2
retry += 1
except (
grpc.RpcError,
ConnectionRefusedError,
):
if retry == max_retries:
print("Lost connection to task monitor, exiting")
os._exit(0)

except (CancelledError, ValueError):
print(f"Lost connection to task monitor, retrying... {retry}")
time.sleep(backoff)
backoff *= 2
retry += 1

except BaseException:
print(f"Unexpected error occurred in task monitor: {traceback.format_exc()}")
os._exit(0)

# Outer loop: restart only if the stream ended with no errors
while True:
should_restart = _monitor_stream()
if not should_restart:
# Exit condition encountered; exit the monitor task completely
return

except BaseException:
print(f"Unexpected error occurred in task monitor: {traceback.format_exc()}")
os._exit(0)
# If we reached here, the stream ended with no errors;
# so we should restart the monitoring stream

@with_runner_context
def process_tasks(self, channel: Channel) -> None:
Expand Down

0 comments on commit 38f93d9

Please sign in to comment.