Skip to content

Commit

Permalink
Substitute gateway hostname from secrets, show SSH errors in CLI
Browse files Browse the repository at this point in the history
  • Loading branch information
Egor-S committed Aug 1, 2023
1 parent 2783de8 commit 33c5d01
Show file tree
Hide file tree
Showing 7 changed files with 58 additions and 17 deletions.
8 changes: 7 additions & 1 deletion cli/dstack/_internal/backend/base/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -297,7 +297,13 @@ def list_jobs(self, repo_id: str, run_name: str) -> List[Job]:

def run_job(self, job: Job, failed_to_start_job_new_status: JobStatus):
self.predict_build_plan(job) # raises exception on missing build
base_jobs.run_job(self.storage(), self.compute(), job, failed_to_start_job_new_status)
base_jobs.run_job(
self.storage(),
self.compute(),
self.secrets_manager(),
job,
failed_to_start_job_new_status,
)

def restart_job(self, job: Job):
base_jobs.restart_job(self.storage(), self.compute(), job)
Expand Down
31 changes: 22 additions & 9 deletions cli/dstack/_internal/backend/base/gateway.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,13 @@
list_head_objects,
put_head_object,
)
from dstack._internal.backend.base.secrets import SecretsManager
from dstack._internal.backend.base.storage import Storage
from dstack._internal.core.error import DstackError
from dstack._internal.core.error import SSHCommandError
from dstack._internal.core.gateway import GatewayHead
from dstack._internal.hub.utils.ssh import HUB_PRIVATE_KEY_PATH
from dstack._internal.utils.common import PathLike
from dstack._internal.utils.common import PathLike, removeprefix
from dstack._internal.utils.interpolator import VariablesInterpolator
from dstack._internal.utils.random_names import generate_name


Expand All @@ -38,6 +40,17 @@ def delete_gateway(compute: Compute, storage: Storage, instance_name: str):
delete_head_object(storage, head)


def resolve_hostname(secrets_manager: SecretsManager, repo_id: str, hostname: str) -> str:
secrets = {}
_, missed = VariablesInterpolator({}).interpolate(hostname, return_missing=True)
for ns_name in missed:
name = removeprefix(ns_name, "secrets.")
value = secrets_manager.get_secret(repo_id, name)
if value is not None:
secrets[name] = value.secret_value
return VariablesInterpolator({"secrets": secrets}).interpolate(hostname)


