diff --git a/cli/dstack/_internal/backend/base/__init__.py b/cli/dstack/_internal/backend/base/__init__.py index d67ec8195..84db72556 100644 --- a/cli/dstack/_internal/backend/base/__init__.py +++ b/cli/dstack/_internal/backend/base/__init__.py @@ -2,6 +2,7 @@ from datetime import datetime from typing import Generator, List, Optional +import dstack._internal.backend.base.gateway as gateway import dstack._internal.core.build from dstack._internal.backend.base import artifacts as base_artifacts from dstack._internal.backend.base import build as base_build @@ -17,6 +18,7 @@ from dstack._internal.backend.base.storage import Storage from dstack._internal.core.artifact import Artifact from dstack._internal.core.build import BuildPlan +from dstack._internal.core.gateway import GatewayHead from dstack._internal.core.instance import InstanceType from dstack._internal.core.job import Job, JobHead, JobStatus from dstack._internal.core.log_event import LogEvent @@ -248,6 +250,18 @@ def get_signed_upload_url(self, object_key: str) -> str: def predict_build_plan(self, job: Job) -> BuildPlan: pass + @abstractmethod + def create_gateway(self, ssh_key_pub: str) -> GatewayHead: + pass + + @abstractmethod + def list_gateways(self) -> List[GatewayHead]: + pass + + @abstractmethod + def delete_gateway(self, instance_name: str): + pass + class ComponentBasedBackend(Backend): @abstractmethod @@ -468,3 +482,12 @@ def predict_build_plan(self, job: Job) -> BuildPlan: return base_build.predict_build_plan( self.storage(), job, dstack._internal.core.build.DockerPlatform.amd64 ) + + def create_gateway(self, ssh_key_pub: str) -> GatewayHead: + return gateway.create_gateway(self.compute(), self.storage(), ssh_key_pub) + + def list_gateways(self) -> List[GatewayHead]: + return gateway.list_gateways(self.storage()) + + def delete_gateway(self, instance_name: str): + gateway.delete_gateway(self.compute(), self.storage(), instance_name) diff --git a/cli/dstack/_internal/backend/base/build.py b/cli/dstack/_internal/backend/base/build.py index 94669584a..4aabc9718 100644 --- a/cli/dstack/_internal/backend/base/build.py +++ b/cli/dstack/_internal/backend/base/build.py @@ -26,8 +26,6 @@ def predict_build_plan( raise BuildNotFoundError("Build not found. Run `dstack build` or add `--build` flag") return BuildPlan.yes - if job.optional_build_commands and job.build_policy == BuildPolicy.BUILD: - return BuildPlan.yes return BuildPlan.no diff --git a/cli/dstack/_internal/backend/base/compute.py b/cli/dstack/_internal/backend/base/compute.py index 8e4357e9e..5305301a8 100644 --- a/cli/dstack/_internal/backend/base/compute.py +++ b/cli/dstack/_internal/backend/base/compute.py @@ -5,6 +5,7 @@ import dstack.version as version from dstack._internal.core.error import BackendError +from dstack._internal.core.gateway import GatewayHead from dstack._internal.core.instance import InstanceType, LaunchedInstanceInfo from dstack._internal.core.job import Job, Requirements from dstack._internal.core.request import RequestHead @@ -49,6 +50,14 @@ def terminate_instance(self, runner: Runner): def cancel_spot_request(self, runner: Runner): pass + def create_gateway(self, instance_name: str, ssh_key_pub: str) -> GatewayHead: + # todo make abstract & implement for each backend + raise NotImplementedError() + + def delete_instance(self, instance_name: str): + # todo make abstract & implement for each backend + raise NotImplementedError() + def choose_instance_type( instance_types: List[InstanceType], diff --git a/cli/dstack/_internal/backend/base/gateway.py b/cli/dstack/_internal/backend/base/gateway.py new file mode 100644 index 000000000..f7a723d93 --- /dev/null +++ b/cli/dstack/_internal/backend/base/gateway.py @@ -0,0 +1,70 @@ +import subprocess +import time +from typing import List, Optional + +from dstack._internal.backend.base.compute import Compute +from dstack._internal.backend.base.head import ( + delete_head_object, + list_head_objects, + put_head_object, +) +from dstack._internal.backend.base.storage import Storage +from dstack._internal.core.error import DstackError +from dstack._internal.core.gateway import GatewayHead +from dstack._internal.hub.utils.ssh import HUB_PRIVATE_KEY_PATH +from dstack._internal.utils.common import PathLike +from dstack._internal.utils.random_names import generate_name + + +def create_gateway(compute: Compute, storage: Storage, ssh_key_pub: str) -> GatewayHead: + # todo generate while instance name is not unique + instance_name = f"dstack-gateway-{generate_name()}" + head = compute.create_gateway(instance_name, ssh_key_pub) + put_head_object(storage, head) + return head + + +def list_gateways(storage: Storage) -> List[GatewayHead]: + return list_head_objects(storage, GatewayHead) + + +def delete_gateway(compute: Compute, storage: Storage, instance_name: str): + heads = list_gateways(storage) + for head in heads: + if head.instance_name != instance_name: + continue + compute.delete_instance(instance_name) + delete_head_object(storage, head) + + +def ssh_copy_id( + hostname: str, + public_key: bytes, + user: str = "ubuntu", + id_rsa: Optional[PathLike] = HUB_PRIVATE_KEY_PATH, +): + command = f"echo '{public_key.decode()}' >> ~/.ssh/authorized_keys" + exec_ssh_command(hostname, command, user=user, id_rsa=id_rsa) + + +def exec_ssh_command(hostname: str, command: str, user: str, id_rsa: Optional[PathLike]) -> bytes: + args = ["ssh"] + if id_rsa is not None: + args += ["-i", id_rsa] + args += [ + "-o", + "StrictHostKeyChecking=accept-new", + f"{user}@{hostname}", + command, + ] + proc = subprocess.Popen(args, stdout=subprocess.PIPE, stderr=subprocess.PIPE) + stdout, stderr = proc.communicate() + if proc.returncode != 0: + raise SSHCommandError(args, stderr.decode()) + return stdout + + +class SSHCommandError(DstackError): + def __init__(self, cmd: List[str], message: str): + super().__init__(message) + self.cmd = cmd diff --git a/cli/dstack/_internal/backend/base/head.py b/cli/dstack/_internal/backend/base/head.py new file mode 100644 index 000000000..bf4c21bdc --- /dev/null +++ b/cli/dstack/_internal/backend/base/head.py @@ -0,0 +1,19 @@ +from typing import List, Type + +from dstack._internal.backend.base.storage import Storage +from dstack._internal.core.head import BaseHead, T + + +def put_head_object(storage: Storage, head: BaseHead) -> str: + key = head.encode() + storage.put_object(key, content="") + return key + + +def list_head_objects(storage: Storage, cls: Type[T]) -> List[T]: + keys = storage.list_objects(cls.prefix()) + return [cls.decode(key) for key in keys] + + +def delete_head_object(storage, head: BaseHead): + storage.delete_object(head.encode()) diff --git a/cli/dstack/_internal/backend/base/jobs.py b/cli/dstack/_internal/backend/base/jobs.py index 1af9049cc..3a1098683 100644 --- a/cli/dstack/_internal/backend/base/jobs.py +++ b/cli/dstack/_internal/backend/base/jobs.py @@ -2,12 +2,14 @@ import yaml +import dstack._internal.backend.base.gateway as gateway from dstack._internal.backend.base import runners from dstack._internal.backend.base.compute import Compute, InstanceNotFoundError, NoCapacityError from dstack._internal.backend.base.storage import Storage from dstack._internal.core.error import BackendError, BackendValueError, NoMatchingInstanceError from dstack._internal.core.instance import InstanceType from dstack._internal.core.job import ( + ConfigurationType, Job, JobErrorCode, JobHead, @@ -18,6 +20,7 @@ from dstack._internal.core.repo import RepoRef from dstack._internal.core.runners import Runner from dstack._internal.utils.common import get_milliseconds_since_epoch +from dstack._internal.utils.crypto import generate_rsa_key_pair_bytes from dstack._internal.utils.escape import escape_head, unescape_head from dstack._internal.utils.logging import get_logger @@ -118,6 +121,11 @@ def run_job( if job.status != JobStatus.SUBMITTED: raise BackendError("Can't create a request for a job which status is not SUBMITTED") try: + if job.configuration_type == ConfigurationType.SERVICE: + private_bytes, public_bytes = generate_rsa_key_pair_bytes(comment=job.run_name) + gateway.ssh_copy_id(job.gateway.hostname, public_bytes) + job.gateway.ssh_key = private_bytes.decode() + update_job(storage, job) _try_run_job( storage=storage, compute=compute, diff --git a/cli/dstack/_internal/backend/gcp/compute.py b/cli/dstack/_internal/backend/gcp/compute.py index e2a5b7773..2e15caac7 100644 --- a/cli/dstack/_internal/backend/gcp/compute.py +++ b/cli/dstack/_internal/backend/gcp/compute.py @@ -6,6 +6,7 @@ from google.cloud import compute_v1 from google.oauth2 import service_account +import dstack._internal.backend.gcp.gateway as gateway from dstack import version from dstack._internal.backend.base.compute import ( WS_PORT, @@ -20,6 +21,7 @@ from dstack._internal.backend.gcp import utils as gcp_utils from dstack._internal.backend.gcp.config import GCPConfig from dstack._internal.core.error import BackendValueError +from dstack._internal.core.gateway import GatewayHead from dstack._internal.core.instance import InstanceType, LaunchedInstanceInfo from dstack._internal.core.job import Job, Requirements from dstack._internal.core.request import RequestHead, RequestStatus @@ -157,6 +159,35 @@ def cancel_spot_request(self, runner: Runner): instance_name=runner.request_id, ) + def create_gateway(self, instance_name: str, ssh_key_pub: str) -> GatewayHead: + instance = gateway.create_gateway_instance( + instances_client=self.instances_client, + firewalls_client=self.firewalls_client, + project_id=self.gcp_config.project_id, + network=_get_network_resource(self.gcp_config.vpc), + subnet=_get_subnet_resource(self.gcp_config.region, self.gcp_config.subnet), + zone=self.gcp_config.zone, + instance_name=instance_name, + service_account=self.credentials.service_account_email, + labels=dict( + role="gateway", + owner="dstack", + ), + ssh_key_pub=ssh_key_pub, + ) + return GatewayHead( + instance_name=instance_name, + external_ip=instance.network_interfaces[0].access_configs[0].nat_i_p, + internal_ip=instance.network_interfaces[0].network_i_p, + ) + + def delete_instance(self, instance_name: str): + _terminate_instance( + client=self.instances_client, + gcp_config=self.gcp_config, + instance_name=instance_name, + ) + def _get_instance_status( instances_client: compute_v1.InstancesClient, @@ -740,7 +771,7 @@ def _create_firewall_rules( network: str = "global/networks/default", ): """ - Creates a simple firewall rule allowing for incoming HTTP and HTTPS access from the entire Internet. + Creates a simple firewall rule allowing for incoming SSH access from the entire Internet. Args: project_id: project ID or project number of the Cloud project you want to use. diff --git a/cli/dstack/_internal/backend/gcp/gateway.py b/cli/dstack/_internal/backend/gcp/gateway.py new file mode 100644 index 000000000..02b8b6d5e --- /dev/null +++ b/cli/dstack/_internal/backend/gcp/gateway.py @@ -0,0 +1,117 @@ +from typing import Dict, List + +import google.api_core.exceptions +from google.cloud import compute_v1 + +import dstack._internal.backend.gcp.utils as gcp_utils + +DSTACK_GATEWAY_TAG = "dstack-gateway" + + +def create_gateway_instance( + instances_client: compute_v1.InstancesClient, + firewalls_client: compute_v1.FirewallsClient, + project_id: str, + network: str, + subnet: str, + zone: str, + instance_name: str, + service_account: str, + labels: Dict[str, str], + ssh_key_pub: str, + machine_type: str = "e2-micro", +) -> compute_v1.Instance: + try: + create_gateway_firewall_rules( + firewalls_client=firewalls_client, + project_id=project_id, + network=network, + ) + except google.api_core.exceptions.Conflict: + pass + + network_interface = compute_v1.NetworkInterface() + network_interface.name = network + network_interface.subnetwork = subnet + + access = compute_v1.AccessConfig() + access.type_ = compute_v1.AccessConfig.Type.ONE_TO_ONE_NAT.name + access.name = "External NAT" + access.network_tier = access.NetworkTier.PREMIUM.name + network_interface.access_configs = [access] + + instance = compute_v1.Instance() + instance.network_interfaces = [network_interface] + instance.name = instance_name + instance.disks = gateway_disks(zone) + instance.machine_type = f"zones/{zone}/machineTypes/{machine_type}" + + metadata_items = [ + compute_v1.Items(key="ssh-keys", value=f"ubuntu:{ssh_key_pub}"), + compute_v1.Items(key="user-data", value=gateway_user_data_script()), + ] + instance.metadata = compute_v1.Metadata(items=metadata_items) + instance.labels = labels + instance.tags = compute_v1.Tags(items=[DSTACK_GATEWAY_TAG]) # to apply firewall rules + + instance.service_accounts = [ + compute_v1.ServiceAccount( + email=service_account, + scopes=["https://www.googleapis.com/auth/cloud-platform"], + ) + ] + + request = compute_v1.InsertInstanceRequest() + request.zone = zone + request.project = project_id + request.instance_resource = instance + operation = instances_client.insert(request=request) + gcp_utils.wait_for_extended_operation(operation, "instance creation") + + return instances_client.get(project=project_id, zone=zone, instance=instance_name) + + +def create_gateway_firewall_rules( + firewalls_client: compute_v1.FirewallsClient, + project_id: str, + network: str, +): + firewall_rule = compute_v1.Firewall() + firewall_rule.name = "dstack-gateway-in-" + network.replace("/", "-") + firewall_rule.direction = "INGRESS" + + allowed_ports = compute_v1.Allowed() + allowed_ports.I_p_protocol = "tcp" + allowed_ports.ports = ["22", "80", "443"] + + firewall_rule.allowed = [allowed_ports] + firewall_rule.source_ranges = ["0.0.0.0/0"] + firewall_rule.network = network + firewall_rule.description = "Allowing TCP traffic on ports 22, 80, and 443 from Internet." + + firewall_rule.target_tags = [DSTACK_GATEWAY_TAG] + + operation = firewalls_client.insert(project=project_id, firewall_resource=firewall_rule) + gcp_utils.wait_for_extended_operation(operation, "firewall rule creation") + + +def gateway_disks(zone: str) -> List[compute_v1.AttachedDisk]: + disk = compute_v1.AttachedDisk() + + initialize_params = compute_v1.AttachedDiskInitializeParams() + initialize_params.source_image = ( + "projects/ubuntu-os-cloud/global/images/ubuntu-2204-jammy-v20230714" + ) + initialize_params.disk_size_gb = 10 + initialize_params.disk_type = f"zones/{zone}/diskTypes/pd-balanced" + + disk.initialize_params = initialize_params + disk.auto_delete = True + disk.boot = True + return [disk] + + +def gateway_user_data_script() -> str: + return f"""#!/bin/sh +sudo apt-get update +DEBIAN_FRONTEND=noninteractive sudo apt-get install -y -q nginx""" diff --git a/cli/dstack/_internal/cli/commands/gateway/__init__.py b/cli/dstack/_internal/cli/commands/gateway/__init__.py new file mode 100644 index 000000000..337d0a7c1 --- /dev/null +++ b/cli/dstack/_internal/cli/commands/gateway/__init__.py @@ -0,0 +1,84 @@ +from argparse import Namespace +from typing import List + +from rich.prompt import Confirm +from rich.table import Table +from rich_argparse import RichHelpFormatter + +from dstack._internal.cli.commands import BasicCommand +from dstack._internal.cli.utils.common import add_project_argument, check_init, console +from dstack._internal.cli.utils.config import get_hub_client +from dstack._internal.core.gateway import GatewayHead +from dstack.api.hub import HubClient + + +class GatewayCommand(BasicCommand): + NAME = "gateway" + DESCRIPTION = "Manage gateways" + + def __init__(self, parser): + super().__init__(parser) + + def register(self): + add_project_argument(self._parser) + subparsers = self._parser.add_subparsers(dest="action", required=True) + + list_parser = subparsers.add_parser( + "list", help="List gateways", formatter_class=RichHelpFormatter + ) + add_project_argument(list_parser) + list_parser.set_defaults(sub_func=self.list_gateways) + + create_parser = subparsers.add_parser( + "create", help="Create a gateway", formatter_class=RichHelpFormatter + ) + add_project_argument(create_parser) + create_parser.set_defaults(sub_func=self.create_gateway) + + delete_gateway_parser = subparsers.add_parser( + "delete", help="Delete a gateway", formatter_class=RichHelpFormatter + ) + add_project_argument(delete_gateway_parser) + delete_gateway_parser.add_argument( + "-y", "--yes", action="store_true", help="Don't ask for confirmation" + ) + delete_gateway_parser.add_argument( + "instance_name", metavar="NAME", type=str, help="The name of the gateway" + ) + delete_gateway_parser.set_defaults(sub_func=self.delete_gateway) + + @check_init + def _command(self, args: Namespace): + hub_client = get_hub_client(project_name=args.project) + args.sub_func(hub_client, args) + + def create_gateway(self, hub_client: HubClient, args: Namespace): + print("Creating gateway, it may take some time...") + head = hub_client.create_gateway() + print_gateways_table([head]) + + def list_gateways(self, hub_client: HubClient, args: Namespace): + heads = hub_client.list_gateways() + print_gateways_table(heads) + + def delete_gateway(self, hub_client: HubClient, args: Namespace): + heads = hub_client.list_gateways() + if args.instance_name not in [head.instance_name for head in heads]: + exit(f"No such gateway '{args.instance_name}'") + if args.yes or Confirm.ask(f"[red]Delete the gateway '{args.instance_name}'?[/]"): + hub_client.delete_gateway(args.instance_name) + console.print("Gateway is deleted") + exit(0) + + +def print_gateways_table(heads: List[GatewayHead]): + table = Table(box=None) + table.add_column("NAME") + table.add_column("ADDRESS") + for head in heads: + table.add_row( + head.instance_name, + head.external_ip, + ) + console.print(table) + console.print() diff --git a/cli/dstack/_internal/cli/commands/init/__init__.py b/cli/dstack/_internal/cli/commands/init/__init__.py index 4bcae7195..6d7d56259 100644 --- a/cli/dstack/_internal/cli/commands/init/__init__.py +++ b/cli/dstack/_internal/cli/commands/init/__init__.py @@ -13,7 +13,7 @@ from dstack._internal.cli.utils.config import config, get_hub_client from dstack._internal.core.repo import LocalRepo, RemoteRepo from dstack._internal.core.userconfig import RepoUserConfig -from dstack._internal.utils.crypto import generage_rsa_key_pair +from dstack._internal.utils.crypto import generate_rsa_key_pair class InitCommand(BasicCommand): @@ -106,5 +106,5 @@ def get_ssh_keypair( if dstack_key_path is None: return None if not dstack_key_path.exists(): - generage_rsa_key_pair(private_key_path=dstack_key_path) + generate_rsa_key_pair(private_key_path=dstack_key_path) return str(dstack_key_path) diff --git a/cli/dstack/_internal/cli/handlers.py b/cli/dstack/_internal/cli/handlers.py index 5fe495d52..bfade01f9 100644 --- a/cli/dstack/_internal/cli/handlers.py +++ b/cli/dstack/_internal/cli/handlers.py @@ -1,6 +1,7 @@ from dstack._internal.cli.commands.build import BuildCommand from dstack._internal.cli.commands.config import ConfigCommand from dstack._internal.cli.commands.cp import CpCommand +from dstack._internal.cli.commands.gateway import GatewayCommand from dstack._internal.cli.commands.init import InitCommand from dstack._internal.cli.commands.logs import LogCommand from dstack._internal.cli.commands.ls import LsCommand @@ -15,15 +16,16 @@ from dstack._internal.cli.commands.tags import TAGCommand commands_classes = [ + BuildCommand, ConfigCommand, CpCommand, + GatewayCommand, InitCommand, LogCommand, LsCommand, - BuildCommand, - PruneCommand, PSCommand, RestartCommand, + PruneCommand, RMCommand, RunCommand, SecretCommand, diff --git a/cli/dstack/_internal/cli/utils/configuration.py b/cli/dstack/_internal/cli/utils/configuration.py index f7e60029f..ca1738332 100644 --- a/cli/dstack/_internal/cli/utils/configuration.py +++ b/cli/dstack/_internal/cli/utils/configuration.py @@ -7,9 +7,11 @@ from dstack._internal.cli.profiles import load_profiles from dstack._internal.configurators import JobConfigurator from dstack._internal.configurators.dev_environment import DevEnvironmentConfigurator +from dstack._internal.configurators.service import ServiceConfigurator from dstack._internal.configurators.task import TaskConfigurator from dstack._internal.core.configuration import ( DevEnvironmentConfiguration, + ServiceConfiguration, TaskConfiguration, parse, ) @@ -39,6 +41,9 @@ def load_configuration( ) elif isinstance(configuration, TaskConfiguration): return TaskConfigurator(working_dir, str(configuration_path), configuration, profile) + elif isinstance(configuration, ServiceConfiguration): + return ServiceConfigurator(working_dir, str(configuration_path), configuration, profile) + exit(f"Unsupported configuration {type(configuration)}") diff --git a/cli/dstack/_internal/configurators/__init__.py b/cli/dstack/_internal/configurators/__init__.py index 6b7e152e1..7751c27ef 100644 --- a/cli/dstack/_internal/configurators/__init__.py +++ b/cli/dstack/_internal/configurators/__init__.py @@ -12,7 +12,11 @@ import dstack._internal.core.job as job import dstack.version as version from dstack._internal.core.build import BuildPolicy -from dstack._internal.core.configuration import BaseConfiguration, PythonVersion +from dstack._internal.core.configuration import ( + BaseConfiguration, + BaseConfigurationWithPorts, + PythonVersion, +) from dstack._internal.core.error import DstackError from dstack._internal.core.plan import RunPlan from dstack._internal.core.profile import Profile, parse_duration, parse_max_duration @@ -44,10 +48,6 @@ def get_parser( if parser is None: parser = argparse.ArgumentParser(prog=prog, formatter_class=RichHelpFormatter) - parser.add_argument( - "-p", "--ports", metavar="PORT", type=port_mapping, nargs=argparse.ONE_OR_MORE - ) - spot_group = parser.add_mutually_exclusive_group() spot_group.add_argument( "--spot", action="store_const", dest="spot_policy", const=job.SpotPolicy.SPOT @@ -81,9 +81,6 @@ def get_parser( return parser def apply_args(self, args: argparse.Namespace): - if args.ports is not None: - self.conf.ports = list(ports.merge_ports(self.conf.ports, args.ports).values()) - if args.spot_policy is not None: self.profile.spot_policy = args.spot_policy @@ -133,37 +130,37 @@ def get_jobs( self.ssh_key_pub = ssh_key_pub created_at = get_milliseconds_since_epoch() configured_job = job.Job( - job_id=f"{run_name},,0", - runner_id=uuid.uuid4().hex, - repo_ref=repo.repo_ref, - repo_data=repo.repo_data, - repo_code_filename=repo_code_filename, - run_name=run_name, - configuration_type=job.ConfigurationType(self.conf.type), + app_specs=self.app_specs(), + artifact_specs=self.artifact_specs(), + build_commands=self.build_commands(), + build_policy=self.build_policy, + cache_specs=self.cache_specs(), + commands=self.commands(), configuration_path=self.configuration_path, - status=job.JobStatus.SUBMITTED, + configuration_type=job.ConfigurationType(self.conf.type), created_at=created_at, - submitted_at=created_at, - image_name=self.image_name(run_plan), - registry_auth=self.registry_auth(), + dep_specs=self.dep_specs(), entrypoint=self.entrypoint(), - build_commands=self.build_commands(), - setup=self.setup(), - commands=self.commands(), - working_dir=self.working_dir, - home_dir=self.home_dir(), env=self.env(), - artifact_specs=self.artifact_specs(), - cache_specs=self.cache_specs(), - app_specs=self.app_specs(), - dep_specs=self.dep_specs(), - spot_policy=self.spot_policy(), - retry_policy=self.retry_policy(), + gateway=self.gateway(), + home_dir=self.home_dir(), + image_name=self.image_name(run_plan), + job_id=f"{run_name},,0", max_duration=self.max_duration(), - build_policy=self.build_policy, - termination_policy=self.termination_policy(), + registry_auth=self.registry_auth(), + repo_code_filename=repo_code_filename, + repo_data=repo.repo_data, + repo_ref=repo.repo_ref, requirements=self.requirements(), + retry_policy=self.retry_policy(), + run_name=run_name, + runner_id=uuid.uuid4().hex, + setup=self.setup(), + spot_policy=self.spot_policy(), ssh_key_pub=ssh_key_pub, + status=job.JobStatus.SUBMITTED, + submitted_at=created_at, + working_dir=self.working_dir, ) return [configured_job] @@ -188,7 +185,11 @@ def dep_specs(self) -> List[job.DepSpec]: pass @abstractmethod - def default_max_duration(self) -> int: + def default_max_duration(self) -> Optional[int]: + pass + + @abstractmethod + def ports(self) -> Dict[int, ports.PortMapping]: pass @abstractmethod @@ -244,14 +245,6 @@ def python(self) -> str: version_info = sys.version_info return PythonVersion(f"{version_info.major}.{version_info.minor}").value - def ports(self) -> Dict[int, ports.PortMapping]: - ports.unique_ports_constraint([pm.container_port for pm in self.conf.ports]) - ports.unique_ports_constraint( - [pm.local_port for pm in self.conf.ports if pm.local_port is not None], - error="Mapped port {} is already in use", - ) - return {pm.container_port: pm for pm in self.conf.ports} - def env(self) -> Dict[str, str]: return self.conf.env @@ -289,6 +282,35 @@ def max_duration(self) -> Optional[int]: return None return self.profile.max_duration + def gateway(self) -> Optional[job.Gateway]: + return None + + +class JobConfiguratorWithPorts(JobConfigurator, ABC): + conf: BaseConfigurationWithPorts + + def get_parser( + self, prog: Optional[str] = None, parser: Optional[argparse.ArgumentParser] = None + ) -> argparse.ArgumentParser: + parser = super().get_parser(prog, parser) + parser.add_argument( + "-p", "--ports", metavar="PORT", type=port_mapping, nargs=argparse.ONE_OR_MORE + ) + return parser + + def apply_args(self, args: argparse.Namespace): + super().apply_args(args) + if args.ports is not None: + self.conf.ports = list(ports.merge_ports(self.conf.ports, args.ports).values()) + + def ports(self) -> Dict[int, ports.PortMapping]: + ports.unique_ports_constraint([pm.container_port for pm in self.conf.ports]) + ports.unique_ports_constraint( + [pm.local_port for pm in self.conf.ports if pm.local_port is not None], + error="Mapped port {} is already in use", + ) + return {pm.container_port: pm for pm in self.conf.ports} + def validate_local_path(path: str, home: Optional[str], working_dir: str) -> str: if path == "~" or path.startswith("~/"): diff --git a/cli/dstack/_internal/configurators/dev_environment.py b/cli/dstack/_internal/configurators/dev_environment.py index 00e89cb9e..ddf2d0a06 100644 --- a/cli/dstack/_internal/configurators/dev_environment.py +++ b/cli/dstack/_internal/configurators/dev_environment.py @@ -1,7 +1,7 @@ from typing import List, Optional import dstack._internal.core.job as job -from dstack._internal.configurators import JobConfigurator +from dstack._internal.configurators import JobConfiguratorWithPorts from dstack._internal.configurators.extensions import IDEExtension from dstack._internal.configurators.extensions.ssh import SSHd from dstack._internal.configurators.extensions.vscode import VSCodeDesktop @@ -15,7 +15,7 @@ install_ipykernel = f'(pip install -q --no-cache-dir ipykernel 2> /dev/null) || echo "no pip, ipykernel was not installed"' -class DevEnvironmentConfigurator(JobConfigurator): +class DevEnvironmentConfigurator(JobConfiguratorWithPorts): conf: DevEnvironmentConfiguration sshd: Optional[SSHd] ide: Optional[IDEExtension] @@ -71,7 +71,7 @@ def commands(self) -> List[str]: commands += ["cat"] # idle return commands - def default_max_duration(self) -> int: + def default_max_duration(self) -> Optional[int]: return DEFAULT_MAX_DURATION_SECONDS def termination_policy(self) -> job.TerminationPolicy: diff --git a/cli/dstack/_internal/configurators/service.py b/cli/dstack/_internal/configurators/service.py new file mode 100644 index 000000000..f9b7a4704 --- /dev/null +++ b/cli/dstack/_internal/configurators/service.py @@ -0,0 +1,38 @@ +from typing import Dict, List, Optional + +import dstack._internal.configurators.ports as ports +import dstack._internal.core.job as job +from dstack._internal.configurators import JobConfigurator +from dstack._internal.core.configuration import ServiceConfiguration + + +class ServiceConfigurator(JobConfigurator): + conf: ServiceConfiguration + + def commands(self) -> List[str]: + return self.conf.commands + + def artifact_specs(self) -> List[job.ArtifactSpec]: + return [] # not implemented + + def dep_specs(self) -> List[job.DepSpec]: + return [] # not implemented + + def default_max_duration(self) -> Optional[int]: + return None # infinite + + def ports(self) -> Dict[int, ports.PortMapping]: + port = self.conf.gateway.service_port + return {port: ports.PortMapping(container_port=port)} + + def gateway(self) -> Optional[job.Gateway]: + return job.Gateway.parse_obj(self.conf.gateway) + + def build_commands(self) -> List[str]: + return self.conf.build + + def setup(self) -> List[str]: + return self.conf.setup + + def termination_policy(self) -> job.TerminationPolicy: + return self.profile.termination_policy or job.TerminationPolicy.TERMINATE diff --git a/cli/dstack/_internal/configurators/task.py b/cli/dstack/_internal/configurators/task.py index a9da79cd5..9ae03cab4 100644 --- a/cli/dstack/_internal/configurators/task.py +++ b/cli/dstack/_internal/configurators/task.py @@ -1,6 +1,6 @@ from typing import List, Optional -from dstack._internal.configurators import JobConfigurator, validate_local_path +from dstack._internal.configurators import JobConfiguratorWithPorts, validate_local_path from dstack._internal.configurators.extensions.ssh import SSHd from dstack._internal.configurators.ports import get_map_to_port from dstack._internal.core import job as job @@ -11,7 +11,7 @@ DEFAULT_MAX_DURATION_SECONDS = 72 * 3600 -class TaskConfigurator(JobConfigurator): +class TaskConfigurator(JobConfiguratorWithPorts): conf: TaskConfiguration sshd: Optional[SSHd] @@ -27,9 +27,6 @@ def get_jobs( self.sshd.map_to_port = get_map_to_port(self.ports(), self.sshd.port) return super().get_jobs(repo, run_name, repo_code_filename, ssh_key_pub, run_plan) - def optional_build_commands(self) -> List[str]: - return [] # not needed - def build_commands(self) -> List[str]: return self.conf.build @@ -48,7 +45,7 @@ def commands(self) -> List[str]: commands += self.conf.commands return commands - def default_max_duration(self) -> int: + def default_max_duration(self) -> Optional[int]: return DEFAULT_MAX_DURATION_SECONDS def termination_policy(self) -> job.TerminationPolicy: diff --git a/cli/dstack/_internal/core/configuration.py b/cli/dstack/_internal/core/configuration.py index ecf116c11..b5973d154 100644 --- a/cli/dstack/_internal/core/configuration.py +++ b/cli/dstack/_internal/core/configuration.py @@ -64,6 +64,14 @@ class Artifact(ForbidExtra): ] = False +class Gateway(ForbidExtra): + hostname: Annotated[str, Field(description="IP address or domain name")] + public_port: Annotated[ + ValidPort, Field(description="The port that the gateway listens to") + ] = 80 + service_port: Annotated[ValidPort, Field(description="The port that the service listens to")] + + class BaseConfiguration(ForbidExtra): type: Literal["none"] image: Annotated[Optional[str], Field(description="The name of the Docker image to run")] @@ -78,10 +86,6 @@ class BaseConfiguration(ForbidExtra): Optional[PythonVersion], Field(description="The major version of Python\nMutually exclusive with the image"), ] - ports: Annotated[ - List[Union[constr(regex=r"^(?:([0-9]+|\*):)?[0-9]+$"), ValidPort, PortMapping]], - Field(description="Port numbers/mapping to expose"), - ] = [] env: Annotated[ Union[List[constr(regex=r"^[a-zA-Z_][a-zA-Z0-9_]*=.*$")], Dict[str, str]], Field(description="The mapping or the list of environment variables"), @@ -106,6 +110,19 @@ def convert_python(cls, v, values) -> Optional[PythonVersion]: return PythonVersion(v) return v + @validator("env") + def convert_env(cls, v) -> Dict[str, str]: + if isinstance(v, list): + return dict(pair.split(sep="=", maxsplit=1) for pair in v) + return v + + +class BaseConfigurationWithPorts(BaseConfiguration): + ports: Annotated[ + List[Union[constr(regex=r"^(?:([0-9]+|\*):)?[0-9]+$"), ValidPort, PortMapping]], + Field(description="Port numbers/mapping to expose"), + ] = [] + @validator("ports", each_item=True) def convert_ports(cls, v) -> PortMapping: if isinstance(v, int): @@ -114,28 +131,29 @@ def convert_ports(cls, v) -> PortMapping: return PortMapping.parse(v) return v - @validator("env") - def convert_env(cls, v) -> Dict[str, str]: - if isinstance(v, list): - return dict(pair.split(sep="=", maxsplit=1) for pair in v) - return v - -class DevEnvironmentConfiguration(BaseConfiguration): +class DevEnvironmentConfiguration(BaseConfigurationWithPorts): type: Literal["dev-environment"] = "dev-environment" ide: Annotated[Literal["vscode"], Field(description="The IDE to run")] init: Annotated[CommandsList, Field(description="The bash commands to run")] = [] -class TaskConfiguration(BaseConfiguration): +class TaskConfiguration(BaseConfigurationWithPorts): type: Literal["task"] = "task" commands: Annotated[CommandsList, Field(description="The bash commands to run")] artifacts: Annotated[List[Artifact], Field(description="The list of output artifacts")] = [] +class ServiceConfiguration(BaseConfiguration): + type: Literal["service"] = "service" + commands: Annotated[CommandsList, Field(description="The bash commands to run")] + gateway: Annotated[Gateway, Field(description="The gateway to publish the service")] + + class DstackConfiguration(BaseModel): __root__: Annotated[ - Union[DevEnvironmentConfiguration, TaskConfiguration], Field(discriminator="type") + Union[DevEnvironmentConfiguration, TaskConfiguration, ServiceConfiguration], + Field(discriminator="type"), ] class Config: diff --git a/cli/dstack/_internal/core/gateway.py b/cli/dstack/_internal/core/gateway.py new file mode 100644 index 000000000..c2edde599 --- /dev/null +++ b/cli/dstack/_internal/core/gateway.py @@ -0,0 +1,16 @@ +import time + +from pydantic import Field + +from dstack._internal.core.head import BaseHead + + +class GatewayHead(BaseHead): + instance_name: str + external_ip: str + internal_ip: str + created_at: int = Field(default_factory=lambda: int(time.time() * 1000)) + + @classmethod + def prefix(cls) -> str: + return "gateways/l;" diff --git a/cli/dstack/_internal/core/head.py b/cli/dstack/_internal/core/head.py new file mode 100644 index 000000000..02c396067 --- /dev/null +++ b/cli/dstack/_internal/core/head.py @@ -0,0 +1,28 @@ +from abc import ABC, abstractmethod +from typing import Type, TypeVar + +from pydantic import BaseModel + +T = TypeVar("T", bound="BaseHead") + + +class BaseHead(BaseModel, ABC): + @classmethod + @abstractmethod + def prefix(cls) -> str: + pass + + def encode(self) -> str: + tokens = [] + data = self.dict(exclude_none=True) + for key in self.__fields__.keys(): + # replace missing with empty token + tokens.append(str(data.get(key, ""))) + return self.prefix() + ";".join(tokens) + + @classmethod + def decode(cls: Type[T], key: str) -> T: + # maxsplit allows *args as last field + values = key[len(cls.prefix()) :].split(";", maxsplit=len(cls.__fields__) - 1) + # dict in python3 is ordered, map values to field names + return cls.parse_obj(dict(zip(cls.__fields__.keys(), values))) diff --git a/cli/dstack/_internal/core/job.py b/cli/dstack/_internal/core/job.py index 2224ffcad..35daaa58b 100644 --- a/cli/dstack/_internal/core/job.py +++ b/cli/dstack/_internal/core/job.py @@ -1,8 +1,10 @@ +import json from abc import abstractmethod from enum import Enum from typing import Any, Dict, List, Optional, Union from pydantic import BaseModel, Field, root_validator +from typing_extensions import Annotated from dstack._internal.core.app import AppSpec from dstack._internal.core.artifact import ArtifactSpec @@ -20,6 +22,13 @@ ) +class Gateway(BaseModel): + hostname: str + ssh_key: Optional[str] + service_port: int + public_port: int = 80 + + class GpusRequirements(BaseModel): count: Optional[int] = None memory_mib: Optional[int] = None @@ -34,26 +43,6 @@ class Requirements(BaseModel): spot: Optional[bool] = None local: Optional[bool] = None - def serialize(self) -> Dict[str, Any]: - req_data = {} - if self.cpus: - req_data["cpus"] = self.cpus - if self.memory_mib: - req_data["memory_mib"] = self.memory_mib - if self.gpus: - req_data["gpus"] = {"count": self.gpus.count} - if self.gpus.memory_mib: - req_data["gpus"]["memory_mib"] = self.gpus.memory_mib - if self.gpus.name: - req_data["gpus"]["name"] = self.gpus.name - if self.shm_size_mib: - req_data["shm_size_mib"] = self.shm_size_mib - if self.spot: - req_data["spot"] = self.spot - if self.local: - req_data["local"] = self.local - return req_data - def pretty_format(self): res = "" res += f"{self.cpus}xCPUs" @@ -86,6 +75,7 @@ def set_id(self, job_id: Optional[str]): class ConfigurationType(str, Enum): DEV_ENVIRONMENT = "dev-environment" TASK = "task" + SERVICE = "service" class JobStatus(str, Enum): @@ -148,8 +138,8 @@ class JobHead(JobRef): repo_ref: RepoRef hub_user_name: str = "" run_name: str - workflow_name: Optional[str] - provider_name: str + workflow_name: Optional[str] = "" # deprecated + provider_name: Optional[str] = "" # deprecated configuration_path: Optional[str] status: JobStatus error_code: Optional[JobErrorCode] @@ -172,63 +162,57 @@ class RegistryAuth(BaseModel): username: Optional[str] = None password: Optional[str] = None - def serialize(self) -> Dict[str, Any]: - return self.dict(exclude_none=True) - - -def check_dict(element: Any, field: str): - if type(element) == dict: - return element.get(field) - if hasattr(element, field): - return getattr(element, field) - return None - class Job(JobHead): - job_id: Optional[str] - repo_data: Union[RepoData, RemoteRepoData, LocalRepoData] = Field( - ..., discriminator="repo_type" - ) - repo_code_filename: Optional[str] = None - run_name: str - workflow_name: Optional[str] # deprecated - provider_name: Optional[str] # deprecated - configuration_type: Optional[ConfigurationType] + app_names: Optional[List[str]] + app_specs: Optional[List[AppSpec]] + artifact_paths: Optional[List[str]] + artifact_specs: Optional[List[ArtifactSpec]] + build_commands: Optional[List[str]] + build_policy: BuildPolicy = BuildPolicy.USE_BUILD + cache_specs: List[CacheSpec] + commands: Optional[List[str]] configuration_path: Optional[str] - status: JobStatus - error_code: Optional[JobErrorCode] + configuration_type: Optional[ConfigurationType] container_exit_code: Optional[int] created_at: int - submitted_at: int - submission_num: int = 1 - image_name: str - registry_auth: Optional[RegistryAuth] - setup: Optional[List[str]] - commands: Optional[List[str]] + dep_specs: Optional[List[DepSpec]] entrypoint: Optional[List[str]] env: Optional[Dict[str, str]] + error_code: Optional[JobErrorCode] + gateway: Optional[Gateway] home_dir: Optional[str] - working_dir: Optional[str] - artifact_specs: Optional[List[ArtifactSpec]] - cache_specs: List[CacheSpec] host_name: Optional[str] + hub_user_name: str = "" + image_name: str + instance_spot_type: Optional[str] + instance_type: Optional[str] + job_id: str + location: Optional[str] + master_job: Optional[str] # not implemented + max_duration: Optional[int] + provider_name: Optional[str] = "" # deprecated + registry_auth: Optional[RegistryAuth] + repo_code_filename: Optional[str] + repo_data: Annotated[ + Union[RepoData, RemoteRepoData, LocalRepoData], Field(discriminator="repo_type") + ] + repo_ref: RepoRef + request_id: Optional[str] requirements: Optional[Requirements] - spot_policy: Optional[SpotPolicy] retry_policy: Optional[RetryPolicy] - termination_policy: Optional[TerminationPolicy] - max_duration: Optional[int] - dep_specs: Optional[List[DepSpec]] - master_job: Optional[JobRef] - app_specs: Optional[List[AppSpec]] + run_name: str runner_id: Optional[str] - request_id: Optional[str] - location: Optional[str] - tag_name: Optional[str] + setup: Optional[List[str]] + spot_policy: Optional[SpotPolicy] ssh_key_pub: Optional[str] - build_policy: BuildPolicy = BuildPolicy.USE_BUILD - build_commands: Optional[List[str]] - optional_build_commands: Optional[List[str]] - run_env: Optional[Dict[str, str]] # deprecated + status: JobStatus + submission_num: int = 1 + submitted_at: int + tag_name: Optional[str] + termination_policy: Optional[TerminationPolicy] + workflow_name: Optional[str] = "" # deprecated + working_dir: Optional[str] @root_validator(pre=True) def preprocess_data(cls, data): @@ -251,230 +235,12 @@ def get_instance_spot_type(self) -> str: return "on-demand" def serialize(self) -> dict: - deps = [] - if self.dep_specs: - for dep in self.dep_specs: - deps.append( - { - "repo_id": dep.repo_ref.repo_id, - "hub_user_name": self.hub_user_name, - "run_name": dep.run_name, - "mount": dep.mount, - } - ) - artifacts = [] - if self.artifact_specs: - for artifact_spec in self.artifact_specs: - artifacts.append( - {"path": artifact_spec.artifact_path, "mount": artifact_spec.mount} - ) - job_data = { - "job_id": self.job_id, - "repo_id": self.repo.repo_id, - "hub_user_name": self.hub_user_name, - "repo_type": self.repo.repo_data.repo_type, - "run_name": self.run_name, - "workflow_name": self.workflow_name or "", - "provider_name": self.provider_name, - "configuration_type": self.configuration_type.value - if self.configuration_type - else None, - "configuration_path": self.configuration_path, - "status": self.status.value, - "error_code": self.error_code.value if self.error_code is not None else "", - "container_exit_code": self.container_exit_code or "", - "created_at": self.created_at, - "submitted_at": self.submitted_at, - "submission_num": self.submission_num, - "image_name": self.image_name, - "registry_auth": self.registry_auth.serialize() if self.registry_auth else {}, - "setup": self.setup or [], - "commands": self.commands or [], - "entrypoint": self.entrypoint, - "env": self.env or {}, - "home_dir": self.home_dir or "", - "working_dir": self.working_dir or "", - "artifacts": artifacts, - "cache": [item.dict() for item in self.cache_specs], - "host_name": self.host_name or "", - "spot_policy": self.spot_policy.value if self.spot_policy else None, - "retry_policy": self.retry_policy.dict() if self.retry_policy else None, - "termination_policy": self.termination_policy.value - if self.termination_policy - else None, - "max_duration": self.max_duration or None, - "requirements": self.requirements.serialize() if self.requirements else {}, - "deps": deps, - "master_job_id": self.master_job.get_id() if self.master_job else "", - "apps": [ - { - "port": a.port, - "map_to_port": a.map_to_port, - "app_name": a.app_name, - "url_path": a.url_path or "", - "url_query_params": a.url_query_params or {}, - } - for a in self.app_specs - ] - if self.app_specs - else [], - "runner_id": self.runner_id or "", - "request_id": self.request_id or "", - "location": self.location or "", - "tag_name": self.tag_name or "", - "ssh_key_pub": self.ssh_key_pub or "", - "repo_code_filename": self.repo_code_filename, - "instance_type": self.instance_type, - "build_policy": self.build_policy.value, - "build_commands": self.build_commands or [], - "optional_build_commands": self.optional_build_commands or [], - "run_env": self.run_env or {}, - } - if isinstance(self.repo_data, RemoteRepoData): - job_data["repo_host_name"] = self.repo_data.repo_host_name - job_data["repo_port"] = self.repo_data.repo_port or 0 - job_data["repo_user_name"] = self.repo_data.repo_user_name - job_data["repo_name"] = self.repo_data.repo_name - job_data["repo_branch"] = self.repo_data.repo_branch or "" - job_data["repo_hash"] = self.repo_data.repo_hash or "" - job_data["repo_config_name"] = self.repo_data.repo_config_name or "" - job_data["repo_config_email"] = self.repo_data.repo_config_email or "" - return job_data + # hack to convert enum to string + return json.loads(self.json(exclude_none=True)) @staticmethod - def unserialize(job_data: dict): - _requirements = job_data.get("requirements") - requirements = ( - Requirements( - cpus=_requirements.get("cpus") or None, - memory_mib=_requirements.get("memory_mib") or None, - gpus=GpusRequirements( - count=_requirements["gpus"].get("count") or None, - memory_mib=_requirements["gpus"].get("memory") or None, - name=_requirements["gpus"].get("name") or None, - ) - if _requirements.get("gpus") - else None, - shm_size_mib=_requirements.get("shm_size_mib") or None, - spot=_requirements.get("spot") or _requirements.get("interruptible"), - local=_requirements.get("local") or None, - ) - if _requirements - else Requirements() - ) - spot_policy = job_data.get("spot_policy") - retry_policy = None - if job_data.get("retry_policy") is not None: - retry_policy = RetryPolicy.parse_obj(job_data.get("retry_policy")) - termination_policy = job_data.get("termination_policy") - dep_specs = [] - if job_data.get("deps"): - for dep in job_data["deps"]: - dep_spec = DepSpec( - repo_ref=RepoRef(repo_id=dep["repo_id"]), - run_name=dep["run_name"], - mount=dep.get("mount") is True, - ) - dep_specs.append(dep_spec) - artifact_specs = [] - if job_data.get("artifacts"): - for artifact in job_data["artifacts"]: - if isinstance(artifact, str): - artifact_spec = ArtifactSpec(artifact_path=artifact, mount=False) - else: - artifact_spec = ArtifactSpec( - artifact_path=artifact["path"], mount=artifact.get("mount") is True - ) - artifact_specs.append(artifact_spec) - master_job = ( - JobRefId(job_id=job_data["master_job_id"]) if job_data.get("master_job_id") else None - ) - app_specs = ( - [ - AppSpec( - port=a.get("port", 0), - map_to_port=a.get("map_to_port"), - app_name=a["app_name"], - url_path=a.get("url_path") or None, - url_query_params=a.get("url_query_params") or None, - ) - for a in (job_data.get("apps") or []) - ] - ) or None - error_code = job_data.get("error_code") - container_exit_code = job_data.get("container_exit_code") - configuration_type = job_data.get("configuration_type") - - if job_data["repo_type"] == "remote": - repo_data = RemoteRepoData( - repo_host_name=job_data["repo_host_name"], - repo_port=job_data.get("repo_port") or None, - repo_user_name=job_data["repo_user_name"], - repo_name=job_data["repo_name"], - repo_branch=job_data.get("repo_branch") or None, - repo_hash=job_data.get("repo_hash") or None, - repo_config_name=job_data.get("repo_config_name") or None, - repo_config_email=job_data.get("repo_config_email") or None, - ) - elif job_data["repo_type"] == "local": - repo_data = LocalRepoData(repo_dir=job_data.get("repo_dir", "")) - else: - raise TypeError(f"Unknown repo_type: {job_data['repo_type']}") - - job = Job( - job_id=job_data["job_id"], - repo_ref=RepoRef(repo_id=job_data["repo_id"]), - hub_user_name=job_data["hub_user_name"], - repo_data=repo_data, - repo_code_filename=job_data.get("repo_code_filename"), - run_name=job_data["run_name"], - workflow_name=job_data.get("workflow_name") or None, - provider_name=job_data["provider_name"], - configuration_type=ConfigurationType(configuration_type) - if configuration_type - else None, - configuration_path=job_data.get("configuration_path"), - status=JobStatus(job_data["status"]), - error_code=JobErrorCode(error_code) if error_code else None, - container_exit_code=int(container_exit_code) if container_exit_code else None, - created_at=job_data.get("created_at") or job_data["submitted_at"], - submitted_at=job_data["submitted_at"], - submission_num=job_data.get("submission_num") or 1, - image_name=job_data["image_name"], - registry_auth=RegistryAuth(**job_data.get("registry_auth", {})), - setup=job_data.get("setup"), - commands=job_data.get("commands") or None, - entrypoint=job_data.get("entrypoint") or None, - env=job_data["env"] or None, - home_dir=job_data.get("home_dir") or None, - working_dir=job_data.get("working_dir") or None, - artifact_specs=artifact_specs, - cache_specs=[CacheSpec(**item) for item in job_data.get("cache", [])], - host_name=job_data.get("host_name") or None, - spot_policy=SpotPolicy(spot_policy) if spot_policy else None, - retry_policy=retry_policy, - termination_policy=TerminationPolicy(termination_policy) - if termination_policy - else None, - max_duration=int(job_data.get("max_duration")) - if job_data.get("max_duration") - else None, - requirements=requirements, - dep_specs=dep_specs or None, - master_job=master_job, - app_specs=app_specs, - runner_id=job_data.get("runner_id") or None, - request_id=job_data.get("request_id") or None, - location=job_data.get("location") or None, - tag_name=job_data.get("tag_name") or None, - ssh_key_pub=job_data.get("ssh_key_pub") or None, - instance_type=job_data.get("instance_type") or None, - build_policy=job_data.get("build_policy") or BuildPolicy.USE_BUILD, - build_commands=job_data.get("build_commands") or None, - optional_build_commands=job_data.get("optional_build_commands") or None, - run_env=job_data.get("run_env") or None, - ) - return job + def unserialize(job_data: dict) -> "Job": + return Job.parse_obj(job_data) @property def repo(self) -> Repo: @@ -484,23 +250,9 @@ def repo(self) -> Repo: return LocalRepo(repo_ref=self.repo_ref, repo_data=self.repo_data) -class JobSpec(JobRef): - image_name: str - job_id: Optional[str] = None - registry_auth: Optional[RegistryAuth] = None - commands: Optional[List[str]] = None - entrypoint: Optional[List[str]] = None - env: Optional[Dict[str, str]] = None - run_env: Optional[Dict[str, str]] = None - working_dir: Optional[str] = None - artifact_specs: Optional[List[ArtifactSpec]] = None - requirements: Optional[Requirements] = None - master_job: Optional[JobRef] = None - app_specs: Optional[List[AppSpec]] = None - build_commands: Optional[List[str]] = None - - def get_id(self) -> Optional[str]: - return self.job_id - - def set_id(self, job_id: Optional[str]): - self.job_id = job_id +def check_dict(element: Any, field: str): + if type(element) == dict: + return element.get(field) + if hasattr(element, field): + return getattr(element, field) + return None diff --git a/cli/dstack/_internal/hub/main.py b/cli/dstack/_internal/hub/main.py index 2f1ad439b..58a24e7c6 100644 --- a/cli/dstack/_internal/hub/main.py +++ b/cli/dstack/_internal/hub/main.py @@ -17,6 +17,7 @@ artifacts, backends, configurations, + gateways, jobs, link, logs, @@ -54,6 +55,7 @@ app.include_router(repos.router) app.include_router(link.router) app.include_router(configurations.router) +app.include_router(gateways.router) DEFAULT_PROJECT_NAME = "local" diff --git a/cli/dstack/_internal/hub/routers/gateways.py b/cli/dstack/_internal/hub/routers/gateways.py new file mode 100644 index 000000000..e42aa705c --- /dev/null +++ b/cli/dstack/_internal/hub/routers/gateways.py @@ -0,0 +1,42 @@ +from typing import List + +from fastapi import APIRouter, Body, Depends, HTTPException, status + +from dstack._internal.core.gateway import GatewayHead +from dstack._internal.hub.routers.util import error_detail, get_backend, get_project +from dstack._internal.hub.security.permissions import ProjectAdmin, ProjectMember +from dstack._internal.hub.utils.common import run_async +from dstack._internal.hub.utils.ssh import get_hub_ssh_public_key + +router = APIRouter( + prefix="/api/project", tags=["gateways"], dependencies=[Depends(ProjectMember())] +) + + +@router.post("/{project_name}/gateways/create", dependencies=[Depends(ProjectAdmin())]) +async def gateways_create(project_name: str) -> GatewayHead: + project = await get_project(project_name=project_name) + backend = await get_backend(project) + try: + return await run_async(backend.create_gateway, get_hub_ssh_public_key()) + except NotImplementedError: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail=error_detail( + msg=f"Can't create gateway for {backend.name} backend", code="not_implemented" + ), + ) + + +@router.get("/{project_name}/gateways") +async def gateways_list(project_name: str) -> List[GatewayHead]: + project = await get_project(project_name=project_name) + backend = await get_backend(project) + return backend.list_gateways() + + +@router.post("/{project_name}/gateways/delete", dependencies=[Depends(ProjectAdmin())]) +async def gateways_delete(project_name: str, instance_name: str = Body()): + project = await get_project(project_name=project_name) + backend = await get_backend(project) + await run_async(backend.delete_gateway, instance_name) diff --git a/cli/dstack/_internal/hub/utils/ssh.py b/cli/dstack/_internal/hub/utils/ssh.py index f9d39b2ed..1a82e1c3b 100644 --- a/cli/dstack/_internal/hub/utils/ssh.py +++ b/cli/dstack/_internal/hub/utils/ssh.py @@ -1,6 +1,6 @@ from pathlib import Path -from dstack._internal.utils.crypto import generage_rsa_key_pair +from dstack._internal.utils.crypto import generate_rsa_key_pair HUB_PRIVATE_KEY_PATH = Path.home() / ".dstack" / "hub" / "ssh" / "hub_ssh_key" HUB_PUBLIC_KEY_PATH = Path.home() / ".dstack" / "hub" / "ssh" / "hub_ssh_key.pub" @@ -10,7 +10,7 @@ def generate_hub_ssh_key_pair(): if HUB_PRIVATE_KEY_PATH.exists(): return HUB_PRIVATE_KEY_PATH.parent.mkdir(parents=True, exist_ok=True) - generage_rsa_key_pair( + generate_rsa_key_pair( private_key_path=HUB_PRIVATE_KEY_PATH, public_key_path=HUB_PUBLIC_KEY_PATH ) diff --git a/cli/dstack/_internal/utils/crypto.py b/cli/dstack/_internal/utils/crypto.py index fce579abd..538a71eb6 100644 --- a/cli/dstack/_internal/utils/crypto.py +++ b/cli/dstack/_internal/utils/crypto.py @@ -1,36 +1,40 @@ import os from pathlib import Path -from typing import Optional +from typing import Optional, Tuple from cryptography.hazmat.backends import default_backend as crypto_default_backend from cryptography.hazmat.primitives import serialization as crypto_serialization from cryptography.hazmat.primitives.asymmetric import rsa -def generage_rsa_key_pair(private_key_path: Path, public_key_path: Optional[Path] = None): +def generate_rsa_key_pair(private_key_path: Path, public_key_path: Optional[Path] = None): if public_key_path is None: public_key_path = private_key_path.with_suffix(private_key_path.suffix + ".pub") - key = rsa.generate_private_key( - backend=crypto_default_backend(), public_exponent=65537, key_size=2048 - ) + private_bytes, public_bytes = generate_rsa_key_pair_bytes() def key_opener(path, flags): return os.open(path, flags, 0o600) with open(private_key_path, "wb", opener=key_opener) as f: - f.write( - key.private_bytes( - crypto_serialization.Encoding.PEM, - crypto_serialization.PrivateFormat.PKCS8, - crypto_serialization.NoEncryption(), - ) - ) + f.write(private_bytes) with open(public_key_path, "wb", opener=key_opener) as f: - f.write( - key.public_key().public_bytes( - crypto_serialization.Encoding.OpenSSH, - crypto_serialization.PublicFormat.OpenSSH, - ) - ) - f.write(b" dstack\n") + f.write(public_bytes) + + +def generate_rsa_key_pair_bytes(comment: str = "dstack") -> Tuple[bytes, bytes]: + key = rsa.generate_private_key( + backend=crypto_default_backend(), public_exponent=65537, key_size=2048 + ) + private_bytes = key.private_bytes( + crypto_serialization.Encoding.PEM, + crypto_serialization.PrivateFormat.PKCS8, + crypto_serialization.NoEncryption(), + ) + public_bytes = key.public_key().public_bytes( + crypto_serialization.Encoding.OpenSSH, + crypto_serialization.PublicFormat.OpenSSH, + ) + public_bytes += f" {comment}\n".encode() + + return private_bytes, public_bytes diff --git a/cli/dstack/api/hub/_api_client.py b/cli/dstack/api/hub/_api_client.py index 4a5b591a5..3aa9212d5 100644 --- a/cli/dstack/api/hub/_api_client.py +++ b/cli/dstack/api/hub/_api_client.py @@ -1,8 +1,10 @@ +import json from datetime import datetime from typing import Dict, Generator, List, Optional from urllib.parse import urlencode, urlparse, urlunparse import requests +from pydantic import parse_obj_as from dstack._internal.core.artifact import Artifact from dstack._internal.core.build import BuildNotFoundError @@ -11,6 +13,7 @@ BackendValueError, NoMatchingInstanceError, ) +from dstack._internal.core.gateway import GatewayHead from dstack._internal.core.job import Job, JobHead from dstack._internal.core.log_event import LogEvent from dstack._internal.core.plan import RunPlan @@ -685,6 +688,45 @@ def delete_configuration_cache(self, configuration_path: str): return resp.raise_for_status() + def create_gateway(self) -> GatewayHead: + url = _project_url(url=self.url, project=self.project, additional_path="/gateways/create") + resp = _make_hub_request( + requests.post, + host=self.url, + url=url, + headers=self._headers(), + ) + if resp.ok: + return GatewayHead.parse_obj(resp.json()) + if resp.status_code == 400: + body = resp.json() + if body["detail"]["code"] == "not_implemented": + raise HubClientError(body["detail"]["msg"]) + resp.raise_for_status() + + def list_gateways(self) -> List[GatewayHead]: + url = _project_url(url=self.url, project=self.project, additional_path="/gateways") + resp = _make_hub_request( + requests.get, + host=self.url, + url=url, + headers=self._headers(), + ) + if not resp.ok: + resp.raise_for_status() + return parse_obj_as(List[GatewayHead], resp.json()) + + def delete_gateway(self, instance_name: str): + url = _project_url(url=self.url, project=self.project, additional_path="/gateways/delete") + resp = _make_hub_request( + requests.post, + host=self.url, + url=url, + headers=self._headers(), + data=json.dumps(instance_name), + ) + resp.raise_for_status() + def _project_url(url: str, project: str, additional_path: str, query: Optional[dict] = None): query = {} if query is None else query diff --git a/cli/dstack/api/hub/_client.py b/cli/dstack/api/hub/_client.py index 33df23687..d82600015 100644 --- a/cli/dstack/api/hub/_client.py +++ b/cli/dstack/api/hub/_client.py @@ -11,6 +11,7 @@ from dstack._internal.api.repos import get_local_repo_credentials from dstack._internal.backend.base import artifacts as base_artifacts from dstack._internal.core.artifact import Artifact +from dstack._internal.core.gateway import GatewayHead from dstack._internal.core.job import Job, JobHead, JobStatus from dstack._internal.core.log_event import LogEvent from dstack._internal.core.plan import RunPlan @@ -309,3 +310,12 @@ def _upload_code_file(self) -> str: f.seek(0) self._storage.upload_file(f.name, repo_code_filename, lambda _: ...) return repo_code_filename + + def create_gateway(self) -> GatewayHead: + return self._api_client.create_gateway() + + def list_gateways(self) -> List[GatewayHead]: + return self._api_client.list_gateways() + + def delete_gateway(self, instance_name: str): + self._api_client.delete_gateway(instance_name) diff --git a/cli/tests/core/__init__.py b/cli/tests/core/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/cli/tests/core/test_head.py b/cli/tests/core/test_head.py new file mode 100644 index 000000000..2c574bf24 --- /dev/null +++ b/cli/tests/core/test_head.py @@ -0,0 +1,30 @@ +from typing import Optional + +from dstack._internal.core.head import BaseHead + + +class TestHead(BaseHead): + id: int + a: str + b: Optional[str] + c: str + + @classmethod + def prefix(cls) -> str: + return "test/l;" + + +def test_prefix(): + assert TestHead.prefix() == "test/l;" + + +def test_decode(): + h = TestHead.decode("test/l;123;var;;a;b;c;d") + assert h.id == 123 + assert h.a == "var" + assert h.b == "" + assert h.c == "a;b;c;d" + + +def test_encode(): + assert TestHead(id=123, a="var", c="a;b;c;d").encode() == "test/l;123;var;;a;b;c;d" diff --git a/docs/docs/reference/cli/gateway.md b/docs/docs/reference/cli/gateway.md new file mode 100644 index 000000000..c4b72ab7e --- /dev/null +++ b/docs/docs/reference/cli/gateway.md @@ -0,0 +1,80 @@ +# dstack gateway + +Gateway makes running jobs (`type: service`) accessible from the public internet. + +!!! info "NOTE:" + Many domains could be attached to the same gateway. Many jobs could use the same gateway. + +## dstack gateway list + +The `dstack gateway list` command displays the names and addresses of the gateways configured in the selected project. + +### Usage + +
+ +```shell +$ dstack gateway list --project gcp +``` + +
+ +## dstack gateway create + +The `dstack gateway create` command creates a new gateway instance in the project. + +### Usage + +
+ +```shell +$ dstack gateway create --help +Usage: dstack gateway create [-h] [--project PROJECT] + +Optional Arguments: + -h, --help show this help message and exit + --project PROJECT The name of the project +``` + +
+ +### Arguments reference + +The following arguments are optional: + +- `--project PROJECT` - (Optional) The name of the project to execute the command for + + +## dstack gateway delete + +The `dstack gateway delete` command deletes the specified gateway. + +### Usage + +
+ +```shell +$ dstack gateway delete --help +Usage: dstack gateway delete [-h] [--project PROJECT] [-y] NAME + +Positional Arguments: + NAME The name of the gateway + +Optional Arguments: + -h, --help show this help message and exit + --project PROJECT The name of the project + -y, --yes Don't ask for confirmation +``` + +
+ +### Arguments reference + +The following arguments are required: + +- `NAME` - (Required) A name of the gateway + +The following arguments are optional: + +- `--project PROJECT` - (Optional) The name of the project to execute the command for +- `-y`, `--yes` – (Optional) Don't ask for confirmation diff --git a/docs/docs/reference/dstack.yml.md b/docs/docs/reference/dstack.yml.md index bf47b9423..459bbd20f 100644 --- a/docs/docs/reference/dstack.yml.md +++ b/docs/docs/reference/dstack.yml.md @@ -1,7 +1,7 @@ # .dstack.yml -Configurations are YAML files that describe what you want to run with `dstack`. Configurations can be of two -types: `dev-environment` and `task`. +Configurations are YAML files that describe what you want to run with `dstack`. Configurations can be of three +types: `dev-environment`, `task`, and `service`. !!! info "Filename" The configuration file must be named with the suffix `.dstack.yml`. For example, @@ -12,12 +12,12 @@ types: `dev-environment` and `task`. Below is a full reference of all available properties. -- `type` - (Required) The type of the configurations. Can be `dev-environment` or `task`. +- `type` - (Required) The type of the configurations. Can be `dev-environment`, `task`, or `service`. - `image` - (Optional) The name of the Docker image. - `entrypoint` - (Optional) The Docker entrypoint. - `build` - (Optional) The list of bash commands to build the environment. - `ide` - (Required if `type` is `dev-environment`). Can be `vscode`. -- `ports` - (Optional) The list of port numbers to expose. +- `ports` - (Optional) The list of port numbers to expose (only for `dev-environment` and `task`). - `env` - (Optional) The mapping or the list of environment variables (e.g. `PYTHONPATH: src` or `PYTHONPATH=src`). - `registry_auth` - (Optional) Credentials to pull the private Docker image. - `username` - (Required) Username. @@ -26,6 +26,9 @@ Below is a full reference of all available properties. - `commands` - (Required if `type` is `task`). The list of bash commands to run as a task. - `python` - (Optional) The major version of Python to pre-install (e.g., `"3.11"`). Defaults to the current version installed locally. Mutually exclusive with `image`. - `cache` - (Optional) The directories to be cached between runs. +- `gateway` - (Required if `type` is `service`) Gateway configuration. + - `hostname` (Required) The address or the domain pointing to the gateway. + - `service_port` (Required) The application port. [//]: # (- `home_dir` - (Optional) The absolute path to the home directory inside the container) diff --git a/mkdocs.yml b/mkdocs.yml index 1da80cc12..66c811064 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -166,6 +166,7 @@ nav: - dstack secrets: docs/reference/cli/secrets.md - dstack prune: docs/reference/cli/prune.md - dstack build: docs/reference/cli/build.md + - dstack gateway: docs/reference/cli/gateway.md - API: - Python: docs/reference/api/python.md - Backends: diff --git a/runner/internal/backend/aws/backend.go b/runner/internal/backend/aws/backend.go index 138946920..bed999b8f 100644 --- a/runner/internal/backend/aws/backend.go +++ b/runner/internal/backend/aws/backend.go @@ -195,7 +195,7 @@ func (s *AWSBackend) MasterJob(ctx context.Context) *models.Job { log.Trace(ctx, "State not exist") return nil } - theFile, err := base.GetObject(ctx, s.storage, fmt.Sprintf("jobs/%s/%s.yaml", s.State.Job.RepoId, s.State.Job.MasterJobID)) + theFile, err := base.GetObject(ctx, s.storage, fmt.Sprintf("jobs/%s/%s.yaml", s.State.Job.RepoRef.RepoId, s.State.Job.MasterJobID)) if err != nil { return nil } @@ -269,7 +269,7 @@ func (s *AWSBackend) Secrets(ctx context.Context) (map[string]string, error) { for _, secretPath := range listSecrets { clearName := strings.ReplaceAll(secretPath, prefix, "") secrets[clearName] = fmt.Sprintf("%s/%s", - s.State.Job.RepoId, + s.State.Job.RepoRef.RepoId, clearName) } return s.cliSecret.fetchSecret(ctx, s.bucket, secrets) @@ -289,7 +289,7 @@ func (s *AWSBackend) GitCredentials(ctx context.Context) *models.GitCredentials log.Error(ctx, "Job is empty") return nil } - return s.cliSecret.fetchCredentials(ctx, s.bucket, s.State.Job.RepoId) + return s.cliSecret.fetchCredentials(ctx, s.bucket, s.State.Job.RepoRef.RepoId) } func (s *AWSBackend) GetRepoDiff(ctx context.Context, path string) (string, error) { diff --git a/runner/internal/backend/azure/backend.go b/runner/internal/backend/azure/backend.go index c1bf8d72e..aec52fbe2 100644 --- a/runner/internal/backend/azure/backend.go +++ b/runner/internal/backend/azure/backend.go @@ -186,7 +186,7 @@ func (azbackend *AzureBackend) Secrets(ctx context.Context) (map[string]string, secrets := make(map[string]string, 0) for _, secretFilename := range secretFilenames { secretName := strings.ReplaceAll(secretFilename, prefix, "") - secretValue, err := azbackend.secretManager.FetchSecret(ctx, azbackend.state.Job.RepoId, secretName) + secretValue, err := azbackend.secretManager.FetchSecret(ctx, azbackend.state.Job.RepoRef.RepoId, secretName) if err != nil { if errors.Is(err, ErrSecretNotFound) { continue @@ -200,7 +200,7 @@ func (azbackend *AzureBackend) Secrets(ctx context.Context) (map[string]string, func (azbackend *AzureBackend) GitCredentials(ctx context.Context) *models.GitCredentials { log.Trace(ctx, "Getting credentials") - creds, err := azbackend.secretManager.FetchCredentials(ctx, azbackend.state.Job.RepoId) + creds, err := azbackend.secretManager.FetchCredentials(ctx, azbackend.state.Job.RepoRef.RepoId) if err != nil { log.Error(ctx, "Getting credentials failure: %+v", err) return nil diff --git a/runner/internal/backend/base/backend.go b/runner/internal/backend/base/backend.go index 3124a51b2..3b5209504 100644 --- a/runner/internal/backend/base/backend.go +++ b/runner/internal/backend/base/backend.go @@ -66,7 +66,7 @@ func UpdateState(ctx context.Context, storage Storage, job *models.Job) error { return gerrors.Wrap(err) } // should it be a job.HubUserName? - log.Trace(ctx, "Fetching list jobs", "Repo username", job.RepoUserName, "Repo name", job.RepoName, "Job ID", job.JobID) + log.Trace(ctx, "Fetching list jobs", "Repo username", job.RepoData.RepoUserName, "Repo name", job.RepoData.RepoName, "Job ID", job.JobID) files, err := ListObjects(ctx, storage, job.JobHeadFilepathPrefix()) if err != nil { return gerrors.Wrap(err) @@ -138,7 +138,7 @@ func getBuildDiffPrefix(spec *docker.BuildSpec) string { return fmt.Sprintf( "builds/%s/%s;%s;%s;%s;%s;", spec.RepoId, - models.EscapeHead(spec.ConfigurationType), + models.EscapeHead(string(spec.ConfigurationType)), models.EscapeHead(spec.ConfigurationPath), models.EscapeHead(spec.WorkDir), models.EscapeHead(spec.BaseImageName), diff --git a/runner/internal/backend/gcp/backend.go b/runner/internal/backend/gcp/backend.go index 40b2ee994..794b4fb50 100644 --- a/runner/internal/backend/gcp/backend.go +++ b/runner/internal/backend/gcp/backend.go @@ -192,7 +192,7 @@ func (gbackend *GCPBackend) Secrets(ctx context.Context) (map[string]string, err secrets := make(map[string]string, 0) for _, secretFilename := range secretFilenames { secretName := strings.ReplaceAll(secretFilename, prefix, "") - secretValue, err := gbackend.secretManager.FetchSecret(ctx, gbackend.state.Job.RepoId, secretName) + secretValue, err := gbackend.secretManager.FetchSecret(ctx, gbackend.state.Job.RepoRef.RepoId, secretName) if err != nil { if errors.Is(err, ErrSecretNotFound) { continue @@ -206,7 +206,7 @@ func (gbackend *GCPBackend) Secrets(ctx context.Context) (map[string]string, err func (gbackend *GCPBackend) GitCredentials(ctx context.Context) *models.GitCredentials { log.Trace(ctx, "Getting credentials") - creds, err := gbackend.secretManager.FetchCredentials(ctx, gbackend.state.Job.RepoId) + creds, err := gbackend.secretManager.FetchCredentials(ctx, gbackend.state.Job.RepoRef.RepoId) if err != nil { return nil } diff --git a/runner/internal/backend/local/backend.go b/runner/internal/backend/local/backend.go index a7bb6d8e8..f42b202a6 100644 --- a/runner/internal/backend/local/backend.go +++ b/runner/internal/backend/local/backend.go @@ -88,7 +88,7 @@ func (l *Local) RefetchJob(ctx context.Context) (*models.Job, error) { } func (l *Local) MasterJob(ctx context.Context) *models.Job { - contents, err := base.GetObject(ctx, l.storage, filepath.Join("jobs", l.state.Job.RepoUserName, l.state.Job.RepoName, fmt.Sprintf("%s.yaml", l.state.Job.MasterJobID))) + contents, err := base.GetObject(ctx, l.storage, filepath.Join("jobs", l.state.Job.RepoData.RepoUserName, l.state.Job.RepoData.RepoName, fmt.Sprintf("%s.yaml", l.state.Job.MasterJobID))) if err != nil { return nil } @@ -156,12 +156,12 @@ func (l *Local) GetJobByPath(ctx context.Context, path string) (*models.Job, err func (l *Local) GitCredentials(ctx context.Context) *models.GitCredentials { log.Trace(ctx, "Getting credentials") - return l.cliSecret.fetchCredentials(ctx, l.state.Job.RepoId) + return l.cliSecret.fetchCredentials(ctx, l.state.Job.RepoRef.RepoId) } func (l *Local) Secrets(ctx context.Context) (map[string]string, error) { log.Trace(ctx, "Getting secrets") - templatePath := fmt.Sprintf("secrets/%s", l.state.Job.RepoId) + templatePath := fmt.Sprintf("secrets/%s", l.state.Job.RepoRef.RepoId) if _, err := os.Stat(filepath.Join(l.path, templatePath)); err != nil { return map[string]string{}, nil } @@ -177,7 +177,7 @@ func (l *Local) Secrets(ctx context.Context) (map[string]string, error) { if strings.HasPrefix(file.Name(), "l;") { clearName := strings.ReplaceAll(file.Name(), "l;", "") secrets[clearName] = fmt.Sprintf("%s/%s", - l.state.Job.RepoId, + l.state.Job.RepoRef.RepoId, clearName) } } diff --git a/runner/internal/docker/build.go b/runner/internal/docker/build.go index 99bde70dc..5d3b4ae87 100644 --- a/runner/internal/docker/build.go +++ b/runner/internal/docker/build.go @@ -4,13 +4,14 @@ import ( "bytes" "crypto/sha256" "fmt" + "github.com/dstackai/dstack/runner/internal/models" ) type BuildSpec struct { BaseImageID string WorkDir string ConfigurationPath string - ConfigurationType string + ConfigurationType models.ConfigurationType Commands []string Entrypoint []string @@ -32,7 +33,7 @@ func (s *BuildSpec) Hash() string { buffer.WriteString("\n") buffer.WriteString(s.ConfigurationPath) buffer.WriteString("\n") - buffer.WriteString(s.ConfigurationType) + buffer.WriteString(string(s.ConfigurationType)) buffer.WriteString("\n") return fmt.Sprintf("%x", sha256.Sum256(buffer.Bytes())) } diff --git a/runner/internal/docker/engine.go b/runner/internal/docker/engine.go index 02965ab78..e59b1625f 100644 --- a/runner/internal/docker/engine.go +++ b/runner/internal/docker/engine.go @@ -342,8 +342,6 @@ func (e *Engine) NewBuildSpec(ctx context.Context, job *models.Job, spec *Spec, return nil, gerrors.Wrap(err) } - commands := append([]string{}, job.BuildCommands...) - commands = append(commands, job.OptionalBuildCommands...) env := environment.New() env.AddMapString(secrets) @@ -353,12 +351,12 @@ func (e *Engine) NewBuildSpec(ctx context.Context, job *models.Job, spec *Spec, WorkDir: spec.WorkDir, ConfigurationPath: job.ConfigurationPath, ConfigurationType: job.ConfigurationType, - Commands: ShellCommands(InsertEnvs(commands, job.Environment)), + Commands: ShellCommands(InsertEnvs(job.BuildCommands, job.Environment)), Entrypoint: spec.Entrypoint, Env: env.ToSlice(), RegistryAuthBase64: spec.RegistryAuthBase64, RepoPath: repoPath, - RepoId: job.RepoId, + RepoId: job.RepoRef.RepoId, ShmSize: spec.ShmSize, } if daemonInfo.Architecture == "aarch64" { diff --git a/runner/internal/executor/executor.go b/runner/internal/executor/executor.go index e7d65694e..8cf3b3818 100644 --- a/runner/internal/executor/executor.go +++ b/runner/internal/executor/executor.go @@ -6,6 +6,8 @@ import ( "encoding/json" "errors" "fmt" + "github.com/docker/go-connections/nat" + "github.com/dstackai/dstack/runner/internal/gateway" "io" "os" "path" @@ -102,12 +104,12 @@ func (ex *Executor) Init(ctx context.Context, configDir string) error { } for _, artifact := range job.Artifacts { - artOut := ex.backend.GetArtifact(ctx, job.RunName, artifact.Path, path.Join("artifacts", job.RepoId, job.JobID, artifact.Path), artifact.Mount) + artOut := ex.backend.GetArtifact(ctx, job.RunName, artifact.Path, path.Join("artifacts", job.RepoRef.RepoId, job.JobID, artifact.Path), artifact.Mount) if artOut != nil { ex.artifactsOut = append(ex.artifactsOut, artOut) } if artifact.Mount { - art := ex.backend.GetArtifact(ctx, job.RunName, artifact.Path, path.Join("artifacts", job.RepoId, job.JobID, artifact.Path), artifact.Mount) + art := ex.backend.GetArtifact(ctx, job.RunName, artifact.Path, path.Join("artifacts", job.RepoRef.RepoId, job.JobID, artifact.Path), artifact.Mount) if art != nil { ex.artifactsFUSE = append(ex.artifactsFUSE, art) } @@ -259,8 +261,8 @@ func (ex *Executor) runJob(ctx context.Context, erCh chan error, stoppedCh chan } }() - logger := ex.backend.CreateLogger(ctx, fmt.Sprintf("/dstack/jobs/%s/%s", ex.backend.Bucket(ctx), job.RepoId), job.RunName) - logGroup := fmt.Sprintf("/jobs/%s", job.RepoId) + logger := ex.backend.CreateLogger(ctx, fmt.Sprintf("/dstack/jobs/%s/%s", ex.backend.Bucket(ctx), job.RepoRef.RepoId), job.RunName) + logGroup := fmt.Sprintf("/jobs/%s", job.RepoRef.RepoId) fileLog, err := createLocalLog(filepath.Join(ex.configDir, "logs", logGroup), job.RunName) if err != nil { erCh <- gerrors.Wrap(err) @@ -286,7 +288,7 @@ func (ex *Executor) startJob(ctx context.Context, erCh chan error, stoppedCh cha } var err error - switch job.RepoType { + switch job.RepoData.RepoType { case "remote": log.Trace(ctx, "Fetching git repository") if err = ex.prepareGit(ctx); err != nil { @@ -300,7 +302,7 @@ func (ex *Executor) startJob(ctx context.Context, erCh chan error, stoppedCh cha return } default: - log.Error(ctx, "Unknown RepoType", "RepoType", job.RepoType) + log.Error(ctx, "Unknown RepoType", "RepoType", job.RepoData.RepoType) } if job.BuildPolicy != models.BuildOnly { @@ -379,6 +381,27 @@ func (ex *Executor) startJob(ctx context.Context, erCh chan error, stoppedCh cha erCh <- gerrors.Wrap(err) return } + + var gatewayControl *gateway.SSHControl + if job.ConfigurationType == "service" { + binding, ok := spec.BindingPorts[nat.Port(fmt.Sprintf("%d/tcp", job.Gateway.ServicePort))] + if !ok { + erCh <- gerrors.Newf("gateway: job doesn't expose port %d", job.Gateway.ServicePort) + return + } + localPort := binding[0].HostPort + gatewayControl, err = gateway.NewSSHControl(job.Gateway.Hostname, job.Gateway.SSHKey) + if err != nil { + erCh <- gerrors.Wrap(err) + return + } + defer gatewayControl.Cleanup() + if err := gatewayControl.Publish(localPort, strconv.Itoa(job.Gateway.PublicPort)); err != nil { + erCh <- gerrors.Wrap(err) + return + } + } + container, err := ex.engine.CreateNamed(ctx, spec, job.RunName, allLogs) if err != nil { erCh <- gerrors.Wrap(err) @@ -425,7 +448,7 @@ func (ex *Executor) prepareGit(ctx context.Context) error { } } - ex.repo = repo.NewManager(ctx, fmt.Sprintf(consts.REPO_HTTPS_URL, job.RepoHostNameWithPort(), job.RepoUserName, job.RepoName), job.RepoBranch, job.RepoHash).WithLocalPath(dir) + ex.repo = repo.NewManager(ctx, fmt.Sprintf(consts.REPO_HTTPS_URL, job.RepoHostNameWithPort(), job.RepoData.RepoUserName, job.RepoData.RepoName), job.RepoData.RepoBranch, job.RepoData.RepoHash).WithLocalPath(dir) cred := ex.backend.GitCredentials(ctx) if cred != nil { log.Trace(ctx, "Credentials is not empty") @@ -447,7 +470,7 @@ func (ex *Executor) prepareGit(ctx context.Context) error { if cred.Passphrase != nil { password = *cred.Passphrase } - ex.repo = repo.NewManager(ctx, fmt.Sprintf(consts.REPO_GIT_URL, job.RepoHostNameWithPort(), job.RepoUserName, job.RepoName), job.RepoBranch, job.RepoHash).WithLocalPath(dir) + ex.repo = repo.NewManager(ctx, fmt.Sprintf(consts.REPO_GIT_URL, job.RepoHostNameWithPort(), job.RepoData.RepoUserName, job.RepoData.RepoName), job.RepoData.RepoBranch, job.RepoData.RepoHash).WithLocalPath(dir) ex.repo.WithSSHAuth(*cred.PrivateKey, password) default: log.Error(ctx, "Unsupported protocol", "protocol", cred.Protocol) @@ -458,7 +481,7 @@ func (ex *Executor) prepareGit(ctx context.Context) error { log.Trace(ctx, "GIT checkout error", "err", err, "GIT URL", ex.repo.URL()) return gerrors.Wrap(err) } - if err := ex.repo.SetConfig(job.RepoConfigName, job.RepoConfigEmail); err != nil { + if err := ex.repo.SetConfig(job.RepoData.RepoConfigName, job.RepoData.RepoConfigEmail); err != nil { return gerrors.Wrap(err) } @@ -505,7 +528,7 @@ func (ex *Executor) processDeps(ctx context.Context) error { return gerrors.Wrap(err) } for _, artifact := range jobDep.Artifacts { - artIn := ex.backend.GetArtifact(ctx, jobDep.RunName, artifact.Path, path.Join("artifacts", jobDep.RepoId, jobDep.JobID, artifact.Path), artifact.Mount) + artIn := ex.backend.GetArtifact(ctx, jobDep.RunName, artifact.Path, path.Join("artifacts", jobDep.RepoRef.RepoId, jobDep.JobID, artifact.Path), artifact.Mount) if artIn != nil { ex.artifactsIn = append(ex.artifactsIn, artIn) } @@ -518,7 +541,7 @@ func (ex *Executor) processDeps(ctx context.Context) error { func (ex *Executor) processCache(ctx context.Context) error { job := ex.backend.Job(ctx) for _, cache := range job.Cache { - cacheArt := ex.backend.GetCache(ctx, job.RunName, cache.Path, path.Join("cache", job.RepoId, job.HubUserName, models.EscapeHead(job.ConfigurationPath), cache.Path)) + cacheArt := ex.backend.GetCache(ctx, job.RunName, cache.Path, path.Join("cache", job.RepoRef.RepoId, job.HubUserName, models.EscapeHead(job.ConfigurationPath), cache.Path)) if cacheArt != nil { ex.cacheArtifacts = append(ex.cacheArtifacts, cacheArt) } @@ -564,7 +587,7 @@ func (ex *Executor) newSpec(ctx context.Context, credPath string) (*docker.Spec, } bindings = append(bindings, art...) } - if job.RepoType == "remote" && job.HomeDir != "" { + if job.RepoData.RepoType == "remote" && job.HomeDir != "" { cred := ex.backend.GitCredentials(ctx) if cred != nil { log.Trace(ctx, "Trying to mount git credentials") @@ -580,7 +603,7 @@ func (ex *Executor) newSpec(ctx context.Context, credPath string) (*docker.Spec, case "https": if cred.OAuthToken != nil { credMountPath = path.Join(job.HomeDir, ".config/gh/hosts.yml") - ghHost := fmt.Sprintf("%s:\n oauth_token: \"%s\"\n", job.RepoHostName, *cred.OAuthToken) + ghHost := fmt.Sprintf("%s:\n oauth_token: \"%s\"\n", job.RepoData.RepoHostName, *cred.OAuthToken) if err := os.WriteFile(credPath, []byte(ghHost), 0644); err != nil { log.Error(ctx, "Failed writing credentials", "err", err) } @@ -652,7 +675,7 @@ func (ex *Executor) environment(ctx context.Context, includeRun bool) []string { if includeRun { cons := make(map[string]string) cons["PYTHONUNBUFFERED"] = "1" - cons["DSTACK_REPO"] = job.RepoId + cons["DSTACK_REPO"] = job.RepoRef.RepoId cons["JOB_ID"] = job.JobID cons["RUN_NAME"] = job.RunName @@ -668,7 +691,6 @@ func (ex *Executor) environment(ctx context.Context, includeRun bool) []string { cons["MASTER_JOB_ID"] = master.JobID cons["MASTER_JOB_HOSTNAME"] = master.HostName } - env.AddMapString(job.RunEnvironment) env.AddMapString(cons) } secrets, err := ex.backend.Secrets(ctx) @@ -734,7 +756,7 @@ func (ex *Executor) runContainer(ctx context.Context, container *docker.Containe func (ex *Executor) build(ctx context.Context, spec *docker.Spec, stoppedCh chan bool, logs io.Writer) error { job := ex.backend.Job(ctx) - if len(job.BuildCommands) == 0 && len(job.OptionalBuildCommands) == 0 { + if len(job.BuildCommands) == 0 { return nil } secrets, err := ex.backend.Secrets(ctx) diff --git a/runner/internal/gateway/ssh.go b/runner/internal/gateway/ssh.go new file mode 100644 index 000000000..0dc59a4f0 --- /dev/null +++ b/runner/internal/gateway/ssh.go @@ -0,0 +1,119 @@ +package gateway + +import ( + "fmt" + "github.com/dstackai/dstack/runner/internal/gerrors" + "os" + "os/exec" + "path" + "path/filepath" + "strings" +) + +type SSHControl struct { + keyPath string + controlPath string + hostname string + user string + remoteTempDir string + localTempDir string +} + +func NewSSHControl(hostname, sshKey string) (*SSHControl, error) { + localTempDir, err := os.MkdirTemp("", "") + if err != nil { + return nil, gerrors.Wrap(err) + } + keyPath := filepath.Join(localTempDir, "id_rsa") + if err := os.WriteFile(keyPath, []byte(sshKey), 0o600); err != nil { + return nil, gerrors.Wrap(err) + } + c := &SSHControl{ + keyPath: keyPath, + controlPath: filepath.Join(localTempDir, "ssh.control"), + hostname: hostname, + user: "ubuntu", + localTempDir: localTempDir, + } + err = c.mkTempDir() + return c, gerrors.Wrap(err) +} + +func (c *SSHControl) exec(args []string, command string) ([]byte, error) { + allArgs := []string{ + "-i", c.keyPath, + "-o", "StrictHostKeyChecking=accept-new", + "-o", fmt.Sprintf("ControlPath=%s", c.controlPath), + "-o", "ControlMaster=auto", + "-o", "ControlPersist=yes", + } + if args != nil { + allArgs = append(allArgs, args...) + } + allArgs = append(allArgs, fmt.Sprintf("%s@%s", c.user, c.hostname)) + if command != "" { + allArgs = append(allArgs, command) + } + fmt.Println(allArgs) + cmd := exec.Command("ssh", allArgs...) + stdout, err := cmd.Output() + return stdout, gerrors.Wrap(err) +} + +func (c *SSHControl) mkTempDir() error { + tempDir, err := c.exec(nil, "mktemp -d /tmp/dstack-XXXXXXXX") + if err != nil { + return gerrors.Wrap(err) + } + c.remoteTempDir = strings.Trim(string(tempDir), "\n") + return nil +} + +func (c *SSHControl) Publish(localPort, publicPort string) error { + // run tunnel in background + _, err := c.exec([]string{ + "-f", "-N", + "-R", fmt.Sprintf("%s/http.sock:localhost:%s", c.remoteTempDir, localPort), + }, "") + if err != nil { + return gerrors.Wrap(err) + } + // \\n will be converted to \n by remote printf + nginxConf := strings.ReplaceAll(fmt.Sprintf(nginxConfFmt, c.hostname, publicPort, c.remoteTempDir, path.Base(c.remoteTempDir)), "\n", "\\n") + script := []string{ + fmt.Sprintf("sudo chown -R %s:www-data %s", c.user, c.remoteTempDir), + fmt.Sprintf("chmod 0770 %s", c.remoteTempDir), + fmt.Sprintf("chmod 0660 %s/http.sock", c.remoteTempDir), + // todo check if conflicts + fmt.Sprintf("printf '%s' | sudo tee /etc/nginx/sites-enabled/%s-%s.conf", nginxConf, publicPort, c.hostname), + fmt.Sprintf("sudo systemctl reload nginx.service"), + } + _, err = c.exec(nil, strings.Join(script, " && ")) + return gerrors.Wrap(err) +} + +func (c *SSHControl) Cleanup() { + // todo cleanup remote + _ = exec.Command("ssh", "-o", "ControlPath="+c.controlPath, "-O", "exit", c.hostname).Run() + _ = os.RemoveAll(c.localTempDir) +} + +// 1: hostname +// 2: port +// 3: temp dir +// 4: upstream name +var nginxConfFmt = `upstream %[4]s { + server unix:%[3]s/http.sock; +} + +server { + server_name %[1]s; + listen %[2]s; + + location / { + proxy_pass http://%[4]s; + proxy_set_header X-Real-IP $remote_addr; + proxy_set_header Host $host; + } +} +` diff --git a/runner/internal/models/backend.go b/runner/internal/models/backend.go index 363d3063e..a102a1378 100644 --- a/runner/internal/models/backend.go +++ b/runner/internal/models/backend.go @@ -16,63 +16,57 @@ type Resource struct { } type Job struct { - Apps []App `yaml:"apps"` - Artifacts []Artifact `yaml:"artifacts"` - Cache []Cache `yaml:"cache"` - BuildCommands []string `yaml:"build_commands"` - OptionalBuildCommands []string `yaml:"optional_build_commands"` - Setup []string `yaml:"setup"` - Commands []string `yaml:"commands"` - BuildPolicy BuildPolicy `yaml:"build_policy"` - Entrypoint []string `yaml:"entrypoint"` - Environment map[string]string `yaml:"env"` - RunEnvironment map[string]string `yaml:"run_env"` - HostName string `yaml:"host_name"` - Image string `yaml:"image_name"` - JobID string `yaml:"job_id"` - MasterJobID string `yaml:"master_job_id"` - Deps []Dep `yaml:"deps"` - ProviderName string `yaml:"provider_name"` - - RepoId string `yaml:"repo_id"` - RepoType string `yaml:"repo_type"` - HubUserName string `yaml:"hub_user_name"` - - RepoHostName string `yaml:"repo_host_name,omitempty"` - RepoPort int `yaml:"repo_port,omitempty"` - RepoUserName string `yaml:"repo_user_name,omitempty"` - RepoName string `yaml:"repo_name,omitempty"` - RepoBranch string `yaml:"repo_branch,omitempty"` - RepoHash string `yaml:"repo_hash,omitempty"` - RepoConfigName string `yaml:"repo_config_name,omitempty"` - RepoConfigEmail string `yaml:"repo_config_email,omitempty"` - - RepoCodeFilename string `yaml:"repo_code_filename"` - - RequestID string `yaml:"request_id"` - Location string `yaml:"location"` - Requirements Requirements `yaml:"requirements"` - RunName string `yaml:"run_name"` - RunnerID string `yaml:"runner_id"` - SpotPolicy string `yaml:"spot_policy"` - RetryPolicy RetryPolicy `yaml:"retry_policy"` - TerminationPolicy string `yaml:"termination_policy"` - MaxDuration uint64 `yaml:"max_duration,omitempty"` - Status string `yaml:"status"` - ErrorCode string `yaml:"error_code,omitempty"` - ContainerExitCode string `yaml:"container_exit_code,omitempty"` - CreatedAt uint64 `yaml:"created_at"` - SubmittedAt uint64 `yaml:"submitted_at"` - SubmissionNum int `yaml:"submission_num"` - TagName string `yaml:"tag_name"` - InstanceType string `yaml:"instance_type"` - ConfigurationPath string `yaml:"configuration_path"` - ConfigurationType string `yaml:"configuration_type"` - WorkflowName string `yaml:"workflow_name"` // deprecated - HomeDir string `yaml:"home_dir"` - WorkingDir string `yaml:"working_dir"` - - RegistryAuth RegistryAuth `yaml:"registry_auth"` + // apply omitempty to every Optional[] in pydantic model + AppNames []string `yaml:"app_names,omitempty"` // head + Apps []App `yaml:"app_specs,omitempty"` + ArtifactPaths []string `yaml:"artifact_paths,omitempty"` // head + Artifacts []Artifact `yaml:"artifact_specs,omitempty"` + BuildCommands []string `yaml:"build_commands,omitempty"` + BuildPolicy BuildPolicy `yaml:"build_policy"` + Cache []Cache `yaml:"cache_specs"` + Commands []string `yaml:"commands,omitempty"` + ConfigurationPath string `yaml:"configuration_path,omitempty"` // head + ConfigurationType ConfigurationType `yaml:"configuration_type,omitempty"` + ContainerExitCode string `yaml:"container_exit_code,omitempty"` // head + CreatedAt uint64 `yaml:"created_at"` + Deps []Dep `yaml:"dep_specs,omitempty"` + Entrypoint []string `yaml:"entrypoint,omitempty"` + Environment map[string]string `yaml:"env,omitempty"` + ErrorCode ErrorCode `yaml:"error_code,omitempty"` // head + Gateway Gateway `yaml:"gateway,omitempty"` + HomeDir string `yaml:"home_dir,omitempty"` + HostName string `yaml:"host_name,omitempty"` + HubUserName string `yaml:"hub_user_name"` // head + Image string `yaml:"image_name"` + InstanceSpotType string `yaml:"instance_spot_type,omitempty"` // head + InstanceType string `yaml:"instance_type,omitempty"` // head + JobID string `yaml:"job_id"` // head + Location string `yaml:"location,omitempty"` + MasterJobID string `yaml:"master_job,omitempty"` + MaxDuration uint64 `yaml:"max_duration,omitempty"` + ProviderName string `yaml:"provider_name,omitempty"` // deprecated, head + RegistryAuth RegistryAuth `yaml:"registry_auth,omitempty"` + RepoCodeFilename string `yaml:"repo_code_filename,omitempty"` + RepoData RepoData `yaml:"repo_data"` + RepoRef RepoRef `yaml:"repo_ref"` // head + RequestID string `yaml:"request_id,omitempty"` + Requirements Requirements `yaml:"requirements,omitempty"` + RetryPolicy RetryPolicy `yaml:"retry_policy,omitempty"` + RunName string `yaml:"run_name"` // head + RunnerID string `yaml:"runner_id,omitempty"` + Setup []string `yaml:"setup"` + SpotPolicy SpotPolicy `yaml:"spot_policy,omitempty"` + Status JobStatus `yaml:"status"` // head + SubmissionNum int `yaml:"submission_num"` + SubmittedAt uint64 `yaml:"submitted_at"` // head + TagName string `yaml:"tag_name,omitempty"` // head + TerminationPolicy string `yaml:"termination_policy,omitempty"` + WorkflowName string `yaml:"workflow_name,omitempty"` // deprecated, head + WorkingDir string `yaml:"working_dir,omitempty"` +} + +type RepoRef struct { + RepoId string `yaml:"repo_id"` } type Dep struct { @@ -138,11 +132,39 @@ type RegistryAuth struct { Password string `yaml:"password,omitempty"` } +type RepoData struct { + RepoType RepoType `yaml:"repo_type"` + // type=remote + RepoHostName string `yaml:"repo_host_name,omitempty"` + RepoPort int `yaml:"repo_port,omitempty"` + RepoUserName string `yaml:"repo_user_name,omitempty"` + RepoName string `yaml:"repo_name,omitempty"` + RepoBranch string `yaml:"repo_branch,omitempty"` + RepoHash string `yaml:"repo_hash,omitempty"` + RepoConfigName string `yaml:"repo_config_name,omitempty"` + RepoConfigEmail string `yaml:"repo_config_email,omitempty"` + // type=local + RepoDir string `yaml:"repo_dir"` +} + +type Gateway struct { + Hostname string `yaml:"hostname"` + SSHKey string `yaml:"ssh_key,omitempty"` + ServicePort int `yaml:"service_port"` + PublicPort int `yaml:"public_port"` +} + type RunnerMetadata struct { Status string `yaml:"status"` } +type ConfigurationType string +type ErrorCode string +type SpotPolicy string +type TerminationPolicy string +type JobStatus string type BuildPolicy string +type RepoType string const ( UseBuild BuildPolicy = "use-build" @@ -152,18 +174,18 @@ const ( ) func (j *Job) RepoHostNameWithPort() string { - if j.RepoPort == 0 { - return j.RepoHostName + if j.RepoData.RepoPort == 0 { + return j.RepoData.RepoHostName } - return fmt.Sprintf("%s:%d", j.RepoHostName, j.RepoPort) + return fmt.Sprintf("%s:%d", j.RepoData.RepoHostName, j.RepoData.RepoPort) } func (j *Job) JobFilepath() string { - return fmt.Sprintf("jobs/%s/%s.yaml", j.RepoId, j.JobID) + return fmt.Sprintf("jobs/%s/%s.yaml", j.RepoRef.RepoId, j.JobID) } func (j *Job) JobHeadFilepathPrefix() string { - return fmt.Sprintf("jobs/%s/l;%s;", j.RepoId, j.JobID) + return fmt.Sprintf("jobs/%s/l;%s;", j.RepoRef.RepoId, j.JobID) } func (j *Job) JobHeadFilepath() string { @@ -177,12 +199,12 @@ func (j *Job) JobHeadFilepath() string { } return fmt.Sprintf( "jobs/%s/l;%s;%s;%s;%d;%s;%s;%s;%s;%s;%s;%s", - j.RepoId, + j.RepoRef.RepoId, j.JobID, - j.ProviderName, + "", // ProviderName j.HubUserName, j.SubmittedAt, - strings.Join([]string{j.Status, j.ErrorCode, j.ContainerExitCode}, ","), + strings.Join([]string{string(j.Status), string(j.ErrorCode), j.ContainerExitCode}, ","), strings.Join(artifactSlice, ","), strings.Join(appsSlice, ","), j.TagName, @@ -200,7 +222,7 @@ func (j *Job) GetInstanceType() string { } func (j *Job) SecretsPrefix() string { - return fmt.Sprintf("secrets/%s/l;", j.RepoId) + return fmt.Sprintf("secrets/%s/l;", j.RepoRef.RepoId) } func (j *Job) MaxDurationExceeded() bool {