Skip to content

Commit

Permalink
Introduce gateways for services publication (#596)
Browse files Browse the repository at this point in the history
* Add service configuration type

* Use pydantic serialization for Job

* Add gateway to Job model, generate new keypair on run_job

* Deploy public key for service configuration

* Establish an SSH tunnel with the gateway, configure nginx

* Cleanup on exit, handle multiple domains

* GCP: create gateway

* Fix BaseHead tests

* Resolve merge bug

* Delete gateway, prettify commands

* Update gateway docs
  • Loading branch information
Egor-S authored Jul 28, 2023
1 parent 04b07dc commit 9bb93b6
Show file tree
Hide file tree
Showing 41 changed files with 1,117 additions and 504 deletions.
23 changes: 23 additions & 0 deletions cli/dstack/_internal/backend/base/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from datetime import datetime
from typing import Generator, List, Optional

import dstack._internal.backend.base.gateway as gateway
import dstack._internal.core.build
from dstack._internal.backend.base import artifacts as base_artifacts
from dstack._internal.backend.base import build as base_build
Expand All @@ -17,6 +18,7 @@
from dstack._internal.backend.base.storage import Storage
from dstack._internal.core.artifact import Artifact
from dstack._internal.core.build import BuildPlan
from dstack._internal.core.gateway import GatewayHead
from dstack._internal.core.instance import InstanceType
from dstack._internal.core.job import Job, JobHead, JobStatus
from dstack._internal.core.log_event import LogEvent
Expand Down Expand Up @@ -248,6 +250,18 @@ def get_signed_upload_url(self, object_key: str) -> str:
def predict_build_plan(self, job: Job) -> BuildPlan:
pass

@abstractmethod
def create_gateway(self, ssh_key_pub: str) -> GatewayHead:
pass

@abstractmethod
def list_gateways(self) -> List[GatewayHead]:
pass

@abstractmethod
def delete_gateway(self, instance_name: str):
pass


class ComponentBasedBackend(Backend):
@abstractmethod
Expand Down Expand Up @@ -468,3 +482,12 @@ def predict_build_plan(self, job: Job) -> BuildPlan:
return base_build.predict_build_plan(
self.storage(), job, dstack._internal.core.build.DockerPlatform.amd64
)

def create_gateway(self, ssh_key_pub: str) -> GatewayHead:
return gateway.create_gateway(self.compute(), self.storage(), ssh_key_pub)

def list_gateways(self) -> List[GatewayHead]:
return gateway.list_gateways(self.storage())

def delete_gateway(self, instance_name: str):
gateway.delete_gateway(self.compute(), self.storage(), instance_name)
2 changes: 0 additions & 2 deletions cli/dstack/_internal/backend/base/build.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,6 @@ def predict_build_plan(
raise BuildNotFoundError("Build not found. Run `dstack build` or add `--build` flag")
return BuildPlan.yes

if job.optional_build_commands and job.build_policy == BuildPolicy.BUILD:
return BuildPlan.yes
return BuildPlan.no


Expand Down
9 changes: 9 additions & 0 deletions cli/dstack/_internal/backend/base/compute.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

import dstack.version as version
from dstack._internal.core.error import BackendError
from dstack._internal.core.gateway import GatewayHead
from dstack._internal.core.instance import InstanceType, LaunchedInstanceInfo
from dstack._internal.core.job import Job, Requirements
from dstack._internal.core.request import RequestHead
Expand Down Expand Up @@ -49,6 +50,14 @@ def terminate_instance(self, runner: Runner):
def cancel_spot_request(self, runner: Runner):
pass

def create_gateway(self, instance_name: str, ssh_key_pub: str) -> GatewayHead:
# todo make abstract & implement for each backend
raise NotImplementedError()

def delete_instance(self, instance_name: str):
# todo make abstract & implement for each backend
raise NotImplementedError()


def choose_instance_type(
instance_types: List[InstanceType],
Expand Down
70 changes: 70 additions & 0 deletions cli/dstack/_internal/backend/base/gateway.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
import subprocess
import time
from typing import List, Optional

from dstack._internal.backend.base.compute import Compute
from dstack._internal.backend.base.head import (
delete_head_object,
list_head_objects,
put_head_object,
)
from dstack._internal.backend.base.storage import Storage
from dstack._internal.core.error import DstackError
from dstack._internal.core.gateway import GatewayHead
from dstack._internal.hub.utils.ssh import HUB_PRIVATE_KEY_PATH
from dstack._internal.utils.common import PathLike
from dstack._internal.utils.random_names import generate_name


def create_gateway(compute: Compute, storage: Storage, ssh_key_pub: str) -> GatewayHead:
# todo generate while instance name is not unique
instance_name = f"dstack-gateway-{generate_name()}"
head = compute.create_gateway(instance_name, ssh_key_pub)
put_head_object(storage, head)
return head


def list_gateways(storage: Storage) -> List[GatewayHead]:
return list_head_objects(storage, GatewayHead)


def delete_gateway(compute: Compute, storage: Storage, instance_name: str):
heads = list_gateways(storage)
for head in heads:
if head.instance_name != instance_name:
continue
compute.delete_instance(instance_name)
delete_head_object(storage, head)


def ssh_copy_id(
hostname: str,
public_key: bytes,
user: str = "ubuntu",
id_rsa: Optional[PathLike] = HUB_PRIVATE_KEY_PATH,
):
command = f"echo '{public_key.decode()}' >> ~/.ssh/authorized_keys"
exec_ssh_command(hostname, command, user=user, id_rsa=id_rsa)


def exec_ssh_command(hostname: str, command: str, user: str, id_rsa: Optional[PathLike]) -> bytes:
args = ["ssh"]
if id_rsa is not None:
args += ["-i", id_rsa]
args += [
"-o",
"StrictHostKeyChecking=accept-new",
f"{user}@{hostname}",
command,
]
proc = subprocess.Popen(args, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
stdout, stderr = proc.communicate()
if proc.returncode != 0:
raise SSHCommandError(args, stderr.decode())
return stdout


class SSHCommandError(DstackError):
def __init__(self, cmd: List[str], message: str):
super().__init__(message)
self.cmd = cmd
19 changes: 19 additions & 0 deletions cli/dstack/_internal/backend/base/head.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
from typing import List, Type

from dstack._internal.backend.base.storage import Storage
from dstack._internal.core.head import BaseHead, T


def put_head_object(storage: Storage, head: BaseHead) -> str:
key = head.encode()
storage.put_object(key, content="")
return key


def list_head_objects(storage: Storage, cls: Type[T]) -> List[T]:
keys = storage.list_objects(cls.prefix())
return [cls.decode(key) for key in keys]


def delete_head_object(storage, head: BaseHead):
storage.delete_object(head.encode())
8 changes: 8 additions & 0 deletions cli/dstack/_internal/backend/base/jobs.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,14 @@

import yaml

import dstack._internal.backend.base.gateway as gateway
from dstack._internal.backend.base import runners
from dstack._internal.backend.base.compute import Compute, InstanceNotFoundError, NoCapacityError
from dstack._internal.backend.base.storage import Storage
from dstack._internal.core.error import BackendError, BackendValueError, NoMatchingInstanceError
from dstack._internal.core.instance import InstanceType
from dstack._internal.core.job import (
ConfigurationType,
Job,
JobErrorCode,
JobHead,
Expand All @@ -18,6 +20,7 @@
from dstack._internal.core.repo import RepoRef
from dstack._internal.core.runners import Runner
from dstack._internal.utils.common import get_milliseconds_since_epoch
from dstack._internal.utils.crypto import generate_rsa_key_pair_bytes
from dstack._internal.utils.escape import escape_head, unescape_head
from dstack._internal.utils.logging import get_logger

Expand Down Expand Up @@ -118,6 +121,11 @@ def run_job(
if job.status != JobStatus.SUBMITTED:
raise BackendError("Can't create a request for a job which status is not SUBMITTED")
try:
if job.configuration_type == ConfigurationType.SERVICE:
private_bytes, public_bytes = generate_rsa_key_pair_bytes(comment=job.run_name)
gateway.ssh_copy_id(job.gateway.hostname, public_bytes)
job.gateway.ssh_key = private_bytes.decode()
update_job(storage, job)
_try_run_job(
storage=storage,
compute=compute,
Expand Down
33 changes: 32 additions & 1 deletion cli/dstack/_internal/backend/gcp/compute.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from google.cloud import compute_v1
from google.oauth2 import service_account

import dstack._internal.backend.gcp.gateway as gateway
from dstack import version
from dstack._internal.backend.base.compute import (
WS_PORT,
Expand All @@ -20,6 +21,7 @@
from dstack._internal.backend.gcp import utils as gcp_utils
from dstack._internal.backend.gcp.config import GCPConfig
from dstack._internal.core.error import BackendValueError
from dstack._internal.core.gateway import GatewayHead
from dstack._internal.core.instance import InstanceType, LaunchedInstanceInfo
from dstack._internal.core.job import Job, Requirements
from dstack._internal.core.request import RequestHead, RequestStatus
Expand Down Expand Up @@ -157,6 +159,35 @@ def cancel_spot_request(self, runner: Runner):
instance_name=runner.request_id,
)

def create_gateway(self, instance_name: str, ssh_key_pub: str) -> GatewayHead:
instance = gateway.create_gateway_instance(
instances_client=self.instances_client,
firewalls_client=self.firewalls_client,
project_id=self.gcp_config.project_id,
network=_get_network_resource(self.gcp_config.vpc),
subnet=_get_subnet_resource(self.gcp_config.region, self.gcp_config.subnet),
zone=self.gcp_config.zone,
instance_name=instance_name,
service_account=self.credentials.service_account_email,
labels=dict(
role="gateway",
owner="dstack",
),
ssh_key_pub=ssh_key_pub,
)
return GatewayHead(
instance_name=instance_name,
external_ip=instance.network_interfaces[0].access_configs[0].nat_i_p,
internal_ip=instance.network_interfaces[0].network_i_p,
)

def delete_instance(self, instance_name: str):
_terminate_instance(
client=self.instances_client,
gcp_config=self.gcp_config,
instance_name=instance_name,
)


def _get_instance_status(
instances_client: compute_v1.InstancesClient,
Expand Down Expand Up @@ -740,7 +771,7 @@ def _create_firewall_rules(
network: str = "global/networks/default",
):
"""
Creates a simple firewall rule allowing for incoming HTTP and HTTPS access from the entire Internet.
Creates a simple firewall rule allowing for incoming SSH access from the entire Internet.
Args:
project_id: project ID or project number of the Cloud project you want to use.
Expand Down
117 changes: 117 additions & 0 deletions cli/dstack/_internal/backend/gcp/gateway.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@
from typing import Dict, List

import google.api_core.exceptions
from google.cloud import compute_v1

import dstack._internal.backend.gcp.utils as gcp_utils

DSTACK_GATEWAY_TAG = "dstack-gateway"


def create_gateway_instance(
instances_client: compute_v1.InstancesClient,
firewalls_client: compute_v1.FirewallsClient,
project_id: str,
network: str,
subnet: str,
zone: str,
instance_name: str,
service_account: str,
labels: Dict[str, str],
ssh_key_pub: str,
machine_type: str = "e2-micro",
) -> compute_v1.Instance:
try:
create_gateway_firewall_rules(
firewalls_client=firewalls_client,
project_id=project_id,
network=network,
)
except google.api_core.exceptions.Conflict:
pass

network_interface = compute_v1.NetworkInterface()
network_interface.name = network
network_interface.subnetwork = subnet

access = compute_v1.AccessConfig()
access.type_ = compute_v1.AccessConfig.Type.ONE_TO_ONE_NAT.name
access.name = "External NAT"
access.network_tier = access.NetworkTier.PREMIUM.name
network_interface.access_configs = [access]

instance = compute_v1.Instance()
instance.network_interfaces = [network_interface]
instance.name = instance_name
instance.disks = gateway_disks(zone)
instance.machine_type = f"zones/{zone}/machineTypes/{machine_type}"

metadata_items = [
compute_v1.Items(key="ssh-keys", value=f"ubuntu:{ssh_key_pub}"),
compute_v1.Items(key="user-data", value=gateway_user_data_script()),
]
instance.metadata = compute_v1.Metadata(items=metadata_items)
instance.labels = labels
instance.tags = compute_v1.Tags(items=[DSTACK_GATEWAY_TAG]) # to apply firewall rules

instance.service_accounts = [
compute_v1.ServiceAccount(
email=service_account,
scopes=["https://www.googleapis.com/auth/cloud-platform"],
)
]

request = compute_v1.InsertInstanceRequest()
request.zone = zone
request.project = project_id
request.instance_resource = instance
operation = instances_client.insert(request=request)
gcp_utils.wait_for_extended_operation(operation, "instance creation")

return instances_client.get(project=project_id, zone=zone, instance=instance_name)


def create_gateway_firewall_rules(
firewalls_client: compute_v1.FirewallsClient,
project_id: str,
network: str,
):
firewall_rule = compute_v1.Firewall()
firewall_rule.name = "dstack-gateway-in-" + network.replace("/", "-")
firewall_rule.direction = "INGRESS"

allowed_ports = compute_v1.Allowed()
allowed_ports.I_p_protocol = "tcp"
allowed_ports.ports = ["22", "80", "443"]

firewall_rule.allowed = [allowed_ports]
firewall_rule.source_ranges = ["0.0.0.0/0"]
firewall_rule.network = network
firewall_rule.description = "Allowing TCP traffic on ports 22, 80, and 443 from Internet."

firewall_rule.target_tags = [DSTACK_GATEWAY_TAG]

operation = firewalls_client.insert(project=project_id, firewall_resource=firewall_rule)
gcp_utils.wait_for_extended_operation(operation, "firewall rule creation")


def gateway_disks(zone: str) -> List[compute_v1.AttachedDisk]:
disk = compute_v1.AttachedDisk()

initialize_params = compute_v1.AttachedDiskInitializeParams()
initialize_params.source_image = (
"projects/ubuntu-os-cloud/global/images/ubuntu-2204-jammy-v20230714"
)
initialize_params.disk_size_gb = 10
initialize_params.disk_type = f"zones/{zone}/diskTypes/pd-balanced"

disk.initialize_params = initialize_params
disk.auto_delete = True
disk.boot = True
return [disk]


def gateway_user_data_script() -> str:
return f"""#!/bin/sh
sudo apt-get update
DEBIAN_FRONTEND=noninteractive sudo apt-get install -y -q nginx"""
Loading

0 comments on commit 9bb93b6

Please sign in to comment.