Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix gateway command bugs #645

Merged
merged 4 commits into from
Aug 15, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 18 additions & 1 deletion cli/dstack/_internal/backend/aws/gateway.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ def create_gateway_instance(
},
}
],
ImageId="ami-0cffefff2d52e0a23", # Ubuntu 22.04 LTS
ImageId=gateway_image_id(ec2_client),
InstanceType=machine_type,
MinCount=1,
MaxCount=1,
Expand Down Expand Up @@ -122,6 +122,23 @@ def gateway_security_group_id(
return security_group_id


def gateway_image_id(ec2_client: BaseClient) -> str:
response = ec2_client.describe_images(
Filters=[
{
"Name": "name",
"Values": ["ubuntu/images/hvm-ssd/ubuntu-jammy-22.04-amd64-server-*"],
},
{
"Name": "owner-alias",
"Values": ["amazon"],
},
],
)
image = sorted(response["Images"], key=lambda i: i["CreationDate"], reverse=True)[0]
return image["ImageId"]


def wait_till_running(
ec2_client: BaseClient, instance: dict, delay: int = 5, attempts: int = 30
) -> dict:
Expand Down
18 changes: 10 additions & 8 deletions cli/dstack/_internal/backend/azure/compute.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,8 +164,14 @@ def create_gateway(self, instance_name: str, ssh_key_pub: str) -> GatewayHead:
subscription_id=self.azure_config.subscription_id,
location=self.azure_config.location,
resource_group=self.azure_config.resource_group,
network=self.azure_config.network,
subnet=self.azure_config.subnet,
network=azure_utils.get_default_network_name(
storage_account=self.azure_config.storage_account,
location=self.azure_config.location,
),
subnet=azure_utils.get_default_subnet_name(
storage_account=self.azure_config.storage_account,
location=self.azure_config.location,
),
instance_name=instance_name,
ssh_key_pub=ssh_key_pub,
)
Expand All @@ -176,14 +182,10 @@ def create_gateway(self, instance_name: str, ssh_key_pub: str) -> GatewayHead:
"resource_name"
],
)
public_ip = gateway.get_public_ip(
network_client=self._network_client,
resource_group=self.azure_config.resource_group,
public_ip=interface.ip_configurations[0].public_ip_address.name,
)
public_ip = interface.ip_configurations[0].public_ip_address.ip_address
return GatewayHead(
instance_name=instance_name,
external_ip=public_ip.ip_address,
external_ip=public_ip,
internal_ip=interface.ip_configurations[0].private_ip_address,
)

Expand Down
10 changes: 3 additions & 7 deletions cli/dstack/_internal/backend/azure/gateway.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,13 +177,9 @@ def gateway_network_security_group(
def get_network_interface(
network_client: NetworkManagementClient, resource_group: str, interface: str
) -> NetworkInterface:
return network_client.network_interfaces.get(resource_group, interface)


def get_public_ip(
network_client: NetworkManagementClient, resource_group: str, public_ip: str
) -> PublicIPAddress:
return network_client.public_ip_addresses.get(resource_group, public_ip)
return network_client.network_interfaces.get(
resource_group, interface, expand="IPConfigurations/PublicIPAddress"
)


def gateway_user_data_script() -> str:
Expand Down
42 changes: 25 additions & 17 deletions cli/dstack/_internal/cli/commands/gateway/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from argparse import Namespace
from typing import List
from typing import Dict, List

from rich.prompt import Confirm
from rich.table import Table
Expand Down Expand Up @@ -33,6 +33,7 @@ def register(self):
"create", help="Create a gateway", formatter_class=RichHelpFormatter
)
add_project_argument(create_parser)
create_parser.add_argument("--backend", choices=["aws", "gcp", "azure"], required=True)
create_parser.set_defaults(sub_func=self.create_gateway)

