Skip to content

Commit

Permalink
- [Doc] Fixing the gateways broken by the implementation of multiple …
Browse files Browse the repository at this point in the history
…regions (#667)
  • Loading branch information
peterschmidt85 authored Aug 21, 2023
1 parent 6113a11 commit d8d2141
Show file tree
Hide file tree
Showing 11 changed files with 57 additions and 37 deletions.
1 change: 1 addition & 0 deletions cli/dstack/_internal/backend/aws/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@ def logging(self) -> AWSLogging:
def pricing(self) -> AWSPricing:
return self._pricing

# TODO: The `offer` field must be required
def run_job(
self,
job: Job,
Expand Down
16 changes: 11 additions & 5 deletions cli/dstack/_internal/backend/aws/compute.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,9 +31,10 @@ def get_request_head(self, job: Job, request_id: Optional[str]) -> RequestHead:
request_id=request_id,
)

def get_instance_type(self, job: Job) -> Optional[InstanceType]:
# TODO: This function is deprecated and will be deleted in 0.11.x
def get_instance_type(self, job: Job, region_name: Optional[str]) -> Optional[InstanceType]:
return runners.get_instance_type(
ec2_client=self._get_ec2_client(),
ec2_client=self._get_ec2_client(region_name),
requirements=job.requirements,
)

Expand Down Expand Up @@ -78,8 +79,10 @@ def cancel_spot_request(self, runner: Runner):
)

def create_gateway(self, instance_name: str, ssh_key_pub: str) -> GatewayHead:
# TODO: This must be a configurable field of the gateway
default_region_name = self.backend_config.regions[0]
instance = gateway.create_gateway_instance(
ec2_client=self._get_ec2_client(region=self.backend_config.region_name),
ec2_client=self._get_ec2_client(region=default_region_name),
subnet_id=self.backend_config.subnet_id,
bucket_name=self.backend_config.bucket_name,
instance_name=instance_name,
Expand All @@ -91,14 +94,17 @@ def create_gateway(self, instance_name: str, ssh_key_pub: str) -> GatewayHead:
internal_ip=instance["PrivateIpAddress"],
)

# TODO: Must be renamed to `delete_gateway_instance`
def delete_instance(self, instance_name: str):
# TODO: This must be a configurable field of the gateway
default_region_name = self.backend_config.regions[0]
try:
instance_id = gateway.get_instance_id(
ec2_client=self._get_ec2_client(region=self.backend_config.region_name),
ec2_client=self._get_ec2_client(region=default_region_name),
instance_name=instance_name,
)
runners.terminate_instance(
ec2_client=self._get_ec2_client(region=self.backend_config.region_name),
ec2_client=self._get_ec2_client(region=default_region_name),
request_id=instance_id,
)
except IndexError:
Expand Down
14 changes: 9 additions & 5 deletions cli/dstack/_internal/backend/azure/compute.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,9 +102,10 @@ def __init__(
self.azure_config.storage_account,
)

def get_instance_type(self, job: Job) -> Optional[InstanceType]:
# TODO: This function is deprecated and will be deleted in 0.11.x
def get_instance_type(self, job: Job, region_name: Optional[str]) -> Optional[InstanceType]:
instance_types = _get_instance_types(
client=self._compute_client, location=self.azure_config.location
client=self._compute_client, location=region_name or self.azure_config.locations[0]
)
return choose_instance_type(instance_types=instance_types, requirements=job.requirements)

Expand Down Expand Up @@ -158,19 +159,22 @@ def cancel_spot_request(self, runner: Runner):
self.terminate_instance(runner)

