Skip to content

Commit

Permalink
fix: 🚑️ Fix tcp timeout specification.
Browse files Browse the repository at this point in the history
Timeouts were hardcoded to 0.1, despite having an input variable in Tasks(). This was corrected by moving specification to the run_blockwise() function call via the "tcp_timeout" kwarg.

Task() no longer takes a timeout argument for inputs. convenience.run_blockwise() takes tcp_timeout, defaulting to 0.1, as previously was hardcoded.
  • Loading branch information
rhoadesScholar committed Jan 26, 2024
1 parent 0082824 commit d45d741
Show file tree
Hide file tree
Showing 7 changed files with 83 additions and 124 deletions.
38 changes: 17 additions & 21 deletions daisy/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,16 +6,17 @@
ReleaseBlock,
RequestShutdown,
SendBlock,
UnexpectedMessage)
UnexpectedMessage,
)
from contextlib import contextmanager
from daisy.tcp import TCPClient, StreamClosedError
import logging

logger = logging.getLogger(__name__)


class Client():
'''Client code that runs on a remote worker providing task management
class Client:
"""Client code that runs on a remote worker providing task management
API for user code. It communicates with the scheduler through TCP/IP.
Scheduler IP address, port, and other configurations are typically
Expand All @@ -35,12 +36,10 @@ def main():
break
blockwise_process(block)
block.state = BlockStatus.SUCCESS # (or FAILED)
'''
"""

def __init__(
self,
context=None):
'''Initialize a client and connect to the server.
def __init__(self, context=None):
"""Initialize a client and connect to the server.
Args:
Expand All @@ -50,29 +49,30 @@ def __init__(
given, the context will be read from environment variable
``DAISY_CONTEXT``.
'''
"""
logger.debug("Client init")
self.context = context
if self.context is None:
self.context = Context.from_env()
logger.debug("Client context: %s", self.context)

self.host = self.context['hostname']
self.port = int(self.context['port'])
self.worker_id = int(self.context['worker_id'])
self.task_id = self.context['task_id']
self.host = self.context["hostname"]
self.port = int(self.context["port"])
self.worker_id = int(self.context["worker_id"])
self.task_id = self.context["task_id"]
self.tcp_timeout = self.context["tcp_timeout"]

# Make TCP Connection
self.tcp_client = TCPClient(self.host, self.port)

@contextmanager
def acquire_block(self):
'''API for client to get a new block.'''
"""API for client to get a new block."""
self.tcp_client.send_message(AcquireBlock(self.task_id))
message = None
try:
while message is None:
message = self.tcp_client.get_message(timeout=0.1)
message = self.tcp_client.get_message(timeout=self.tcp_timeout)
except StreamClosedError:
logger.debug("TCP stream was closed, server is probably down")
yield
Expand All @@ -89,12 +89,8 @@ def acquire_block(self):
block.status = BlockStatus.SUCCESS
except Exception as e:
block.status = BlockStatus.FAILED
self.tcp_client.send_message(
BlockFailed(e, block, self.context))
logger.exception(
"Block %s failed in worker %d",
block,
self.worker_id)
self.tcp_client.send_message(BlockFailed(e, block, self.context))
logger.exception("Block %s failed in worker %d", block, self.worker_id)
finally:
# if we somehow got here without setting the block status to
# "SUCCESS" (e.g., through KeyboardInterrupt), we assume the
Expand Down
32 changes: 17 additions & 15 deletions daisy/convenience.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,18 +5,20 @@
from multiprocessing import Event


def run_blockwise(tasks):
'''Schedule and run the given tasks.
Args:
list_of_tasks:
The tasks to schedule over.
Return:
bool:
`True` if all blocks in the given `tasks` were successfully
run, else `False`
'''
def run_blockwise(tasks, tcp_timeout=0.1):
"""Schedule and run the given tasks.
Args:
list_of_tasks:
The tasks to schedule over.
tcp_timeout (float, optional):
The timeout for TCP connections to the scheduler.
Return:
bool:
`True` if all blocks in the given `tasks` were successfully
run, else `False`
"""
task_ids = set()
all_tasks = []
while len(tasks) > 0:
Expand All @@ -31,15 +33,15 @@ def run_blockwise(tasks):

IOLooper.clear()
pool = ThreadPool(processes=1)
result = pool.apply_async(_run_blockwise, args=(tasks, stop_event))
result = pool.apply_async(_run_blockwise, args=(tasks, stop_event, tcp_timeout))
try:
return result.get()
except KeyboardInterrupt:
stop_event.set()
return result.get()


def _run_blockwise(tasks, stop_event):
server = Server(stop_event=stop_event)
def _run_blockwise(tasks, stop_event, tcp_timeout):
server = Server(stop_event=stop_event, tcp_timeout=tcp_timeout)
cl_monitor = CLMonitor(server) # noqa
return server.run_blockwise(tasks)
Loading

0 comments on commit d45d741

Please sign in to comment.