From f58a46ea6f5ef04f5a84aa8527376947382e7df0 Mon Sep 17 00:00:00 2001 From: Andrew Liu Date: Fri, 24 Jan 2025 08:52:33 +0000 Subject: [PATCH] add client changes --- modal/_utils/shell_utils.py | 58 +++++++++++++++++++++++++++++++++++++ modal/cli/container.py | 4 ++- modal/cli/run.py | 5 ++++ modal/container_process.py | 23 +++++++++++---- modal/runner.py | 11 +++++-- 5 files changed, 93 insertions(+), 8 deletions(-) diff --git a/modal/_utils/shell_utils.py b/modal/_utils/shell_utils.py index f05ee0c89..af4b42b72 100644 --- a/modal/_utils/shell_utils.py +++ b/modal/_utils/shell_utils.py @@ -3,10 +3,17 @@ import asyncio import contextlib import errno +import fcntl import os import select +import signal +import struct import sys +import termios +import threading from collections.abc import Coroutine +from queue import Empty, Queue +from types import FrameType from typing import Callable, Optional from modal._pty import raw_terminal, set_nonblocking @@ -77,3 +84,54 @@ async def _write(): yield os.write(quit_pipe_write, b"\n") write_task.cancel() + + +class WindowSizeHandler: + """Handles terminal window resize events.""" + + def __init__(self): + """Initialize window size handler. Must be called from the main thread to set signals properly. + In case this is invoked from a thread that is not the main thread, e.g. in tests, the context manager + becomes a no-op.""" + self._is_main_thread = threading.current_thread() is threading.main_thread() + self._event_queue: Queue[tuple[int, int]] = Queue() + + if self._is_main_thread and hasattr(signal, "SIGWINCH"): + signal.signal(signal.SIGWINCH, self._queue_resize_event) + + def _queue_resize_event(self, signum: Optional[int] = None, frame: Optional[FrameType] = None) -> None: + """Signal handler for SIGWINCH that queues events.""" + try: + hw = struct.unpack("hh", fcntl.ioctl(sys.stdout.fileno(), termios.TIOCGWINSZ, b"1234")) + rows, cols = hw + self._event_queue.put((rows, cols)) + except Exception: + # ignore failed window size reads + pass + + @contextlib.asynccontextmanager + async def watch_window_size(self, handler: Callable[[int, int], Coroutine]): + """Context manager that processes window resize events from the queue. + Can be run from any thread. If the window manager was initialized from a thread that is not the main thread, + e.g. in tests, this context manager is a no-op. + + Args: + handler: Callback function to handle window resize events + """ + if not self._is_main_thread: + yield + return + + async def process_events(): + while True: + try: + rows, cols = self._event_queue.get_nowait() + await handler(rows, cols) + except Empty: + await asyncio.sleep(0.1) + + event_task = asyncio.create_task(process_events()) + try: + yield + finally: + event_task.cancel() diff --git a/modal/cli/container.py b/modal/cli/container.py index 792d03c34..f9f665abc 100644 --- a/modal/cli/container.py +++ b/modal/cli/container.py @@ -8,6 +8,7 @@ from modal._pty import get_pty_info from modal._utils.async_utils import synchronizer from modal._utils.grpc_utils import retry_transient_errors +from modal._utils.shell_utils import WindowSizeHandler from modal.cli.utils import ENV_OPTION, display_table, is_tty, stream_app_logs, timestamp_to_local from modal.client import _Client from modal.config import config @@ -79,7 +80,8 @@ async def exec( res: api_pb2.ContainerExecResponse = await client.stub.ContainerExec(req) if pty: - await _ContainerProcess(res.exec_id, client).attach() + window_size_handler = WindowSizeHandler() + await _ContainerProcess(res.exec_id, client).attach(window_size_handler=window_size_handler) else: # TODO: redirect stderr to its own stream? await _ContainerProcess(res.exec_id, client, stdout=StreamType.STDOUT, stderr=StreamType.STDOUT).wait() diff --git a/modal/cli/run.py b/modal/cli/run.py index 2261602c0..e54f17eac 100644 --- a/modal/cli/run.py +++ b/modal/cli/run.py @@ -14,6 +14,7 @@ import typer from typing_extensions import TypedDict +from .._utils.shell_utils import WindowSizeHandler from ..app import App, LocalEntrypoint from ..config import config from ..environments import ensure_env @@ -461,6 +462,8 @@ def shell( if pty is None: pty = is_tty() + window_size_handler = WindowSizeHandler() + if platform.system() == "Windows": raise InvalidError("`modal shell` is currently not supported on Windows") @@ -503,6 +506,7 @@ def shell( volumes=function_spec.volumes, region=function_spec.scheduler_placement.proto.regions if function_spec.scheduler_placement else None, pty=pty, + window_size_handler=window_size_handler, proxy=function_spec.proxy, ) else: @@ -518,6 +522,7 @@ def shell( volumes=volumes, region=region.split(",") if region else [], pty=pty, + window_size_handler=window_size_handler, ) # NB: invoking under bash makes --cmd a lot more flexible. diff --git a/modal/container_process.py b/modal/container_process.py index 1f1bd7360..368165d17 100644 --- a/modal/container_process.py +++ b/modal/container_process.py @@ -1,6 +1,7 @@ # Copyright Modal Labs 2024 import asyncio import platform +import struct from typing import Generic, Optional, TypeVar from modal_proto import api_pb2 @@ -8,7 +9,7 @@ from ._utils.async_utils import TaskContext, synchronize_api from ._utils.deprecation import deprecation_error from ._utils.grpc_utils import retry_transient_errors -from ._utils.shell_utils import stream_from_stdin, write_to_fd +from ._utils.shell_utils import WindowSizeHandler, stream_from_stdin, write_to_fd from .client import _Client from .exception import InteractiveTimeoutError, InvalidError from .io_streams import _StreamReader, _StreamWriter @@ -115,7 +116,7 @@ async def wait(self) -> int: self._returncode = resp.exit_code return self._returncode - async def attach(self, *, pty: Optional[bool] = None): + async def attach(self, *, window_size_handler: WindowSizeHandler, pty: Optional[bool] = None): if platform.system() == "Windows": print("interactive exec is not currently supported on Windows.") return @@ -151,6 +152,17 @@ async def _handle_input(data: bytes, message_index: int): self.stdin.write(data) await self.stdin.drain() + async def _send_window_resize(rows: int, cols: int): + # create resize sequence: + # - magic byte 0xC1 to identify the resize sequence + # - 2 bytes for the number of rows (big-endian) + # - 2 bytes for the number of columns (big-endian) + magic = bytes([0xC1]) + dims = struct.pack(">HH", rows, cols) + resize_data = magic + dims + self.stdin.write(resize_data) + await self.stdin.drain() + async with TaskContext() as tc: stdout_task = tc.create_task(_write_to_fd_loop(self.stdout)) stderr_task = tc.create_task(_write_to_fd_loop(self.stderr)) @@ -159,9 +171,10 @@ async def _handle_input(data: bytes, message_index: int): # time out if we can't connect to the server fast enough await asyncio.wait_for(on_connect.wait(), timeout=60) - async with stream_from_stdin(_handle_input, use_raw_terminal=True): - await stdout_task - await stderr_task + async with window_size_handler.watch_window_size(_send_window_resize): + async with stream_from_stdin(_handle_input, use_raw_terminal=True): + await stdout_task + await stderr_task # TODO: this doesn't work right now. # if exit_status != 0: diff --git a/modal/runner.py b/modal/runner.py index 9dc1ea3fd..866aea4e2 100644 --- a/modal/runner.py +++ b/modal/runner.py @@ -24,6 +24,7 @@ from ._utils.deprecation import deprecation_error from ._utils.grpc_utils import retry_transient_errors from ._utils.name_utils import check_object_name, is_valid_tag +from ._utils.shell_utils import WindowSizeHandler from .client import HEARTBEAT_INTERVAL, HEARTBEAT_TIMEOUT, _Client from .cls import _Cls from .config import config, logger @@ -563,7 +564,12 @@ def heartbeat(): async def _interactive_shell( - _app: _App, cmds: list[str], environment_name: str = "", pty: bool = True, **kwargs: Any + _app: _App, + cmds: list[str], + environment_name: str = "", + pty: bool = True, + window_size_handler: Optional[WindowSizeHandler] = None, + **kwargs: Any, ) -> None: """Run an interactive shell (like `bash`) within the image for this app. @@ -611,7 +617,8 @@ async def _interactive_shell( container_process = await sandbox.exec( *sandbox_cmds, pty_info=get_pty_info(shell=True) if pty else None ) - await container_process.attach() + assert window_size_handler is not None, "window_size_handler must be provided when pty is True" + await container_process.attach(window_size_handler=window_size_handler) else: container_process = await sandbox.exec( *sandbox_cmds, stdout=StreamType.STDOUT, stderr=StreamType.STDOUT