def create_gateway(self, instance_name: str, ssh_key_pub: str) -> GatewayHead:
# TODO: This must be a configurable field of the gateway
default_location = self.azure_config.locations[0]
vm = gateway.create_gateway(
storage_account=self.azure_config.storage_account,
compute_client=self._compute_client,
network_client=self._network_client,
subscription_id=self.azure_config.subscription_id,
location=self.azure_config.location,
location=default_location,
resource_group=self.azure_config.resource_group,
network=azure_utils.get_default_network_name(
storage_account=self.azure_config.storage_account,
location=self.azure_config.location,
location=default_location,
),
subnet=azure_utils.get_default_subnet_name(
storage_account=self.azure_config.storage_account,
location=self.azure_config.location,
location=default_location,
),
instance_name=instance_name,
ssh_key_pub=ssh_key_pub,
Expand Down
2 changes: 0 additions & 2 deletions cli/dstack/_internal/backend/azure/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,6 @@ class AzureConfig(BackendConfig, BaseModel):
storage_account: str
vault_url: str
locations: List[str]
# set dynamically
location: Optional[str]
network: Optional[str]
subnet: Optional[str]
credentials: Optional[Dict] = None
Expand Down
18 changes: 14 additions & 4 deletions cli/dstack/_internal/backend/azure/gateway.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,7 @@
from azure.mgmt.network import NetworkManagementClient
from azure.mgmt.network.models import (
NetworkInterface,
NetworkInterfaceIPConfiguration,
NetworkSecurityGroup,
PublicIPAddress,
SecurityRule,
SecurityRuleAccess,
SecurityRuleDirection,
Expand All @@ -38,6 +36,7 @@


def create_gateway(
storage_account: str,
compute_client: ComputeManagementClient,
network_client: NetworkManagementClient,
subscription_id: str,
Expand Down Expand Up @@ -73,6 +72,7 @@ def create_gateway(
network_profile=NetworkProfile(
network_api_version=NetworkManagementClient.DEFAULT_API_VERSION,
network_interface_configurations=gateway_interface_configurations(
storage_account=storage_account,
network_client=network_client,
subscription_id=subscription_id,
location=location,
Expand Down Expand Up @@ -113,6 +113,7 @@ def gateway_storage_profile() -> StorageProfile:


def gateway_interface_configurations(
storage_account: str,
network_client: NetworkManagementClient,
subscription_id: str,
location: str,
Expand All @@ -123,7 +124,9 @@ def gateway_interface_configurations(
conf = VirtualMachineNetworkInterfaceConfiguration(
name="nic_config",
network_security_group=SubResource(
id=gateway_network_security_group(network_client, location, resource_group)
id=gateway_network_security_group(
storage_account, network_client, location, resource_group
)
),
ip_configurations=[
VirtualMachineNetworkInterfaceIPConfiguration(
Expand All @@ -145,14 +148,21 @@ def gateway_interface_configurations(
return [conf]


def _get_gateway_network_security_group_name(storage_account: str, location: str) -> str:
return f"{storage_account}-{location}-gateway-security-group"


def gateway_network_security_group(
storage_account: str,
network_client: NetworkManagementClient,
location: str,
resource_group: str,
) -> str:
poller = network_client.network_security_groups.begin_create_or_update(
resource_group_name=resource_group,
network_security_group_name="dstack-gateway-network-security-group",
network_security_group_name=_get_gateway_network_security_group_name(
storage_account, location
),
parameters=NetworkSecurityGroup(
location=location,
security_rules=[
Expand Down
13 changes: 6 additions & 7 deletions cli/dstack/_internal/backend/base/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,21 +46,22 @@ def load(cls) -> Optional["Backend"]:
def name(self) -> str:
return self.NAME

@abstractmethod
def predict_instance_type(self, job: Job) -> Optional[InstanceType]:
pass

@abstractmethod
def create_job(
self,
job: Job,
):
pass

# TODO: Is this function used at all?
# TODO: This must use offers from multiple clouds
# TODO: Why does `run_job` not pass `project_private_key`?
def submit_job(self, job: Job, failed_to_start_job_new_status: JobStatus = JobStatus.FAILED):
self.create_job(job)
self.run_job(job, failed_to_start_job_new_status)

# TODO: This must use offers from multiple clouds
# TODO: Why does `run_job` not pass `project_private_key`?
def resubmit_job(self, job: Job, failed_to_start_job_new_status: JobStatus = JobStatus.FAILED):
base_jobs.update_job_submission(job)
self.run_job(job, failed_to_start_job_new_status)
Expand Down Expand Up @@ -296,9 +297,6 @@ def logging(self) -> Logging:
def pricing(self) -> Pricing:
pass

def predict_instance_type(self, job: Job) -> Optional[InstanceType]:
return base_jobs.predict_job_instance(self.compute(), job)

def create_job(self, job: Job):
base_jobs.create_job(self.storage(), job)

Expand All @@ -308,6 +306,7 @@ def get_job(self, repo_id: str, job_id: str) -> Optional[Job]:
def list_jobs(self, repo_id: str, run_name: str) -> List[Job]:
return base_jobs.list_jobs(self.storage(), repo_id, run_name)

# TODO: The `offer` field must be required
def run_job(
self,
job: Job,
Expand Down
3 changes: 2 additions & 1 deletion cli/dstack/_internal/backend/base/compute.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,9 @@ class Compute(ABC):
def get_request_head(self, job: Job, request_id: Optional[str]) -> RequestHead:
pass

# TODO: This function is deprecated and will be deleted in 0.11.x
@abstractmethod
def get_instance_type(self, job: Job) -> Optional[InstanceType]:
def get_instance_type(self, job: Job, region_name: Optional[str]) -> Optional[InstanceType]:
pass

@abstractmethod
Expand Down
11 changes: 3 additions & 8 deletions cli/dstack/_internal/backend/base/jobs.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,13 +109,7 @@ def delete_jobs(storage: Storage, repo_id: str, run_name: str):
storage.delete_object(job_key)


def predict_job_instance(
compute: Compute,
job: Job,
) -> Optional[InstanceType]:
return compute.get_instance_type(job)


# TODO: The `offer` field must be required
def run_job(
storage: Storage,
compute: Compute,
Expand Down Expand Up @@ -213,6 +207,7 @@ def update_job_submission(job: Job):
job.submitted_at = get_milliseconds_since_epoch()


# TODO: The `offer` field must be required
def _try_run_job(
storage: Storage,
compute: Compute,
Expand All @@ -228,7 +223,7 @@ def _try_run_job(
and attempt == 0
)
job.requirements.spot = spot
instance_type = compute.get_instance_type(job)
instance_type = compute.get_instance_type(job, region_name=job.location)
else:
job.requirements.spot = offer.instance_type.resources.spot
instance_type = offer.instance_type
Expand Down
12 changes: 8 additions & 4 deletions cli/dstack/_internal/backend/gcp/compute.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,8 @@ def get_request_head(self, job: Job, request_id: Optional[str]) -> RequestHead:
message=None,
)

def get_instance_type(self, job: Job) -> Optional[InstanceType]:
# TODO: This function is deprecated and will be deleted in 0.11.x
def get_instance_type(self, job: Job, region: Optional[str]) -> Optional[InstanceType]:
return _choose_instance_type(
machine_types_client=self.machine_types_client,
accelerator_types_client=self.accelerator_types_client,
Expand Down Expand Up @@ -190,14 +191,17 @@ def cancel_spot_request(self, runner: Runner):
)

def create_gateway(self, instance_name: str, ssh_key_pub: str) -> GatewayHead:
region = self.gcp_config.regions[0]
# TODO: This must be a configurable field of the gateway
default_region_name = self.gcp_config.regions[0]
instance = gateway.create_gateway_instance(
instances_client=self.instances_client,
firewalls_client=self.firewalls_client,
project_id=self.gcp_config.project_id,
network=_get_network_resource(self.gcp_config.vpc),
subnet=_get_subnet_resource(region, self.gcp_config.subnet),
zone=_get_zones(self.regions_client, self.gcp_config.project_id, [region])[0],
subnet=_get_subnet_resource(default_region_name, self.gcp_config.subnet),
zone=_get_zones(
self.regions_client, self.gcp_config.project_id, [default_region_name]
)[0],
instance_name=instance_name,
service_account=self.credentials.service_account_email,
labels=dict(
Expand Down
3 changes: 2 additions & 1 deletion cli/dstack/_internal/backend/local/compute.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,14 +12,15 @@
from dstack._internal.core.runners import Runner


# TODO: The entire backend is deprecated and will be deleted in 0.11.x
class LocalCompute(Compute):
def __init__(self, backend_config: LocalConfig):
self.backend_config = backend_config

def get_request_head(self, job: Job, request_id: Optional[str]) -> RequestHead:
return runners.get_request_head(job, request_id)

def get_instance_type(self, job: Job) -> Optional[InstanceType]:
def get_instance_type(self, job: Job, region_name: Optional[str]) -> Optional[InstanceType]:
resources = runners.check_runner_resources(self.backend_config, job.runner_id)
instance_type = choose_instance_type(
instance_types=[InstanceType(instance_name="", resources=resources)],
Expand Down
1 change: 1 addition & 0 deletions cli/dstack/_internal/core/job.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,6 +200,7 @@ class Job(JobHead):
instance_spot_type: Optional[str]
instance_type: Optional[str]
job_id: str
# TODO: Rename to `region_name`
location: Optional[str]
master_job: Optional[str] # not implemented
max_duration: Optional[int]
Expand Down

0 comments on commit d8d2141

Please sign in to comment.