From d45d74177309abb19e5996b41fc37254e3749041 Mon Sep 17 00:00:00 2001 From: rhoadesScholar Date: Fri, 26 Jan 2024 14:45:37 -0500 Subject: [PATCH] =?UTF-8?q?fix:=20=F0=9F=9A=91=EF=B8=8F=20Fix=20tcp=20time?= =?UTF-8?q?out=20specification.?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Timeouts were hardcoded to 0.1, despite having an input variable in Tasks(). This was corrected by moving specification to the run_blockwise() function call via the "tcp_timeout" kwarg. Task() no longer takes a timeout argument for inputs. convenience.run_blockwise() takes tcp_timeout, defaulting to 0.1, as previously was hardcoded. --- daisy/client.py | 38 +++++++-------- daisy/convenience.py | 32 ++++++------ daisy/server.py | 97 +++++++++++++------------------------ daisy/task.py | 12 +---- daisy/tcp/tcp_server.py | 2 +- tests/test_clients_close.py | 1 - tests/test_server.py | 25 +++++----- 7 files changed, 83 insertions(+), 124 deletions(-) diff --git a/daisy/client.py b/daisy/client.py index 55155939..31316884 100644 --- a/daisy/client.py +++ b/daisy/client.py @@ -6,7 +6,8 @@ ReleaseBlock, RequestShutdown, SendBlock, - UnexpectedMessage) + UnexpectedMessage, +) from contextlib import contextmanager from daisy.tcp import TCPClient, StreamClosedError import logging @@ -14,8 +15,8 @@ logger = logging.getLogger(__name__) -class Client(): - '''Client code that runs on a remote worker providing task management +class Client: + """Client code that runs on a remote worker providing task management API for user code. It communicates with the scheduler through TCP/IP. Scheduler IP address, port, and other configurations are typically @@ -35,12 +36,10 @@ def main(): break blockwise_process(block) block.state = BlockStatus.SUCCESS # (or FAILED) - ''' + """ - def __init__( - self, - context=None): - '''Initialize a client and connect to the server. + def __init__(self, context=None): + """Initialize a client and connect to the server. Args: @@ -50,29 +49,30 @@ def __init__( given, the context will be read from environment variable ``DAISY_CONTEXT``. - ''' + """ logger.debug("Client init") self.context = context if self.context is None: self.context = Context.from_env() logger.debug("Client context: %s", self.context) - self.host = self.context['hostname'] - self.port = int(self.context['port']) - self.worker_id = int(self.context['worker_id']) - self.task_id = self.context['task_id'] + self.host = self.context["hostname"] + self.port = int(self.context["port"]) + self.worker_id = int(self.context["worker_id"]) + self.task_id = self.context["task_id"] + self.tcp_timeout = self.context["tcp_timeout"] # Make TCP Connection self.tcp_client = TCPClient(self.host, self.port) @contextmanager def acquire_block(self): - '''API for client to get a new block.''' + """API for client to get a new block.""" self.tcp_client.send_message(AcquireBlock(self.task_id)) message = None try: while message is None: - message = self.tcp_client.get_message(timeout=0.1) + message = self.tcp_client.get_message(timeout=self.tcp_timeout) except StreamClosedError: logger.debug("TCP stream was closed, server is probably down") yield @@ -89,12 +89,8 @@ def acquire_block(self): block.status = BlockStatus.SUCCESS except Exception as e: block.status = BlockStatus.FAILED - self.tcp_client.send_message( - BlockFailed(e, block, self.context)) - logger.exception( - "Block %s failed in worker %d", - block, - self.worker_id) + self.tcp_client.send_message(BlockFailed(e, block, self.context)) + logger.exception("Block %s failed in worker %d", block, self.worker_id) finally: # if we somehow got here without setting the block status to # "SUCCESS" (e.g., through KeyboardInterrupt), we assume the diff --git a/daisy/convenience.py b/daisy/convenience.py index 767fb057..228df461 100644 --- a/daisy/convenience.py +++ b/daisy/convenience.py @@ -5,18 +5,20 @@ from multiprocessing import Event -def run_blockwise(tasks): - '''Schedule and run the given tasks. - - Args: - list_of_tasks: - The tasks to schedule over. - - Return: - bool: - `True` if all blocks in the given `tasks` were successfully - run, else `False` - ''' +def run_blockwise(tasks, tcp_timeout=0.1): + """Schedule and run the given tasks. + + Args: + list_of_tasks: + The tasks to schedule over. + tcp_timeout (float, optional): + The timeout for TCP connections to the scheduler. + + Return: + bool: + `True` if all blocks in the given `tasks` were successfully + run, else `False` + """ task_ids = set() all_tasks = [] while len(tasks) > 0: @@ -31,7 +33,7 @@ def run_blockwise(tasks): IOLooper.clear() pool = ThreadPool(processes=1) - result = pool.apply_async(_run_blockwise, args=(tasks, stop_event)) + result = pool.apply_async(_run_blockwise, args=(tasks, stop_event, tcp_timeout)) try: return result.get() except KeyboardInterrupt: @@ -39,7 +41,7 @@ def run_blockwise(tasks): return result.get() -def _run_blockwise(tasks, stop_event): - server = Server(stop_event=stop_event) +def _run_blockwise(tasks, stop_event, tcp_timeout): + server = Server(stop_event=stop_event, tcp_timeout=tcp_timeout) cl_monitor = CLMonitor(server) # noqa return server.run_blockwise(tasks) diff --git a/daisy/server.py b/daisy/server.py index 8d695cba..54d9a8d9 100644 --- a/daisy/server.py +++ b/daisy/server.py @@ -8,7 +8,8 @@ ReleaseBlock, SendBlock, RequestShutdown, - UnexpectedMessage) + UnexpectedMessage, +) from .scheduler import Scheduler from .server_observer import ServerObservee from .task_worker_pools import TaskWorkerPools @@ -22,9 +23,7 @@ class Server(ServerObservee): - - def __init__(self, stop_event=None): - + def __init__(self, stop_event=None, tcp_timeout=0.1): super().__init__() if stop_event is None: @@ -32,16 +31,13 @@ def __init__(self, stop_event=None): else: self.stop_event = stop_event + self.tcp_timeout = tcp_timeout self.tcp_server = TCPServer() self.hostname, self.port = self.tcp_server.address - logger.debug( - "Started server listening at %s:%s", - self.hostname, - self.port) + logger.debug("Started server listening at %s:%s", self.hostname, self.port) def run_blockwise(self, tasks, scheduler=None): - if scheduler is None: self.scheduler = Scheduler(tasks) else: @@ -53,10 +49,7 @@ def run_blockwise(self, tasks, scheduler=None): self.finished_tasks = set() self.all_done = False - self.pending_requests = { - task.task_id: Queue() - for task in tasks - } + self.pending_requests = {task.task_id: Queue() for task in tasks} self._recruit_workers() @@ -74,17 +67,14 @@ def run_blockwise(self, tasks, scheduler=None): return True if self.all_done else False def _event_loop(self): - while not self.stop_event.is_set(): - self._handle_client_messages() self._check_for_lost_blocks() self.worker_pools.check_worker_health() def _get_client_message(self): - try: - message = self.tcp_server.get_message(timeout=0.1) + message = self.tcp_server.get_message(timeout=self.tcp_timeout) except StreamClosedError: return @@ -92,7 +82,6 @@ def _get_client_message(self): return message for task_id, requests in self.pending_requests.items(): - if self.pending_requests[task_id].empty(): continue @@ -100,14 +89,12 @@ def _get_client_message(self): return self.pending_requests[task_id].get() def _send_client_message(self, stream, message): - try: stream.send_message(message) except StreamClosedError: pass def _handle_client_messages(self): - message = self._get_client_message() if message is None: @@ -116,7 +103,6 @@ def _handle_client_messages(self): self._handle_client_message(message) def _handle_client_message(self, message): - if isinstance(message, AcquireBlock): self._handle_acquire_block(message) elif isinstance(message, ReleaseBlock): @@ -129,7 +115,6 @@ def _handle_client_message(self, message): self._check_all_tasks_completed() def _handle_acquire_block(self, message): - logger.debug("Received block request for task %s", message.task_id) task_state = self.scheduler.task_states[message.task_id] @@ -139,48 +124,42 @@ def _handle_acquire_block(self, message): block = self.scheduler.acquire_block(message.task_id) if block is None: - assert task_state.ready_count == 0 if task_state.pending_count == 0: - logger.debug( - "No more pending blocks for task %s, terminating " - "client", message.task_id) + "No more pending blocks for task %s, terminating " "client", + message.task_id, + ) - self._send_client_message( - message.stream, - RequestShutdown()) + self._send_client_message(message.stream, RequestShutdown()) return # there are more blocks for this task, but none of them has its # dependencies fullfilled logger.debug( - "No currently ready blocks for task %s, delaying " - "request", message.task_id) + "No currently ready blocks for task %s, delaying " "request", + message.task_id, + ) self.pending_requests[message.task_id].put(message) else: - try: logger.debug("Sending block %s to client", block) - self._send_client_message( - message.stream, - SendBlock(block)) + self._send_client_message(message.stream, SendBlock(block)) finally: self.block_bookkeeper.notify_block_sent(block, message.stream) self.notify_acquire_block(message.task_id, task_state) def _handle_release_block(self, message): - logger.debug("Client releases block %s", message.block) self._safe_release_block(message.block, message.stream) def _release_block(self, block): - '''Returns a block to the scheduler and checks whether all tasks are - completed.''' + """Returns a block to the scheduler and checks whether all tasks are + completed.""" self.scheduler.release_block(block) task_states = self.scheduler.task_states @@ -192,7 +171,7 @@ def _release_block(self, block): self._recruit_workers() def _check_all_tasks_completed(self): - '''Check if all tasks are completed and stop''' + """Check if all tasks are completed and stop""" self.all_done = True task_states = self.scheduler.task_states @@ -208,18 +187,15 @@ def _check_all_tasks_completed(self): self.all_done = False - logger.debug( - "Task %s has %d ready blocks", - task_id, - task_state.ready_count) + logger.debug("Task %s has %d ready blocks", task_id, task_state.ready_count) if self.all_done: logger.debug("All tasks finished") self.stop_event.set() def _safe_release_block(self, block, stream): - '''Releases a block, if the bookkeeper agrees that this is a valid - return from the given stream.''' + """Releases a block, if the bookkeeper agrees that this is a valid + return from the given stream.""" valid = self.block_bookkeeper.is_valid_return(block, stream) if valid: @@ -227,58 +203,53 @@ def _safe_release_block(self, block, stream): self.block_bookkeeper.notify_block_returned(block, stream) else: logger.debug( - "Attempted to return unexpected block %s from %s", - block, stream) + "Attempted to return unexpected block %s from %s", block, stream + ) def _handle_client_exception(self, message): - if isinstance(message, BlockFailed): - logger.error( "Block %s failed in worker %s with %s", message.block, - message.context['worker_id'], - repr(message.exception)) + message.context["worker_id"], + repr(message.exception), + ) message.block.status = BlockStatus.FAILED self._safe_release_block(message.block, message.stream) - self.notify_block_failure( - message.block, - message.exception, - message.context) + self.notify_block_failure(message.block, message.exception, message.context) else: raise message.exception def _recruit_workers(self): - ready_tasks = self.scheduler.get_ready_tasks() ready_tasks = {task.task_id: task for task in ready_tasks} for task_id in ready_tasks.keys(): if task_id not in self.started_tasks: - self.notify_task_start( - task_id, - self.scheduler.task_states[task_id]) + self.notify_task_start(task_id, self.scheduler.task_states[task_id]) self.started_tasks.add(task_id) # run the task's callback function - ready_tasks[task_id].init_callback_fn(Context( + ready_tasks[task_id].init_callback_fn( + Context( hostname=self.hostname, port=self.port, task_id=task_id, - worker_id=0)) + worker_id=0, + tcp_timeout=self.tcp_timeout, + ) + ) self.worker_pools.recruit_workers(ready_tasks) def _check_for_lost_blocks(self): - lost_blocks = self.block_bookkeeper.get_lost_blocks() # mark as failed and release the lost blocks for block in lost_blocks: - logger.error("Block %s was lost, returning it to scheduler", block) block.status = BlockStatus.FAILED self._release_block(block) diff --git a/daisy/task.py b/daisy/task.py index ce4ba857..6d2d3991 100644 --- a/daisy/task.py +++ b/daisy/task.py @@ -3,7 +3,7 @@ class Task: - '''Definition of a ``daisy`` task that is to be run in a block-wise + """Definition of a ``daisy`` task that is to be run in a block-wise fashion. Args: @@ -117,13 +117,8 @@ class Task: The maximum number of times a task will be retried if failed (either due to failed post check or application crashes or network failure) + """ - timeout (int, optional): - - Time in seconds to wait for a block to be returned from a worker. - The worker is killed (and the block retried) if this time is - exceeded. - ''' def __init__( self, task_id, @@ -137,7 +132,6 @@ def __init__( num_workers=1, max_retries=2, fit="valid", - timeout=None, upstream_tasks=None, ): self.task_id = task_id @@ -155,7 +149,6 @@ def __init__( self.fit = fit self.num_workers = num_workers self.max_retries = max_retries - self.timeout = timeout self.upstream_tasks = [] if upstream_tasks is not None: self.upstream_tasks.extend(upstream_tasks) @@ -170,7 +163,6 @@ def __init__( self.spawn_worker_function = lambda: self._process_blocks() def _process_blocks(self): - client = Client() while True: with client.acquire_block() as block: diff --git a/daisy/tcp/tcp_server.py b/daisy/tcp/tcp_server.py index ed94046c..a949cdf8 100644 --- a/daisy/tcp/tcp_server.py +++ b/daisy/tcp/tcp_server.py @@ -57,7 +57,7 @@ def get_message(self, timeout=None): Args: - timeout (int, optional): + timeout (float, optional): If set, wait up to `timeout` seconds for a message to arrive. If no message is available after the timeout, returns ``None``. diff --git a/tests/test_clients_close.py b/tests/test_clients_close.py index 200611ca..239d4589 100644 --- a/tests/test_clients_close.py +++ b/tests/test_clients_close.py @@ -30,7 +30,6 @@ def start_worker(): fit="valid", num_workers=num_workers, max_retries=2, - timeout=None, ) server = daisy.Server() diff --git a/tests/test_server.py b/tests/test_server.py index 6cde20b4..d70883b9 100644 --- a/tests/test_server.py +++ b/tests/test_server.py @@ -4,22 +4,21 @@ logging.basicConfig(level=logging.DEBUG) -class TestServer(unittest.TestCase): +class TestServer(unittest.TestCase): def test_basic(self): - task = daisy.Task( - 'test_server_task', - total_roi=daisy.Roi((0,), (100,)), - read_roi=daisy.Roi((0,), (10,)), - write_roi=daisy.Roi((1,), (8,)), - process_function=lambda b: self.process_block(b), - check_function=None, - read_write_conflict=True, - fit='valid', - num_workers=1, - max_retries=2, - timeout=None) + "test_server_task", + total_roi=daisy.Roi((0,), (100,)), + read_roi=daisy.Roi((0,), (10,)), + write_roi=daisy.Roi((1,), (8,)), + process_function=lambda b: self.process_block(b), + check_function=None, + read_write_conflict=True, + fit="valid", + num_workers=1, + max_retries=2, + ) server = daisy.Server() server.run_blockwise([task])