diff --git a/cli/dstack/_internal/backend/aws/gateway.py b/cli/dstack/_internal/backend/aws/gateway.py index 0d8f195e6..23a397887 100644 --- a/cli/dstack/_internal/backend/aws/gateway.py +++ b/cli/dstack/_internal/backend/aws/gateway.py @@ -44,7 +44,7 @@ def create_gateway_instance( }, } ], - ImageId="ami-0cffefff2d52e0a23", # Ubuntu 22.04 LTS + ImageId=gateway_image_id(ec2_client), InstanceType=machine_type, MinCount=1, MaxCount=1, @@ -122,6 +122,23 @@ def gateway_security_group_id( return security_group_id +def gateway_image_id(ec2_client: BaseClient) -> str: + response = ec2_client.describe_images( + Filters=[ + { + "Name": "name", + "Values": ["ubuntu/images/hvm-ssd/ubuntu-jammy-22.04-amd64-server-*"], + }, + { + "Name": "owner-alias", + "Values": ["amazon"], + }, + ], + ) + image = sorted(response["Images"], key=lambda i: i["CreationDate"], reverse=True)[0] + return image["ImageId"] + + def wait_till_running( ec2_client: BaseClient, instance: dict, delay: int = 5, attempts: int = 30 ) -> dict: diff --git a/cli/dstack/_internal/backend/azure/compute.py b/cli/dstack/_internal/backend/azure/compute.py index ee6a7386f..dda6f34c2 100644 --- a/cli/dstack/_internal/backend/azure/compute.py +++ b/cli/dstack/_internal/backend/azure/compute.py @@ -164,8 +164,14 @@ def create_gateway(self, instance_name: str, ssh_key_pub: str) -> GatewayHead: subscription_id=self.azure_config.subscription_id, location=self.azure_config.location, resource_group=self.azure_config.resource_group, - network=self.azure_config.network, - subnet=self.azure_config.subnet, + network=azure_utils.get_default_network_name( + storage_account=self.azure_config.storage_account, + location=self.azure_config.location, + ), + subnet=azure_utils.get_default_subnet_name( + storage_account=self.azure_config.storage_account, + location=self.azure_config.location, + ), instance_name=instance_name, ssh_key_pub=ssh_key_pub, ) @@ -176,14 +182,10 @@ def create_gateway(self, instance_name: str, ssh_key_pub: str) -> GatewayHead: "resource_name" ], ) - public_ip = gateway.get_public_ip( - network_client=self._network_client, - resource_group=self.azure_config.resource_group, - public_ip=interface.ip_configurations[0].public_ip_address.name, - ) + public_ip = interface.ip_configurations[0].public_ip_address.ip_address return GatewayHead( instance_name=instance_name, - external_ip=public_ip.ip_address, + external_ip=public_ip, internal_ip=interface.ip_configurations[0].private_ip_address, ) diff --git a/cli/dstack/_internal/backend/azure/gateway.py b/cli/dstack/_internal/backend/azure/gateway.py index 44265bd47..bd7ec2e99 100644 --- a/cli/dstack/_internal/backend/azure/gateway.py +++ b/cli/dstack/_internal/backend/azure/gateway.py @@ -177,13 +177,9 @@ def gateway_network_security_group( def get_network_interface( network_client: NetworkManagementClient, resource_group: str, interface: str ) -> NetworkInterface: - return network_client.network_interfaces.get(resource_group, interface) - - -def get_public_ip( - network_client: NetworkManagementClient, resource_group: str, public_ip: str -) -> PublicIPAddress: - return network_client.public_ip_addresses.get(resource_group, public_ip) + return network_client.network_interfaces.get( + resource_group, interface, expand="IPConfigurations/PublicIPAddress" + ) def gateway_user_data_script() -> str: diff --git a/cli/dstack/_internal/cli/commands/gateway/__init__.py b/cli/dstack/_internal/cli/commands/gateway/__init__.py index 337d0a7c1..8b3f6174c 100644 --- a/cli/dstack/_internal/cli/commands/gateway/__init__.py +++ b/cli/dstack/_internal/cli/commands/gateway/__init__.py @@ -1,5 +1,5 @@ from argparse import Namespace -from typing import List +from typing import Dict, List from rich.prompt import Confirm from rich.table import Table @@ -33,6 +33,7 @@ def register(self): "create", help="Create a gateway", formatter_class=RichHelpFormatter ) add_project_argument(create_parser) + create_parser.add_argument("--backend", choices=["aws", "gcp", "azure"], required=True) create_parser.set_defaults(sub_func=self.create_gateway) delete_gateway_parser = subparsers.add_parser( @@ -54,31 +55,38 @@ def _command(self, args: Namespace): def create_gateway(self, hub_client: HubClient, args: Namespace): print("Creating gateway, it may take some time...") - head = hub_client.create_gateway() - print_gateways_table([head]) + head = hub_client.create_gateway(backend=args.backend) + print_gateways_table({args.backend: [head]}) def list_gateways(self, hub_client: HubClient, args: Namespace): - heads = hub_client.list_gateways() - print_gateways_table(heads) + backends = hub_client.list_gateways() + print_gateways_table(backends) def delete_gateway(self, hub_client: HubClient, args: Namespace): - heads = hub_client.list_gateways() - if args.instance_name not in [head.instance_name for head in heads]: + backends = hub_client.list_gateways() + for backend, heads in backends.items(): + for head in heads: + if args.instance_name != head.instance_name: + continue + if args.yes or Confirm.ask(f"[red]Delete the gateway '{args.instance_name}'?[/]"): + hub_client.delete_gateway(args.instance_name, backend=backend) + console.print("Gateway is deleted") + return + else: exit(f"No such gateway '{args.instance_name}'") - if args.yes or Confirm.ask(f"[red]Delete the gateway '{args.instance_name}'?[/]"): - hub_client.delete_gateway(args.instance_name) - console.print("Gateway is deleted") - exit(0) -def print_gateways_table(heads: List[GatewayHead]): +def print_gateways_table(backends: Dict[str, List[GatewayHead]]): table = Table(box=None) + table.add_column("BACKEND") table.add_column("NAME") table.add_column("ADDRESS") - for head in heads: - table.add_row( - head.instance_name, - head.external_ip, - ) + for backend, heads in backends.items(): + for i, head in enumerate(heads): + table.add_row( + backend if i == 0 else "", + head.instance_name, + head.external_ip, + ) console.print(table) console.print() diff --git a/cli/dstack/_internal/configurators/__init__.py b/cli/dstack/_internal/configurators/__init__.py index 1d09f1d6b..76dfce76b 100644 --- a/cli/dstack/_internal/configurators/__init__.py +++ b/cli/dstack/_internal/configurators/__init__.py @@ -107,7 +107,7 @@ def apply_args(self, args: argparse.Namespace): self.profile.resources.gpu = ProfileGPU.parse_obj(gpu) if args.max_price is not None: - self.profile.resources.max_price = args.max_price + self.profile.max_price = args.max_price if args.spot_policy is not None: self.profile.spot_policy = args.spot_policy @@ -284,7 +284,7 @@ def requirements(self) -> job.Requirements: memory_mib=self.profile.resources.memory, gpus=None, shm_size_mib=self.profile.resources.shm_size, - max_price=self.profile.resources.max_price, + max_price=self.profile.max_price, ) if self.profile.resources.gpu: r.gpus = job.GpusRequirements( diff --git a/cli/dstack/_internal/core/gateway.py b/cli/dstack/_internal/core/gateway.py index c2edde599..8f0800eef 100644 --- a/cli/dstack/_internal/core/gateway.py +++ b/cli/dstack/_internal/core/gateway.py @@ -1,4 +1,5 @@ import time +from typing import Optional from pydantic import Field @@ -10,6 +11,7 @@ class GatewayHead(BaseHead): external_ip: str internal_ip: str created_at: int = Field(default_factory=lambda: int(time.time() * 1000)) + wildcard_domain: Optional[str] @classmethod def prefix(cls) -> str: diff --git a/cli/dstack/_internal/core/profile.py b/cli/dstack/_internal/core/profile.py index e7a48237d..ecd26e67a 100644 --- a/cli/dstack/_internal/core/profile.py +++ b/cli/dstack/_internal/core/profile.py @@ -90,9 +90,6 @@ class ProfileResources(ForbidExtra): ), ] cpu: int = DEFAULT_CPU - max_price: Annotated[ - Optional[confloat(gt=0.0)], Field(description="The maximum price per hour, $") - ] _validate_mem = validator("memory", "shm_size", pre=True, allow_reuse=True)(parse_memory) @validator("gpu", pre=True) @@ -134,6 +131,9 @@ class Profile(ForbidExtra): description="The maximum duration of a run (e.g., 2h, 1d, etc). After it elapses, the run is forced to stop" ), ] + max_price: Annotated[ + Optional[confloat(gt=0.0)], Field(description="The maximum price per hour, $") + ] default: bool = False _validate_max_duration = validator("max_duration", pre=True, allow_reuse=True)( parse_max_duration diff --git a/cli/dstack/_internal/hub/routers/gateways.py b/cli/dstack/_internal/hub/routers/gateways.py index 429668925..e522adb71 100644 --- a/cli/dstack/_internal/hub/routers/gateways.py +++ b/cli/dstack/_internal/hub/routers/gateways.py @@ -1,11 +1,12 @@ -from typing import List +import asyncio +from typing import Dict, List from fastapi import APIRouter, Body, Depends, HTTPException, status from dstack._internal.core.gateway import GatewayHead from dstack._internal.hub.routers.util import call_backend, error_detail, get_backends, get_project +from dstack._internal.hub.schemas import GatewayDelete from dstack._internal.hub.security.permissions import ProjectAdmin, ProjectMember -from dstack._internal.hub.utils.common import run_async from dstack._internal.hub.utils.ssh import get_hub_ssh_public_key router = APIRouter( @@ -14,10 +15,12 @@ @router.post("/{project_name}/gateways/create", dependencies=[Depends(ProjectAdmin())]) -async def gateways_create(project_name: str) -> GatewayHead: +async def gateways_create(project_name: str, backend_name: str = Body()) -> GatewayHead: project = await get_project(project_name=project_name) backends = await get_backends(project) for _, backend in backends: + if backend.name != backend_name: + continue try: return await call_backend(backend.create_gateway, get_hub_ssh_public_key()) except NotImplementedError: @@ -32,18 +35,20 @@ async def gateways_create(project_name: str) -> GatewayHead: @router.get("/{project_name}/gateways") -async def gateways_list(project_name: str) -> List[GatewayHead]: +async def gateways_list(project_name: str) -> Dict[str, List[GatewayHead]]: project = await get_project(project_name=project_name) backends = await get_backends(project) - gateways = [] - for _, backend in backends: - gateways += await call_backend(backend.list_gateways) - return gateways + tasks = [call_backend(backend.list_gateways) for _, backend in backends] + return { + backend.name: gateways + for (_, backend), gateways in zip(backends, await asyncio.gather(*tasks)) + } @router.post("/{project_name}/gateways/delete", dependencies=[Depends(ProjectAdmin())]) -async def gateways_delete(project_name: str, instance_name: str = Body()): +async def gateways_delete(project_name: str, body: GatewayDelete = Body()): project = await get_project(project_name=project_name) backends = await get_backends(project) for _, backend in backends: - await call_backend(backend.delete_gateway, instance_name) + if backend.name == body.backend: + await call_backend(backend.delete_gateway, body.instance_name) diff --git a/cli/dstack/_internal/hub/schemas/__init__.py b/cli/dstack/_internal/hub/schemas/__init__.py index e6f88e47d..b1c6ed23f 100644 --- a/cli/dstack/_internal/hub/schemas/__init__.py +++ b/cli/dstack/_internal/hub/schemas/__init__.py @@ -484,3 +484,8 @@ class DeleteUsers(BaseModel): class FileObject(BaseModel): object_key: str + + +class GatewayDelete(BaseModel): + instance_name: str + backend: str diff --git a/cli/dstack/api/hub/_api_client.py b/cli/dstack/api/hub/_api_client.py index dcb171d37..6fcb4bc92 100644 --- a/cli/dstack/api/hub/_api_client.py +++ b/cli/dstack/api/hub/_api_client.py @@ -19,13 +19,13 @@ from dstack._internal.core.log_event import LogEvent from dstack._internal.core.plan import RunPlan from dstack._internal.core.repo import RemoteRepoCredentials, Repo, RepoHead, RepoSpec -from dstack._internal.core.run import RunHead from dstack._internal.core.secret import Secret from dstack._internal.core.tag import TagHead from dstack._internal.hub.schemas import ( AddTagRun, ArtifactsList, BackendInfo, + GatewayDelete, JobHeadList, JobsGet, JobsList, @@ -721,13 +721,14 @@ def delete_configuration_cache(self, configuration_path: str): return resp.raise_for_status() - def create_gateway(self) -> GatewayHead: + def create_gateway(self, backend: str) -> GatewayHead: url = _project_url(url=self.url, project=self.project, additional_path="/gateways/create") resp = _make_hub_request( requests.post, host=self.url, url=url, headers=self._headers(), + data=json.dumps(backend), ) if resp.ok: return GatewayHead.parse_obj(resp.json()) @@ -737,7 +738,7 @@ def create_gateway(self) -> GatewayHead: raise HubClientError(body["detail"]["msg"]) resp.raise_for_status() - def list_gateways(self) -> List[GatewayHead]: + def list_gateways(self) -> Dict[str, List[GatewayHead]]: url = _project_url(url=self.url, project=self.project, additional_path="/gateways") resp = _make_hub_request( requests.get, @@ -747,16 +748,16 @@ def list_gateways(self) -> List[GatewayHead]: ) if not resp.ok: resp.raise_for_status() - return parse_obj_as(List[GatewayHead], resp.json()) + return parse_obj_as(Dict[str, List[GatewayHead]], resp.json()) - def delete_gateway(self, instance_name: str): + def delete_gateway(self, instance_name: str, backend: str): url = _project_url(url=self.url, project=self.project, additional_path="/gateways/delete") resp = _make_hub_request( requests.post, host=self.url, url=url, headers=self._headers(), - data=json.dumps(instance_name), + data=GatewayDelete(instance_name=instance_name, backend=backend).json(), ) resp.raise_for_status() diff --git a/cli/dstack/api/hub/_client.py b/cli/dstack/api/hub/_client.py index cf7621160..d0b89e542 100644 --- a/cli/dstack/api/hub/_client.py +++ b/cli/dstack/api/hub/_client.py @@ -1,11 +1,9 @@ import copy -import sys import tempfile import time -import urllib.parse from datetime import datetime from pathlib import Path -from typing import Generator, List, Optional, Tuple +from typing import Dict, Generator, List, Optional, Tuple import dstack._internal.configurators as configurators from dstack._internal.api.repos import get_local_repo_credentials @@ -17,10 +15,9 @@ from dstack._internal.core.plan import RunPlan from dstack._internal.core.repo import RemoteRepoCredentials, Repo, RepoHead from dstack._internal.core.repo.remote import RemoteRepo -from dstack._internal.core.run import RunHead from dstack._internal.core.secret import Secret from dstack._internal.core.tag import TagHead -from dstack._internal.hub.schemas import BackendInfo, ProjectInfo, RunInfo +from dstack._internal.hub.schemas import BackendInfo, RunInfo from dstack.api.hub._api_client import HubAPIClient from dstack.api.hub._config import HubClientConfig from dstack.api.hub._storage import HUBStorage @@ -281,11 +278,11 @@ def run_configuration( self.update_repo_last_run_at(last_run_at=int(round(time.time() * 1000))) return run_name, jobs - def create_gateway(self) -> GatewayHead: - return self._api_client.create_gateway() + def create_gateway(self, backend: str) -> GatewayHead: + return self._api_client.create_gateway(backend=backend) - def list_gateways(self) -> List[GatewayHead]: + def list_gateways(self) -> Dict[str, List[GatewayHead]]: return self._api_client.list_gateways() - def delete_gateway(self, instance_name: str): - self._api_client.delete_gateway(instance_name) + def delete_gateway(self, instance_name: str, backend: str): + self._api_client.delete_gateway(instance_name, backend=backend) diff --git a/docs/docs/reference/cli/gateway.md b/docs/docs/reference/cli/gateway.md index c4b72ab7e..a782da6a2 100644 --- a/docs/docs/reference/cli/gateway.md +++ b/docs/docs/reference/cli/gateway.md @@ -7,14 +7,14 @@ Gateway makes running jobs (`type: service`) accessible from the public internet ## dstack gateway list -The `dstack gateway list` command displays the names and addresses of the gateways configured in the selected project. +The `dstack gateway list` command displays the names and addresses of the gateways configured in the project. ### Usage
```shell -$ dstack gateway list --project gcp +$ dstack gateway list ```
@@ -34,6 +34,7 @@ Usage: dstack gateway create [-h] [--project PROJECT] Optional Arguments: -h, --help show this help message and exit --project PROJECT The name of the project + --backend {aws,gcp,azure} ``` @@ -43,6 +44,7 @@ Optional Arguments: The following arguments are optional: - `--project PROJECT` - (Optional) The name of the project to execute the command for +- `--backend {aws,gcp,azure}` - (Optional) The cloud provider to use for the gateway ## dstack gateway delete diff --git a/docs/docs/reference/profiles.yml.md b/docs/docs/reference/profiles.yml.md index b17e0d552..eb3721516 100644 --- a/docs/docs/reference/profiles.yml.md +++ b/docs/docs/reference/profiles.yml.md @@ -17,12 +17,12 @@ Below is a full reference of all available properties. - `memory` - (Optional) The minimum size of GPU memory (e.g., `"16GB"`) - `shm_size` - (Optional) The size of shared memory (e.g., `"8GB"`). If you are using parallel communicating processes (e.g., dataloaders in PyTorch), you may need to configure this. - - `max_price` - (Optional) Maximum price per hour, $ - `spot_policy` - (Optional) The policy for provisioning spot or on-demand instances: `spot`, `on-demand`, or `auto`. `spot` provisions a spot instance. `on-demand` provisions a on-demand instance. `auto` first tries to provision a spot instance and then tries on-demand if spot is not available. Defaults to `on-demand` for dev environments and to `auto` for tasks. - `retry_policy` - (Optional) The policy for re-submitting the run. - `retry` - (Optional) Whether to retry the run on failure or not. Default to `false` - `limit` - (Optional) The maximum period of retrying the run, e.g., `4h` or `1d`. Defaults to `1h` if `retry` is `true`. - `max_duration` - (Optional) The maximum duration of a run (e.g., `2h`, `1d`, etc). After it elapses, the run is forced to stop. Protects from running idle instances. Defaults to `6h` for dev environments and to `72h` for tasks. Use `max_duration: off` to disable maximum run duration. + - `max_price` - (Optional) Maximum price per hour, $ [//]: # (TODO: Add examples)