delete_gateway_parser = subparsers.add_parser(
Expand All @@ -54,31 +55,38 @@ def _command(self, args: Namespace):

def create_gateway(self, hub_client: HubClient, args: Namespace):
print("Creating gateway, it may take some time...")
head = hub_client.create_gateway()
print_gateways_table([head])
head = hub_client.create_gateway(backend=args.backend)
print_gateways_table({args.backend: [head]})

def list_gateways(self, hub_client: HubClient, args: Namespace):
heads = hub_client.list_gateways()
print_gateways_table(heads)
backends = hub_client.list_gateways()
print_gateways_table(backends)

def delete_gateway(self, hub_client: HubClient, args: Namespace):
heads = hub_client.list_gateways()
if args.instance_name not in [head.instance_name for head in heads]:
backends = hub_client.list_gateways()
for backend, heads in backends.items():
for head in heads:
if args.instance_name != head.instance_name:
continue
if args.yes or Confirm.ask(f"[red]Delete the gateway '{args.instance_name}'?[/]"):
hub_client.delete_gateway(args.instance_name, backend=backend)
console.print("Gateway is deleted")
return
else:
exit(f"No such gateway '{args.instance_name}'")
if args.yes or Confirm.ask(f"[red]Delete the gateway '{args.instance_name}'?[/]"):
hub_client.delete_gateway(args.instance_name)
console.print("Gateway is deleted")
exit(0)


def print_gateways_table(heads: List[GatewayHead]):
def print_gateways_table(backends: Dict[str, List[GatewayHead]]):
table = Table(box=None)
table.add_column("BACKEND")
table.add_column("NAME")
table.add_column("ADDRESS")
for head in heads:
table.add_row(
head.instance_name,
head.external_ip,
)
for backend, heads in backends.items():
for i, head in enumerate(heads):
table.add_row(
backend if i == 0 else "",
head.instance_name,
head.external_ip,
)
console.print(table)
console.print()
4 changes: 2 additions & 2 deletions cli/dstack/_internal/configurators/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ def apply_args(self, args: argparse.Namespace):
self.profile.resources.gpu = ProfileGPU.parse_obj(gpu)

if args.max_price is not None:
self.profile.resources.max_price = args.max_price
self.profile.max_price = args.max_price

if args.spot_policy is not None:
self.profile.spot_policy = args.spot_policy
Expand Down Expand Up @@ -284,7 +284,7 @@ def requirements(self) -> job.Requirements:
memory_mib=self.profile.resources.memory,
gpus=None,
shm_size_mib=self.profile.resources.shm_size,
max_price=self.profile.resources.max_price,
max_price=self.profile.max_price,
)
if self.profile.resources.gpu:
r.gpus = job.GpusRequirements(
Expand Down
2 changes: 2 additions & 0 deletions cli/dstack/_internal/core/gateway.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import time
from typing import Optional

from pydantic import Field

Expand All @@ -10,6 +11,7 @@ class GatewayHead(BaseHead):
external_ip: str
internal_ip: str
created_at: int = Field(default_factory=lambda: int(time.time() * 1000))
wildcard_domain: Optional[str]

@classmethod
def prefix(cls) -> str:
Expand Down
6 changes: 3 additions & 3 deletions cli/dstack/_internal/core/profile.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,9 +90,6 @@ class ProfileResources(ForbidExtra):
),
]
cpu: int = DEFAULT_CPU
max_price: Annotated[
Optional[confloat(gt=0.0)], Field(description="The maximum price per hour, $")
]
_validate_mem = validator("memory", "shm_size", pre=True, allow_reuse=True)(parse_memory)

