-
Notifications
You must be signed in to change notification settings - Fork 140
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Introduce gateways for services publication (#596)
* Add service configuration type * Use pydantic serialization for Job * Add gateway to Job model, generate new keypair on run_job * Deploy public key for service configuration * Establish an SSH tunnel with the gateway, configure nginx * Cleanup on exit, handle multiple domains * GCP: create gateway * Fix BaseHead tests * Resolve merge bug * Delete gateway, prettify commands * Update gateway docs
- Loading branch information
Showing
41 changed files
with
1,117 additions
and
504 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,70 @@ | ||
import subprocess | ||
import time | ||
from typing import List, Optional | ||
|
||
from dstack._internal.backend.base.compute import Compute | ||
from dstack._internal.backend.base.head import ( | ||
delete_head_object, | ||
list_head_objects, | ||
put_head_object, | ||
) | ||
from dstack._internal.backend.base.storage import Storage | ||
from dstack._internal.core.error import DstackError | ||
from dstack._internal.core.gateway import GatewayHead | ||
from dstack._internal.hub.utils.ssh import HUB_PRIVATE_KEY_PATH | ||
from dstack._internal.utils.common import PathLike | ||
from dstack._internal.utils.random_names import generate_name | ||
|
||
|
||
def create_gateway(compute: Compute, storage: Storage, ssh_key_pub: str) -> GatewayHead: | ||
# todo generate while instance name is not unique | ||
instance_name = f"dstack-gateway-{generate_name()}" | ||
head = compute.create_gateway(instance_name, ssh_key_pub) | ||
put_head_object(storage, head) | ||
return head | ||
|
||
|
||
def list_gateways(storage: Storage) -> List[GatewayHead]: | ||
return list_head_objects(storage, GatewayHead) | ||
|
||
|
||
def delete_gateway(compute: Compute, storage: Storage, instance_name: str): | ||
heads = list_gateways(storage) | ||
for head in heads: | ||
if head.instance_name != instance_name: | ||
continue | ||
compute.delete_instance(instance_name) | ||
delete_head_object(storage, head) | ||
|
||
|
||
def ssh_copy_id( | ||
hostname: str, | ||
public_key: bytes, | ||
user: str = "ubuntu", | ||
id_rsa: Optional[PathLike] = HUB_PRIVATE_KEY_PATH, | ||
): | ||
command = f"echo '{public_key.decode()}' >> ~/.ssh/authorized_keys" | ||
exec_ssh_command(hostname, command, user=user, id_rsa=id_rsa) | ||
|
||
|
||
def exec_ssh_command(hostname: str, command: str, user: str, id_rsa: Optional[PathLike]) -> bytes: | ||
args = ["ssh"] | ||
if id_rsa is not None: | ||
args += ["-i", id_rsa] | ||
args += [ | ||
"-o", | ||
"StrictHostKeyChecking=accept-new", | ||
f"{user}@{hostname}", | ||
command, | ||
] | ||
proc = subprocess.Popen(args, stdout=subprocess.PIPE, stderr=subprocess.PIPE) | ||
stdout, stderr = proc.communicate() | ||
if proc.returncode != 0: | ||
raise SSHCommandError(args, stderr.decode()) | ||
return stdout | ||
|
||
|
||
class SSHCommandError(DstackError): | ||
def __init__(self, cmd: List[str], message: str): | ||
super().__init__(message) | ||
self.cmd = cmd |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,19 @@ | ||
from typing import List, Type | ||
|
||
from dstack._internal.backend.base.storage import Storage | ||
from dstack._internal.core.head import BaseHead, T | ||
|
||
|
||
def put_head_object(storage: Storage, head: BaseHead) -> str: | ||
key = head.encode() | ||
storage.put_object(key, content="") | ||
return key | ||
|
||
|
||
def list_head_objects(storage: Storage, cls: Type[T]) -> List[T]: | ||
keys = storage.list_objects(cls.prefix()) | ||
return [cls.decode(key) for key in keys] | ||
|
||
|
||
def delete_head_object(storage, head: BaseHead): | ||
storage.delete_object(head.encode()) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,117 @@ | ||
from typing import Dict, List | ||
|
||
import google.api_core.exceptions | ||
from google.cloud import compute_v1 | ||
|
||
import dstack._internal.backend.gcp.utils as gcp_utils | ||
|
||
DSTACK_GATEWAY_TAG = "dstack-gateway" | ||
|
||
|
||
def create_gateway_instance( | ||
instances_client: compute_v1.InstancesClient, | ||
firewalls_client: compute_v1.FirewallsClient, | ||
project_id: str, | ||
network: str, | ||
subnet: str, | ||
zone: str, | ||
instance_name: str, | ||
service_account: str, | ||
labels: Dict[str, str], | ||
ssh_key_pub: str, | ||
machine_type: str = "e2-micro", | ||
) -> compute_v1.Instance: | ||
try: | ||
create_gateway_firewall_rules( | ||
firewalls_client=firewalls_client, | ||
project_id=project_id, | ||
network=network, | ||
) | ||
except google.api_core.exceptions.Conflict: | ||
pass | ||
|
||
network_interface = compute_v1.NetworkInterface() | ||
network_interface.name = network | ||
network_interface.subnetwork = subnet | ||
|
||
access = compute_v1.AccessConfig() | ||
access.type_ = compute_v1.AccessConfig.Type.ONE_TO_ONE_NAT.name | ||
access.name = "External NAT" | ||
access.network_tier = access.NetworkTier.PREMIUM.name | ||
network_interface.access_configs = [access] | ||
|
||
instance = compute_v1.Instance() | ||
instance.network_interfaces = [network_interface] | ||
instance.name = instance_name | ||
instance.disks = gateway_disks(zone) | ||
instance.machine_type = f"zones/{zone}/machineTypes/{machine_type}" | ||
|
||
metadata_items = [ | ||
compute_v1.Items(key="ssh-keys", value=f"ubuntu:{ssh_key_pub}"), | ||
compute_v1.Items(key="user-data", value=gateway_user_data_script()), | ||
] | ||
instance.metadata = compute_v1.Metadata(items=metadata_items) | ||
instance.labels = labels | ||
instance.tags = compute_v1.Tags(items=[DSTACK_GATEWAY_TAG]) # to apply firewall rules | ||
|
||
instance.service_accounts = [ | ||
compute_v1.ServiceAccount( | ||
email=service_account, | ||
scopes=["https://www.googleapis.com/auth/cloud-platform"], | ||
) | ||
] | ||
|
||
request = compute_v1.InsertInstanceRequest() | ||
request.zone = zone | ||
request.project = project_id | ||
request.instance_resource = instance | ||
operation = instances_client.insert(request=request) | ||
gcp_utils.wait_for_extended_operation(operation, "instance creation") | ||
|
||
return instances_client.get(project=project_id, zone=zone, instance=instance_name) | ||
|
||
|
||
def create_gateway_firewall_rules( | ||
firewalls_client: compute_v1.FirewallsClient, | ||
project_id: str, | ||
network: str, | ||
): | ||
firewall_rule = compute_v1.Firewall() | ||
firewall_rule.name = "dstack-gateway-in-" + network.replace("/", "-") | ||
firewall_rule.direction = "INGRESS" | ||
|
||
allowed_ports = compute_v1.Allowed() | ||
allowed_ports.I_p_protocol = "tcp" | ||
allowed_ports.ports = ["22", "80", "443"] | ||
|
||
firewall_rule.allowed = [allowed_ports] | ||
firewall_rule.source_ranges = ["0.0.0.0/0"] | ||
firewall_rule.network = network | ||
firewall_rule.description = "Allowing TCP traffic on ports 22, 80, and 443 from Internet." | ||
|
||
firewall_rule.target_tags = [DSTACK_GATEWAY_TAG] | ||
|
||
operation = firewalls_client.insert(project=project_id, firewall_resource=firewall_rule) | ||
gcp_utils.wait_for_extended_operation(operation, "firewall rule creation") | ||
|
||
|
||
def gateway_disks(zone: str) -> List[compute_v1.AttachedDisk]: | ||
disk = compute_v1.AttachedDisk() | ||
|
||
initialize_params = compute_v1.AttachedDiskInitializeParams() | ||
initialize_params.source_image = ( | ||
"projects/ubuntu-os-cloud/global/images/ubuntu-2204-jammy-v20230714" | ||
) | ||
initialize_params.disk_size_gb = 10 | ||
initialize_params.disk_type = f"zones/{zone}/diskTypes/pd-balanced" | ||
|
||
disk.initialize_params = initialize_params | ||
disk.auto_delete = True | ||
disk.boot = True | ||
return [disk] | ||
|
||
|
||
def gateway_user_data_script() -> str: | ||
return f"""#!/bin/sh | ||
sudo apt-get update | ||
DEBIAN_FRONTEND=noninteractive sudo apt-get install -y -q nginx""" |
Oops, something went wrong.