def publish(
hostname: str,
port: int,
Expand All @@ -63,18 +76,18 @@ def exec_ssh_command(
args += ["-i", id_rsa]
args += [
"-o",
"StrictHostKeyChecking=accept-new",
"StrictHostKeyChecking=no",
"-o",
"UserKnownHostsFile=/dev/null",
f"{user}@{hostname}",
command,
]
if not hostname: # ssh hangs indefinitely with empty hostname
raise SSHCommandError(
args, "ssh: Could not connect to the gateway, because hostname is empty"
)
proc = subprocess.Popen(args, stdin=stdin, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
stdout, stderr = proc.communicate()
if proc.returncode != 0:
raise SSHCommandError(args, stderr.decode())
return stdout


class SSHCommandError(DstackError):
def __init__(self, cmd: List[str], message: str):
super().__init__(message)
self.cmd = cmd
8 changes: 7 additions & 1 deletion cli/dstack/_internal/backend/base/jobs.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import dstack._internal.backend.base.gateway as gateway
from dstack._internal.backend.base import runners
from dstack._internal.backend.base.compute import Compute, InstanceNotFoundError, NoCapacityError
from dstack._internal.backend.base.secrets import SecretsManager
from dstack._internal.backend.base.storage import Storage
from dstack._internal.core.error import BackendError, BackendValueError, NoMatchingInstanceError
from dstack._internal.core.instance import InstanceType
Expand Down Expand Up @@ -119,19 +120,24 @@ def predict_job_instance(
def run_job(
storage: Storage,
compute: Compute,
secrets_manager: SecretsManager,
job: Job,
failed_to_start_job_new_status: JobStatus,
):
if job.status != JobStatus.SUBMITTED:
raise BackendError("Can't create a request for a job which status is not SUBMITTED")
try:
if job.configuration_type == ConfigurationType.SERVICE:
job.gateway.hostname = gateway.resolve_hostname(
secrets_manager, job.repo_ref.repo_id, job.gateway.hostname
)
private_bytes, public_bytes = generate_rsa_key_pair_bytes(comment=job.run_name)
job.gateway.sock_path = gateway.publish(
job.gateway.hostname, job.gateway.public_port, public_bytes
)
job.gateway.ssh_key = private_bytes.decode()
update_job(storage, job)

_try_run_job(
storage=storage,
compute=compute,
Expand Down Expand Up @@ -165,7 +171,7 @@ def restart_job(


def stop_job(
storage: Storage, compute: Compute, repo_id: str, job_id: str, terminate: str, abort: str
storage: Storage, compute: Compute, repo_id: str, job_id: str, terminate: bool, abort: bool
):
logger.info("Stopping job [repo_id=%s job_id=%s]", repo_id, job_id)
job_head = get_job_head(storage, repo_id, job_id)
Expand Down
10 changes: 9 additions & 1 deletion cli/dstack/_internal/core/error.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Optional
from typing import List, Optional


class DstackError(Exception):
Expand Down Expand Up @@ -34,3 +34,11 @@ def __init__(self, message: Optional[str] = None, project_name: Optional[str] =

class NameNotFoundError(DstackError):
pass


class SSHCommandError(BackendError):
code = "ssh_command_error"

def __init__(self, cmd: List[str], message: str):
super().__init__(message)
self.cmd = cmd
7 changes: 6 additions & 1 deletion cli/dstack/_internal/hub/routers/runners.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from fastapi import APIRouter, Depends, HTTPException, status

from dstack._internal.core.build import BuildNotFoundError
from dstack._internal.core.error import NoMatchingInstanceError
from dstack._internal.core.error import NoMatchingInstanceError, SSHCommandError
from dstack._internal.core.job import Job, JobStatus
from dstack._internal.hub.models import StopRunners
from dstack._internal.hub.routers.util import call_backend, error_detail, get_backend, get_project
Expand Down Expand Up @@ -33,6 +33,11 @@ async def run(project_name: str, job: Job):
status_code=status.HTTP_400_BAD_REQUEST,
detail=error_detail(msg=e.message, code=e.code),
)
except SSHCommandError as e:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=error_detail(msg=e.message, code=e.code),
)


@router.post("/{project_name}/runners/restart")
Expand Down
2 changes: 1 addition & 1 deletion cli/dstack/_internal/scripts/gateway_publish.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ def main():
# detect conflicts
conf_path = Path("/etc/nginx/sites-enabled") / f"{args.port}-{args.hostname}.conf"
if conf_path.exists() and is_conf_active(conf_path):
exit(f"{args.hostname}:{args.port} is still in use")
exit(f"Could not start the service, because {args.hostname}:{args.port} is in use")

# create temp dir for socket
temp_dir = tempfile.mkdtemp(prefix="dstack-")
Expand Down
9 changes: 6 additions & 3 deletions cli/dstack/api/hub/_api_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
BackendNotAvailableError,
BackendValueError,
NoMatchingInstanceError,
SSHCommandError,
)
from dstack._internal.core.gateway import GatewayHead
from dstack._internal.core.job import Job, JobHead
Expand Down Expand Up @@ -177,9 +178,11 @@ def run_job(self, job: Job):
return
elif resp.status_code == 400:
body = resp.json()
if body["detail"]["code"] == NoMatchingInstanceError.code:
raise HubClientError(body["detail"]["msg"])
elif body["detail"]["code"] == BuildNotFoundError.code:
if body["detail"]["code"] in (
NoMatchingInstanceError.code,
BuildNotFoundError.code,
SSHCommandError.code,
):
raise HubClientError(body["detail"]["msg"])
resp.raise_for_status()

Expand Down

0 comments on commit 33c5d01

Please sign in to comment.