diff --git a/cli/dstack/_internal/cli/utils/run.py b/cli/dstack/_internal/cli/utils/run.py index b843d86c5..8968ab97f 100644 --- a/cli/dstack/_internal/cli/utils/run.py +++ b/cli/dstack/_internal/cli/utils/run.py @@ -72,34 +72,6 @@ def print_run_plan(configuration_file: str, run_plan: RunPlan): console.print() -def reserve_ports(apps: List[AppSpec], local_backend: bool) -> Tuple[PortsLock, PortsLock]: - host_ports = {} - host_ports_lock = PortsLock() - app_ports = {} - app_ports_lock = PortsLock() - openssh_server_port: Optional[int] = None - - for app in apps: - app_ports[app.port] = app.map_to_port or 0 - if app.app_name == "openssh-server": - openssh_server_port = app.port - if local_backend and openssh_server_port is None: - return host_ports_lock, app_ports_lock - - if not local_backend: - if openssh_server_port is None: - host_ports.update(app_ports) - app_ports = {} - host_ports_lock = PortsLock(host_ports).acquire() - - if openssh_server_port is not None: - del app_ports[openssh_server_port] - if app_ports: - app_ports_lock = PortsLock(app_ports).acquire() - - return host_ports_lock, app_ports_lock - - def poll_run( hub_client: HubClient, run: RunHead, @@ -241,61 +213,61 @@ def _print_failed_run_message(run: RunHead): console.print("Provisioning failed\n") +def reserve_ports(apps: List[AppSpec], local_backend: bool) -> Tuple[PortsLock, PortsLock]: + """ + :return: host_ports_lock, app_ports_lock + """ + app_ports = {app.port: app.map_to_port or 0 for app in apps} + ssh_server_port = get_ssh_server_port(apps) + + if not local_backend and ssh_server_port is None: + # cloud backand without ssh in the container: use a host ssh tunnel + return PortsLock(app_ports).acquire(), PortsLock() + + if ssh_server_port is not None: + # any backend with ssh in the container: use a container ssh tunnel + del app_ports[ssh_server_port] + # for cloud backend: using ProxyJump to access ssh in the container, no host port forwarding needed + # for local backend: the same host, no port forwarding needed + return PortsLock(), PortsLock(app_ports).acquire() + + # local backend without ssh in the container: all ports mapped by runner + return PortsLock(), PortsLock() + + def _attach( hub_client: HubClient, job: Job, ssh_key_path: str, ports_locks: Tuple[PortsLock, PortsLock] ) -> Dict[int, int]: """ - :return: (host tunnel ports, container tunnel ports, ports mapping) + :return: ports_mapping """ backend_type = hub_client.get_project_backend_type() - app_ports = {} - openssh_server_port: Optional[int] = None - for app in job.app_specs or []: - app_ports[app.port] = app.map_to_port or 0 - if app.app_name == "openssh-server": - openssh_server_port = app.port - if backend_type == "local" and openssh_server_port is None: + app_ports = {app.port: app.map_to_port or 0 for app in job.app_specs or []} + host_ports = {} + ssh_server_port = get_ssh_server_port(job.app_specs or []) + + if backend_type == "local" and ssh_server_port is None: console.print("Provisioning... It may take up to a minute. [green]✓[/]") + # local backend without ssh in container: all ports mapped by runner return {k: v for k, v in app_ports.items() if v != 0} - console.print("Starting SSH tunnel...") include_ssh_config(config.ssh_config_path) - ws_port = int(job.env["WS_LOGS_PORT"]) - host_ports = {} host_ports_lock, app_ports_lock = ports_locks if backend_type != "local" and not ENABLE_LOCAL_CLOUD: - ssh_config_add_host( - config.ssh_config_path, - f"{job.run_name}-host", - { - "HostName": job.host_name, - # TODO: use non-root for all backends - "User": "ubuntu" if backend_type in ("azure", "gcp", "lambda") else "root", - "IdentityFile": ssh_key_path, - "StrictHostKeyChecking": "no", - "UserKnownHostsFile": "/dev/null", - "ControlPath": config.ssh_control_path(f"{job.run_name}-host"), - "ControlMaster": "auto", - "ControlPersist": "yes", - }, - ) - if openssh_server_port is None: + # cloud backend, need to forward logs websocket + console.print("Starting SSH tunnel...") + if ssh_server_port is None: + # ssh in the container: no need to forward app ports app_ports = {} - host_ports = PortsLock({ws_port: 0}).acquire().release() - host_ports.update(host_ports_lock.release()) - for i in range(3): # retry - time.sleep(2**i) - if run_ssh_tunnel(f"{job.run_name}-host", host_ports): - break - else: - console.print("[warning]Warning: failed to start SSH tunnel[/warning] [red]✗[/]") + host_ports = _run_host_ssh_tunnel(job, ssh_key_path, host_ports_lock, backend_type) - if openssh_server_port is not None: + if ssh_server_port is not None: + # ssh in the container: update ssh config, run tunnel if any apps options = { "HostName": "localhost", - "Port": app_ports[openssh_server_port] or openssh_server_port, + "Port": app_ports[ssh_server_port] or ssh_server_port, "User": "root", "IdentityFile": ssh_key_path, "StrictHostKeyChecking": "no", @@ -307,12 +279,13 @@ def _attach( if backend_type != "local" and not ENABLE_LOCAL_CLOUD: options["ProxyJump"] = f"{job.run_name}-host" ssh_config_add_host(config.ssh_config_path, job.run_name, options) - del app_ports[openssh_server_port] + del app_ports[ssh_server_port] if app_ports: + # save mapping, but don't release ports yet app_ports.update(app_ports_lock.dict()) # try to attach in the background threading.Thread( - target=_attach_to_container, + target=_run_container_ssh_tunnel, args=(hub_client, job.run_name, app_ports_lock), daemon=True, ).start() @@ -320,7 +293,37 @@ def _attach( return {**host_ports, **app_ports} -def _attach_to_container(hub_client: HubClient, run_name: str, ports_lock: PortsLock): +def _run_host_ssh_tunnel( + job: Job, ssh_key_path: str, ports_lock: PortsLock, backend_type: str +) -> Dict[int, int]: + ssh_config_add_host( + config.ssh_config_path, + f"{job.run_name}-host", + { + "HostName": job.host_name, + # TODO: use non-root for all backends + "User": "ubuntu" if backend_type in ("azure", "gcp", "lambda") else "root", + "IdentityFile": ssh_key_path, + "StrictHostKeyChecking": "no", + "UserKnownHostsFile": "/dev/null", + "ControlPath": config.ssh_control_path(f"{job.run_name}-host"), + "ControlMaster": "auto", + "ControlPersist": "yes", + }, + ) + # get free port for logs + host_ports = PortsLock({int(job.env["WS_LOGS_PORT"]): 0}).acquire().release() + host_ports.update(ports_lock.release()) + for i in range(3): # retry + time.sleep(2**i) + if run_ssh_tunnel(f"{job.run_name}-host", host_ports): + break + else: + console.print("[warning]Warning: failed to start SSH tunnel[/warning] [red]✗[/]") + return host_ports + + +def _run_container_ssh_tunnel(hub_client: HubClient, run_name: str, ports_lock: PortsLock): # idle BUILDING for run in _poll_run_head(hub_client, run_name, loop_statuses=[JobStatus.BUILDING]): pass @@ -337,7 +340,7 @@ def _attach_to_container(hub_client: HubClient, run_name: str, ports_lock: Ports "[red]ERROR[/] Can't establish SSH tunnel with the container\n" "[grey58]Aborting...[/]" ) - hub_client.stop_jobs(run_name, abort=True) + hub_client.stop_jobs(run_name, terminate=True, abort=True) exit(1) @@ -421,3 +424,10 @@ def _ask_on_interrupt(hub_client: HubClient, run_name: str): ssh_config_remove_host(config.ssh_config_path, f"{run_name}-host") ssh_config_remove_host(config.ssh_config_path, run_name) exit(0) + + +def get_ssh_server_port(apps: List[AppSpec]) -> Optional[int]: + for app in apps: + if app.app_name == "openssh-server": + return app.port + return None