diff --git a/CHANGELOG.md b/CHANGELOG.md index f31fe5a8..c2b3568b 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -16,6 +16,13 @@ Types of changes # Latch SDK Changelog +## 2.53.8 - 2024-10-18 + +### Added + +* Nextflow + - add `latch nextflow attach` command to attach to a nextflow work directory + ## 2.53.7 - 2024-10-16 ### Added diff --git a/latch_cli/main.py b/latch_cli/main.py index 9767ea47..e573679b 100644 --- a/latch_cli/main.py +++ b/latch_cli/main.py @@ -437,7 +437,7 @@ def execute( ): """Drops the user into an interactive shell from within a task.""" - from latch_cli.services.execute.main import exec + from latch_cli.services.k8s.execute import exec exec(execution_id=execution_id, egn_id=egn_id, container_index=container_index) @@ -999,6 +999,19 @@ def generate_entrypoint( ) +@nextflow.command("attach") +@click.option( + "--execution-id", "-e", type=str, help="Optional execution ID to inspect." +) +@requires_login +def attach(execution_id: Optional[str]): + """Drops the user into an interactive shell to inspect the workdir of a nextflow execution.""" + + from latch_cli.services.k8s.attach import attach + + attach(execution_id) + + """ POD COMMANDS """ diff --git a/latch_cli/services/execute/__init__.py b/latch_cli/services/k8s/__init__.py similarity index 100% rename from latch_cli/services/execute/__init__.py rename to latch_cli/services/k8s/__init__.py diff --git a/latch_cli/services/k8s/attach.py b/latch_cli/services/k8s/attach.py new file mode 100644 index 00000000..8e2637b6 --- /dev/null +++ b/latch_cli/services/k8s/attach.py @@ -0,0 +1,77 @@ +import asyncio +import json +import secrets +import sys +from typing import Optional +from urllib.parse import urljoin, urlparse + +import click +import websockets.client as websockets +import websockets.exceptions as ws_exceptions +from latch_sdk_config.latch import NUCLEUS_URL + +from latch_cli.utils import get_auth_header + +from .utils import get_pvc_info +from .ws_utils import forward_stdio + + +async def connect(execution_id: str, session_id: str): + async with websockets.connect( + urlparse(urljoin(NUCLEUS_URL, "/workflows/cli/attach-nf-workdir")) + ._replace(scheme="wss") + .geturl(), + close_timeout=0, + extra_headers={"Authorization": get_auth_header()}, + ) as ws: + request = {"execution_id": int(execution_id), "session_id": session_id} + + await ws.send(json.dumps(request)) + data = await ws.recv() + + msg = "" + try: + res = json.loads(data) + if "error" in res: + raise RuntimeError(res["error"]) + except json.JSONDecodeError: + msg = "Unable to connect to pod - internal error." + except RuntimeError as e: + msg = str(e) + + if msg != "": + raise RuntimeError(msg) + + await forward_stdio(ws) + + +def get_session_id(): + return secrets.token_bytes(8).hex() + + +def attach(execution_id: Optional[str] = None): + execution_id = get_pvc_info(execution_id) + session_id = get_session_id() + + click.secho( + "Attaching to workdir - this may take a few seconds...", dim=True, italic=True + ) + + import termios + import tty + + old_settings_stdin = termios.tcgetattr(sys.stdin.fileno()) + tty.setraw(sys.stdin) + + msg = "" + try: + asyncio.run(connect(execution_id, session_id)) + except ws_exceptions.ConnectionClosedError as e: + msg = json.loads(e.reason)["error"] + except RuntimeError as e: + msg = str(e) + finally: + termios.tcsetattr(sys.stdin.fileno(), termios.TCSANOW, old_settings_stdin) + + if msg != "": + click.secho(msg, fg="red") diff --git a/latch_cli/services/k8s/execute.py b/latch_cli/services/k8s/execute.py new file mode 100644 index 00000000..79aba127 --- /dev/null +++ b/latch_cli/services/k8s/execute.py @@ -0,0 +1,78 @@ +import asyncio +import json +import sys +from typing import Optional +from urllib.parse import urljoin, urlparse + +import websockets.client as websockets +from latch_sdk_config.latch import NUCLEUS_URL + +from latch_cli.services.k8s.utils import ( + ContainerNode, + EGNNode, + ExecutionInfoNode, + get_container_info, + get_egn_info, + get_execution_info, +) +from latch_cli.utils import get_auth_header + +from .ws_utils import forward_stdio + + +async def connect(egn_info: EGNNode, container_info: Optional[ContainerNode]): + async with websockets.connect( + urlparse(urljoin(NUCLEUS_URL, "/workflows/cli/shell")) + ._replace(scheme="wss") + .geturl(), + close_timeout=0, + extra_headers={"Authorization": get_auth_header()}, + ) as ws: + request = { + "egn_id": egn_info["id"], + "container_index": ( + container_info["index"] if container_info is not None else None + ), + } + + await ws.send(json.dumps(request)) + data = await ws.recv() + + msg = "" + try: + res = json.loads(data) + if "error" in res: + raise RuntimeError(res["error"]) + except json.JSONDecodeError: + msg = "Unable to connect to pod - internal error." + except RuntimeError as e: + msg = str(e) + + if msg != "": + raise RuntimeError(msg) + + await forward_stdio(ws) + + +def exec( + execution_id: Optional[str] = None, + egn_id: Optional[str] = None, + container_index: Optional[int] = None, +): + execution_info: Optional[ExecutionInfoNode] = None + if egn_id is None: + execution_info = get_execution_info(execution_id) + + egn_info = get_egn_info(execution_info, egn_id) + container_info = get_container_info(egn_info, container_index) + + import termios + import tty + + old_settings_stdin = termios.tcgetattr(sys.stdin.fileno()) + tty.setraw(sys.stdin) + + try: + asyncio.run(connect(egn_info, container_info)) + finally: + termios.tcsetattr(sys.stdin.fileno(), termios.TCSANOW, old_settings_stdin) diff --git a/latch_cli/services/execute/utils.py b/latch_cli/services/k8s/utils.py similarity index 81% rename from latch_cli/services/execute/utils.py rename to latch_cli/services/k8s/utils.py index 21f1e5f5..3faeb761 100644 --- a/latch_cli/services/execute/utils.py +++ b/latch_cli/services/k8s/utils.py @@ -9,7 +9,6 @@ from latch_cli.click_utils import bold, color from latch_cli.menus import select_tui -from latch_cli.utils import current_workspace # todo(ayush): put this into latch_sdk_gql @@ -179,35 +178,39 @@ def get_execution_info(execution_id: Optional[str]) -> ExecutionInfoNode: """), *fragments, ), - {"createdBy": current_workspace()}, + {"createdBy": user_config.workspace_id}, )["runningExecutions"] - if len(res["nodes"]) == 0: - click.secho("You have no executions currently running.", dim=True) + nodes = res["nodes"] + + if len(nodes) == 0: + click.secho( + f"You have no executions currently running.", + dim=True, + ) raise click.exceptions.Exit(0) - if len(res["nodes"]) == 1: - execution = res["nodes"][0] + if len(nodes) == 1: + execution = nodes[0] click.secho( "Selecting execution" f" {color(execution['displayName'])} as it is" - " the only" - " one currently running in Workspace" + " the only one currently running in Workspace" f" {color(workspace_str)}.", ) return execution selected_execution = select_tui( - "You have multiple executions running in this workspace" - f" ({color(workspace_str)}). Which" - " execution would you like to inspect?", + "You have multiple executions running in" + f" this workspace ({color(workspace_str)}). Which execution would you like to" + " inspect?", [ { "display_name": f'{x["displayName"]} ({x["workflow"]["displayName"]})', "value": x, } - for x in res["nodes"] + for x in nodes ], clear_terminal=False, ) @@ -368,3 +371,61 @@ def get_container_info( raise click.exceptions.Exit(0) return selected_container_info + + +class Node(TypedDict): + id: str + displayName: str + + +class nfAvailablePvcs(TypedDict): + nodes: List[Node] + + +def get_pvc_info(execution_id: Optional[str]) -> str: + if execution_id is not None: + return execution_id + + workspace_str: str = user_config.workspace_name or user_config.workspace_id + + res: nfAvailablePvcs = execute( + gql.gql(""" + query NFWorkdirs($wsId: BigInt!) { + nfAvailablePvcs(argWsId: $wsId) { + nodes { + id + displayName + } + } + } + """), + {"wsId": user_config.workspace_id}, + )["nfAvailablePvcs"] + + nodes = res["nodes"] + + if len(nodes) == 0: + click.secho( + f"You have no available workdirs (all have expired).", + dim=True, + ) + raise click.exceptions.Exit(0) + + if len(nodes) == 1: + execution = nodes[0] + click.secho( + f"Selecting execution {color(execution['displayName'])} as it is the only" + f" one without an expired workDir in Workspace {color(workspace_str)}.", + ) + return execution["id"] + + selected_execution = select_tui( + "You have multiple available workDirs in this workspace" + f" ({color(workspace_str)}). Which execution would you like to attach to?", + options=[{"display_name": x["displayName"], "value": x["id"]} for x in nodes], + ) + if selected_execution is None: + click.secho("No execution selected. Exiting.", dim=True) + raise click.exceptions.Exit(0) + + return selected_execution diff --git a/latch_cli/services/execute/main.py b/latch_cli/services/k8s/ws_utils.py similarity index 54% rename from latch_cli/services/execute/main.py rename to latch_cli/services/k8s/ws_utils.py index b98219d0..176ba396 100644 --- a/latch_cli/services/execute/main.py +++ b/latch_cli/services/k8s/ws_utils.py @@ -4,23 +4,11 @@ import os import signal import sys -from typing import Generic, Literal, Optional, Tuple, TypedDict, TypeVar, Union -from urllib.parse import urljoin, urlparse +from typing import Literal, TypedDict, Union import websockets.client as websockets -from latch_sdk_config.latch import NUCLEUS_URL from typing_extensions import TypeAlias -from latch_cli.services.execute.utils import ( - ContainerNode, - EGNNode, - ExecutionInfoNode, - get_container_info, - get_egn_info, - get_execution_info, -) -from latch_cli.utils import get_auth_header - class StdoutResponse(TypedDict): stream: Union[Literal["stdout"], Literal["stderr"]] @@ -125,72 +113,32 @@ async def propagate_resize_events( ) -async def connect(egn_info: EGNNode, container_info: Optional[ContainerNode]): +async def forward_stdio(ws: websockets.WebSocketClientProtocol): loop = asyncio.get_event_loop() - resize_queue: asyncio.Queue = asyncio.Queue() - await resize_queue.put(os.get_terminal_size()) + resize_event_queue: asyncio.Queue = asyncio.Queue() + await resize_event_queue.put(os.get_terminal_size()) loop.add_signal_handler( signal.SIGWINCH, - lambda: asyncio.create_task(handle_resize(resize_queue)), + lambda: asyncio.create_task(handle_resize(resize_event_queue)), ) local_stdin, local_stdout = await get_stdio_streams() - async with websockets.connect( - urlparse(urljoin(NUCLEUS_URL, "/workflows/cli/shell")) - ._replace(scheme="wss") - .geturl(), - close_timeout=0, - extra_headers={"Authorization": get_auth_header()}, - ) as ws: - request = { - "egn_id": egn_info["id"], - "container_index": ( - container_info["index"] if container_info is not None else None - ), - } - - await ws.send(json.dumps(request)) - - # ayush: can't use TaskGroups bc only supported on >= 3.11 - try: - _, pending = await asyncio.wait( - [ - asyncio.create_task(pipe_from_remote_stdout(ws, local_stdout)), - asyncio.create_task(pipe_to_remote_stdin(ws, local_stdin)), - asyncio.create_task(propagate_resize_events(ws, resize_queue)), - ], - return_when=asyncio.FIRST_COMPLETED, - ) - - for unfinished in pending: - unfinished.cancel() - - except asyncio.CancelledError: - pass - - -def exec( - execution_id: Optional[str] = None, - egn_id: Optional[str] = None, - container_index: Optional[int] = None, -): - execution_info: Optional[ExecutionInfoNode] = None - if egn_id is None: - execution_info = get_execution_info(execution_id) - - egn_info = get_egn_info(execution_info, egn_id) - container_info = get_container_info(egn_info, container_index) - - import termios - import tty + # ayush: can't use TaskGroups bc only supported on >= 3.11 + try: + _, pending = await asyncio.wait( + [ + asyncio.create_task(pipe_from_remote_stdout(ws, local_stdout)), + asyncio.create_task(pipe_to_remote_stdin(ws, local_stdin)), + asyncio.create_task(propagate_resize_events(ws, resize_event_queue)), + ], + return_when=asyncio.FIRST_COMPLETED, + ) - old_settings_stdin = termios.tcgetattr(sys.stdin.fileno()) - tty.setraw(sys.stdin) + for unfinished in pending: + unfinished.cancel() - try: - asyncio.run(connect(egn_info, container_info)) - finally: - termios.tcsetattr(sys.stdin.fileno(), termios.TCSANOW, old_settings_stdin) + except asyncio.CancelledError: + pass