@validator("gpu", pre=True)
Expand Down Expand Up @@ -134,6 +131,9 @@ class Profile(ForbidExtra):
description="The maximum duration of a run (e.g., 2h, 1d, etc). After it elapses, the run is forced to stop"
),
]
max_price: Annotated[
Optional[confloat(gt=0.0)], Field(description="The maximum price per hour, $")
]
default: bool = False
_validate_max_duration = validator("max_duration", pre=True, allow_reuse=True)(
parse_max_duration
Expand Down
25 changes: 15 additions & 10 deletions cli/dstack/_internal/hub/routers/gateways.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
from typing import List
import asyncio
from typing import Dict, List

from fastapi import APIRouter, Body, Depends, HTTPException, status

from dstack._internal.core.gateway import GatewayHead
from dstack._internal.hub.routers.util import call_backend, error_detail, get_backends, get_project
from dstack._internal.hub.schemas import GatewayDelete
from dstack._internal.hub.security.permissions import ProjectAdmin, ProjectMember
from dstack._internal.hub.utils.common import run_async
from dstack._internal.hub.utils.ssh import get_hub_ssh_public_key

router = APIRouter(
Expand All @@ -14,10 +15,12 @@


@router.post("/{project_name}/gateways/create", dependencies=[Depends(ProjectAdmin())])
async def gateways_create(project_name: str) -> GatewayHead:
async def gateways_create(project_name: str, backend_name: str = Body()) -> GatewayHead:
project = await get_project(project_name=project_name)
backends = await get_backends(project)
for _, backend in backends:
if backend.name != backend_name:
continue
try:
return await call_backend(backend.create_gateway, get_hub_ssh_public_key())
except NotImplementedError:
Expand All @@ -32,18 +35,20 @@ async def gateways_create(project_name: str) -> GatewayHead:


@router.get("/{project_name}/gateways")
async def gateways_list(project_name: str) -> List[GatewayHead]:
async def gateways_list(project_name: str) -> Dict[str, List[GatewayHead]]:
project = await get_project(project_name=project_name)
backends = await get_backends(project)
gateways = []
for _, backend in backends:
gateways += await call_backend(backend.list_gateways)
return gateways
tasks = [call_backend(backend.list_gateways) for _, backend in backends]
return {
backend.name: gateways
for (_, backend), gateways in zip(backends, await asyncio.gather(*tasks))
}


@router.post("/{project_name}/gateways/delete", dependencies=[Depends(ProjectAdmin())])
async def gateways_delete(project_name: str, instance_name: str = Body()):
async def gateways_delete(project_name: str, body: GatewayDelete = Body()):
project = await get_project(project_name=project_name)
backends = await get_backends(project)
for _, backend in backends:
await call_backend(backend.delete_gateway, instance_name)
if backend.name == body.backend:
await call_backend(backend.delete_gateway, body.instance_name)
5 changes: 5 additions & 0 deletions cli/dstack/_internal/hub/schemas/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -484,3 +484,8 @@ class DeleteUsers(BaseModel):

class FileObject(BaseModel):
object_key: str


class GatewayDelete(BaseModel):
instance_name: str
backend: str
13 changes: 7 additions & 6 deletions cli/dstack/api/hub/_api_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,13 +19,13 @@
from dstack._internal.core.log_event import LogEvent
from dstack._internal.core.plan import RunPlan
from dstack._internal.core.repo import RemoteRepoCredentials, Repo, RepoHead, RepoSpec
from dstack._internal.core.run import RunHead
from dstack._internal.core.secret import Secret
from dstack._internal.core.tag import TagHead
from dstack._internal.hub.schemas import (
AddTagRun,
ArtifactsList,
BackendInfo,
GatewayDelete,
JobHeadList,
JobsGet,
JobsList,
Expand Down Expand Up @@ -721,13 +721,14 @@ def delete_configuration_cache(self, configuration_path: str):
return
resp.raise_for_status()

def create_gateway(self) -> GatewayHead:
def create_gateway(self, backend: str) -> GatewayHead:
url = _project_url(url=self.url, project=self.project, additional_path="/gateways/create")
resp = _make_hub_request(
requests.post,
host=self.url,
url=url,
headers=self._headers(),
data=json.dumps(backend),
)
if resp.ok:
return GatewayHead.parse_obj(resp.json())
Expand All @@ -737,7 +738,7 @@ def create_gateway(self) -> GatewayHead:
raise HubClientError(body["detail"]["msg"])
resp.raise_for_status()

def list_gateways(self) -> List[GatewayHead]:
def list_gateways(self) -> Dict[str, List[GatewayHead]]:
url = _project_url(url=self.url, project=self.project, additional_path="/gateways")
resp = _make_hub_request(
requests.get,
Expand All @@ -747,16 +748,16 @@ def list_gateways(self) -> List[GatewayHead]:
)
if not resp.ok:
resp.raise_for_status()
return parse_obj_as(List[GatewayHead], resp.json())
return parse_obj_as(Dict[str, List[GatewayHead]], resp.json())

def delete_gateway(self, instance_name: str):
def delete_gateway(self, instance_name: str, backend: str):
url = _project_url(url=self.url, project=self.project, additional_path="/gateways/delete")
resp = _make_hub_request(
requests.post,
host=self.url,
url=url,
headers=self._headers(),
data=json.dumps(instance_name),
data=GatewayDelete(instance_name=instance_name, backend=backend).json(),
)
resp.raise_for_status()

Expand Down
17 changes: 7 additions & 10 deletions cli/dstack/api/hub/_client.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,9 @@
import copy
import sys
import tempfile
import time
import urllib.parse
from datetime import datetime
from pathlib import Path
from typing import Generator, List, Optional, Tuple
from typing import Dict, Generator, List, Optional, Tuple

import dstack._internal.configurators as configurators
from dstack._internal.api.repos import get_local_repo_credentials
Expand All @@ -17,10 +15,9 @@
from dstack._internal.core.plan import RunPlan
from dstack._internal.core.repo import RemoteRepoCredentials, Repo, RepoHead
from dstack._internal.core.repo.remote import RemoteRepo
from dstack._internal.core.run import RunHead
from dstack._internal.core.secret import Secret
from dstack._internal.core.tag import TagHead
from dstack._internal.hub.schemas import BackendInfo, ProjectInfo, RunInfo
from dstack._internal.hub.schemas import BackendInfo, RunInfo
from dstack.api.hub._api_client import HubAPIClient
from dstack.api.hub._config import HubClientConfig
from dstack.api.hub._storage import HUBStorage
Expand Down Expand Up @@ -281,11 +278,11 @@ def run_configuration(
self.update_repo_last_run_at(last_run_at=int(round(time.time() * 1000)))
return run_name, jobs

def create_gateway(self) -> GatewayHead:
return self._api_client.create_gateway()
def create_gateway(self, backend: str) -> GatewayHead:
return self._api_client.create_gateway(backend=backend)

def list_gateways(self) -> List[GatewayHead]:
def list_gateways(self) -> Dict[str, List[GatewayHead]]:
return self._api_client.list_gateways()

def delete_gateway(self, instance_name: str):
self._api_client.delete_gateway(instance_name)
def delete_gateway(self, instance_name: str, backend: str):
self._api_client.delete_gateway(instance_name, backend=backend)
Loading
Loading