Skip to content

Commit

Permalink
Small ports refactoring (#603)
Browse files Browse the repository at this point in the history
  • Loading branch information
Egor-S authored Jul 28, 2023
1 parent 9bb93b6 commit 0e099b5
Showing 1 changed file with 79 additions and 69 deletions.
148 changes: 79 additions & 69 deletions cli/dstack/_internal/cli/utils/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,34 +72,6 @@ def print_run_plan(configuration_file: str, run_plan: RunPlan):
console.print()


def reserve_ports(apps: List[AppSpec], local_backend: bool) -> Tuple[PortsLock, PortsLock]:
host_ports = {}
host_ports_lock = PortsLock()
app_ports = {}
app_ports_lock = PortsLock()
openssh_server_port: Optional[int] = None

for app in apps:
app_ports[app.port] = app.map_to_port or 0
if app.app_name == "openssh-server":
openssh_server_port = app.port
if local_backend and openssh_server_port is None:
return host_ports_lock, app_ports_lock

if not local_backend:
if openssh_server_port is None:
host_ports.update(app_ports)
app_ports = {}
host_ports_lock = PortsLock(host_ports).acquire()

if openssh_server_port is not None:
del app_ports[openssh_server_port]
if app_ports:
app_ports_lock = PortsLock(app_ports).acquire()

return host_ports_lock, app_ports_lock


def poll_run(
hub_client: HubClient,
run: RunHead,
Expand Down Expand Up @@ -241,61 +213,61 @@ def _print_failed_run_message(run: RunHead):
console.print("Provisioning failed\n")


def reserve_ports(apps: List[AppSpec], local_backend: bool) -> Tuple[PortsLock, PortsLock]:
"""
:return: host_ports_lock, app_ports_lock
"""
app_ports = {app.port: app.map_to_port or 0 for app in apps}
ssh_server_port = get_ssh_server_port(apps)

if not local_backend and ssh_server_port is None:
# cloud backand without ssh in the container: use a host ssh tunnel
return PortsLock(app_ports).acquire(), PortsLock()

if ssh_server_port is not None:
# any backend with ssh in the container: use a container ssh tunnel
del app_ports[ssh_server_port]
# for cloud backend: using ProxyJump to access ssh in the container, no host port forwarding needed
# for local backend: the same host, no port forwarding needed
return PortsLock(), PortsLock(app_ports).acquire()

# local backend without ssh in the container: all ports mapped by runner
return PortsLock(), PortsLock()


def _attach(
hub_client: HubClient, job: Job, ssh_key_path: str, ports_locks: Tuple[PortsLock, PortsLock]
) -> Dict[int, int]:
"""
:return: (host tunnel ports, container tunnel ports, ports mapping)
:return: ports_mapping
"""
backend_type = hub_client.get_project_backend_type()
app_ports = {}
openssh_server_port: Optional[int] = None
for app in job.app_specs or []:
app_ports[app.port] = app.map_to_port or 0
if app.app_name == "openssh-server":
openssh_server_port = app.port
if backend_type == "local" and openssh_server_port is None:
app_ports = {app.port: app.map_to_port or 0 for app in job.app_specs or []}
host_ports = {}
ssh_server_port = get_ssh_server_port(job.app_specs or [])

if backend_type == "local" and ssh_server_port is None:
console.print("Provisioning... It may take up to a minute. [green]✓[/]")
# local backend without ssh in container: all ports mapped by runner
return {k: v for k, v in app_ports.items() if v != 0}

console.print("Starting SSH tunnel...")
include_ssh_config(config.ssh_config_path)
ws_port = int(job.env["WS_LOGS_PORT"])

host_ports = {}
host_ports_lock, app_ports_lock = ports_locks

if backend_type != "local" and not ENABLE_LOCAL_CLOUD:
ssh_config_add_host(
config.ssh_config_path,
f"{job.run_name}-host",
{
"HostName": job.host_name,
# TODO: use non-root for all backends
"User": "ubuntu" if backend_type in ("azure", "gcp", "lambda") else "root",
"IdentityFile": ssh_key_path,
"StrictHostKeyChecking": "no",
"UserKnownHostsFile": "/dev/null",
"ControlPath": config.ssh_control_path(f"{job.run_name}-host"),
"ControlMaster": "auto",
"ControlPersist": "yes",
},
)
if openssh_server_port is None:
# cloud backend, need to forward logs websocket
console.print("Starting SSH tunnel...")
if ssh_server_port is None:
# ssh in the container: no need to forward app ports
app_ports = {}
host_ports = PortsLock({ws_port: 0}).acquire().release()
host_ports.update(host_ports_lock.release())
for i in range(3): # retry
time.sleep(2**i)
if run_ssh_tunnel(f"{job.run_name}-host", host_ports):
break
else:
console.print("[warning]Warning: failed to start SSH tunnel[/warning] [red]✗[/]")
host_ports = _run_host_ssh_tunnel(job, ssh_key_path, host_ports_lock, backend_type)

if openssh_server_port is not None:
if ssh_server_port is not None:
# ssh in the container: update ssh config, run tunnel if any apps
options = {
"HostName": "localhost",
"Port": app_ports[openssh_server_port] or openssh_server_port,
"Port": app_ports[ssh_server_port] or ssh_server_port,
"User": "root",
"IdentityFile": ssh_key_path,
"StrictHostKeyChecking": "no",
Expand All @@ -307,20 +279,51 @@ def _attach(
if backend_type != "local" and not ENABLE_LOCAL_CLOUD:
options["ProxyJump"] = f"{job.run_name}-host"
ssh_config_add_host(config.ssh_config_path, job.run_name, options)
del app_ports[openssh_server_port]
del app_ports[ssh_server_port]
if app_ports:
# save mapping, but don't release ports yet
app_ports.update(app_ports_lock.dict())
# try to attach in the background
threading.Thread(
target=_attach_to_container,
target=_run_container_ssh_tunnel,
args=(hub_client, job.run_name, app_ports_lock),
daemon=True,
).start()

return {**host_ports, **app_ports}


def _attach_to_container(hub_client: HubClient, run_name: str, ports_lock: PortsLock):
def _run_host_ssh_tunnel(
job: Job, ssh_key_path: str, ports_lock: PortsLock, backend_type: str
) -> Dict[int, int]:
ssh_config_add_host(
config.ssh_config_path,
f"{job.run_name}-host",
{
"HostName": job.host_name,
# TODO: use non-root for all backends
"User": "ubuntu" if backend_type in ("azure", "gcp", "lambda") else "root",
"IdentityFile": ssh_key_path,
"StrictHostKeyChecking": "no",
"UserKnownHostsFile": "/dev/null",
"ControlPath": config.ssh_control_path(f"{job.run_name}-host"),
"ControlMaster": "auto",
"ControlPersist": "yes",
},
)
# get free port for logs
host_ports = PortsLock({int(job.env["WS_LOGS_PORT"]): 0}).acquire().release()
host_ports.update(ports_lock.release())
for i in range(3): # retry
time.sleep(2**i)
if run_ssh_tunnel(f"{job.run_name}-host", host_ports):
break
else:
console.print("[warning]Warning: failed to start SSH tunnel[/warning] [red]✗[/]")
return host_ports


def _run_container_ssh_tunnel(hub_client: HubClient, run_name: str, ports_lock: PortsLock):
# idle BUILDING
for run in _poll_run_head(hub_client, run_name, loop_statuses=[JobStatus.BUILDING]):
pass
Expand All @@ -337,7 +340,7 @@ def _attach_to_container(hub_client: HubClient, run_name: str, ports_lock: Ports
"[red]ERROR[/] Can't establish SSH tunnel with the container\n"
"[grey58]Aborting...[/]"
)
hub_client.stop_jobs(run_name, abort=True)
hub_client.stop_jobs(run_name, terminate=True, abort=True)
exit(1)


Expand Down Expand Up @@ -421,3 +424,10 @@ def _ask_on_interrupt(hub_client: HubClient, run_name: str):
ssh_config_remove_host(config.ssh_config_path, f"{run_name}-host")
ssh_config_remove_host(config.ssh_config_path, run_name)
exit(0)


def get_ssh_server_port(apps: List[AppSpec]) -> Optional[int]:
for app in apps:
if app.app_name == "openssh-server":
return app.port
return None

0 comments on commit 0e099b5

Please sign in to comment.