Skip to content

Commit

Permalink
Simplify regions configuration (#659)
Browse files Browse the repository at this point in the history
* Drop storage region for AWS

* Drop storage region for GCP

* Drop storage region for Azure

* Fixes

* #656 Refactoring regions fields, replaced columns in backend table

* Fix Azure and GCP configuration

* Fix gcp zones

* Fix tests

---------

Co-authored-by: Oleg Vavilov <[email protected]>
  • Loading branch information
r4victor and olgenn authored Aug 18, 2023
1 parent c9e1510 commit 33f1238
Show file tree
Hide file tree
Showing 26 changed files with 285 additions and 727 deletions.
4 changes: 2 additions & 2 deletions cli/dstack/_internal/backend/aws/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,12 +27,12 @@ def __init__(
self.backend_config = backend_config
if self.backend_config.credentials is not None:
self._session = Session(
region_name=self.backend_config.region_name,
region_name=self.backend_config.region,
aws_access_key_id=self.backend_config.credentials.get("access_key"),
aws_secret_access_key=self.backend_config.credentials.get("secret_key"),
)
else:
self._session = Session(region_name=self.backend_config.region_name)
self._session = Session(region_name=self.backend_config.region)
self._storage = AWSStorage(
s3_client=aws_utils.get_s3_client(self._session),
bucket_name=self.backend_config.bucket_name,
Expand Down
6 changes: 3 additions & 3 deletions cli/dstack/_internal/backend/aws/compute.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ def get_instance_type(self, job: Job) -> Optional[InstanceType]:

def get_supported_instances(self) -> List[InstanceType]:
instances = {}
for region in [self.backend_config.region_name, *self.backend_config.extra_regions]:
for region in self.backend_config.regions:
for i in runners._get_instance_types(self._get_ec2_client(region=region)):
if i.instance_name not in instances:
instances[i.instance_name] = i
Expand All @@ -48,13 +48,13 @@ def get_supported_instances(self) -> List[InstanceType]:
return list(instances.values())

def run_instance(
self, job: Job, instance_type: InstanceType, region: Optional[str] = None
self, job: Job, instance_type: InstanceType, region: str
) -> LaunchedInstanceInfo:
return runners.run_instance(
session=self.session,
iam_client=self.iam_client,
bucket_name=self.backend_config.bucket_name,
region_name=region or self.backend_config.region_name,
region_name=region,
subnet_id=self.backend_config.subnet_id,
runner_id=job.runner_id,
instance_type=instance_type,
Expand Down
16 changes: 7 additions & 9 deletions cli/dstack/_internal/backend/aws/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,25 +4,23 @@

from dstack._internal.backend.base.config import BackendConfig

DEFAULT_REGION_NAME = "us-east-1"
DEFAULT_REGION = "us-east-1"


class AWSConfig(BackendConfig, BaseModel):
bucket_name: str
region_name: Optional[str] = DEFAULT_REGION_NAME
extra_regions: List[str] = []
regions: List[str]
subnet_id: Optional[str] = None
credentials: Optional[Dict] = None
# dynamically set
region: Optional[str] = DEFAULT_REGION

def serialize(self) -> Dict:
config_data = {
"backend": "aws",
"bucket": self.bucket_name,
"regions": self.regions,
}
if self.region_name:
config_data["region"] = self.region_name
if self.extra_regions:
config_data["extra_regions"] = self.extra_regions
if self.subnet_id:
config_data["subnet"] = self.subnet_id
return config_data
Expand All @@ -37,7 +35,7 @@ def deserialize(cls, config_data: Dict) -> Optional["AWSConfig"]:
return None
return cls(
bucket_name=bucket_name,
region_name=config_data.get("region"),
extra_regions=config_data.get("extra_regions", []),
regions=config_data.get("regions", []),
subnet_id=config_data.get("subnet"),
region_name=config_data.get("region"),
)
6 changes: 3 additions & 3 deletions cli/dstack/_internal/backend/azure/compute.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ def get_instance_type(self, job: Job) -> Optional[InstanceType]:

def get_supported_instances(self) -> List[InstanceType]:
instances = {}
for location in [self.azure_config.location, *self.azure_config.extra_locations]:
for location in self.azure_config.locations:
for i in _get_instance_types(client=self._compute_client, location=location):
if i.instance_name not in instances:
instances[i.instance_name] = i
Expand All @@ -119,14 +119,14 @@ def get_supported_instances(self) -> List[InstanceType]:
return list(instances.values())

def run_instance(
self, job: Job, instance_type: InstanceType, region: Optional[str] = None
self, job: Job, instance_type: InstanceType, region: str
) -> LaunchedInstanceInfo:
return _run_instance(
compute_client=self._compute_client,
azure_config=self.azure_config,
job=job,
instance_type=instance_type,
location=region or self.azure_config.location,
location=region,
)

def get_request_head(self, job: Job, request_id: Optional[str]) -> RequestHead:
Expand Down
7 changes: 3 additions & 4 deletions cli/dstack/_internal/backend/azure/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,12 @@ class AzureConfig(BackendConfig, BaseModel):
backend: Literal["azure"] = "azure"
tenant_id: str
subscription_id: str
location: str
resource_group: str
storage_account: str
vault_url: str
extra_locations: List[str]
# network and subnet are location-dependent.
# Hub selects them dynamically when provisioning.
locations: List[str]
# set dynamically
location: Optional[str]
network: Optional[str]
subnet: Optional[str]
credentials: Optional[Dict] = None
Expand Down
2 changes: 1 addition & 1 deletion cli/dstack/_internal/backend/base/compute.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ def get_supported_instances(self) -> List[InstanceType]:

@abstractmethod
def run_instance(
self, job: Job, instance_type: InstanceType, region: Optional[str] = None
self, job: Job, instance_type: InstanceType, region: str
) -> LaunchedInstanceInfo:
pass

Expand Down
27 changes: 10 additions & 17 deletions cli/dstack/_internal/backend/gcp/compute.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,9 +121,7 @@ def get_supported_instances(self) -> List[InstanceType]:
zones = _get_zones(
regions_client=self.regions_client,
project_id=self.gcp_config.project_id,
primary_region=self.gcp_config.region,
primary_zone=self.gcp_config.zone,
extra_regions=self.gcp_config.extra_regions,
configured_regions=self.gcp_config.regions,
)
for zone in zones:
region = zone[:-2]
Expand Down Expand Up @@ -170,14 +168,12 @@ def get_supported_instances(self) -> List[InstanceType]:
return self._supported_instances_cache

def run_instance(
self, job: Job, instance_type: InstanceType, region: Optional[str] = None
self, job: Job, instance_type: InstanceType, region: str
) -> LaunchedInstanceInfo:
zones = _get_zones(
regions_client=self.regions_client,
project_id=self.gcp_config.project_id,
primary_region=region or self.gcp_config.region,
primary_zone=self.gcp_config.zone, # doesn't matter if zone is not from the region
extra_regions=[], # regions are managed at the project level
configured_regions=[region],
)
# Note: not all zones in the region may offer the chosen instance type,
# for now, just treat NotFound error as NoCapacity
Expand Down Expand Up @@ -663,18 +659,15 @@ def _run_instance(
def _get_zones(
regions_client: compute_v1.RegionsClient,
project_id: str,
primary_region: str,
primary_zone: str,
extra_regions: List[str],
configured_regions: List[str],
) -> List[str]:
regions = regions_client.list(project=project_id)
region_name_to_zones_map = {
r.name: [gcp_utils.get_resource_name(z) for z in r.zones] for r in regions
}
zones = region_name_to_zones_map[primary_region]
zones = sorted(zones, key=lambda x: x != primary_zone)
for extra_region in extra_regions:
zones += region_name_to_zones_map[extra_region]
zones = [
gcp_utils.get_resource_name(z)
for r in regions
for z in r.zones
if r.name in configured_regions
]
return zones


Expand Down
13 changes: 7 additions & 6 deletions cli/dstack/_internal/backend/gcp/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,25 +9,26 @@
class GCPConfig(BackendConfig, BaseModel):
backend: Literal["gcp"] = "gcp"
project_id: str
region: str
zone: str
regions: List[str]
bucket_name: str
vpc: str
subnet: str
extra_regions: List[str] = []
credentials_file: Optional[str] = None
credentials: Optional[Dict] = None
# dynamically set
region: Optional[str]
zone: Optional[str]

def serialize(self) -> Dict:
res = {
"backend": "gcp",
"project": self.project_id,
"region": self.region,
"zone": self.zone,
"regions": self.regions,
"bucket": self.bucket_name,
"vpc": self.vpc,
"subnet": self.subnet,
"extra_regions": self.extra_regions,
"region": self.region,
"zone": self.zone,
}
if self.credentials_file is not None:
res["credentials_file"] = self.credentials_file
Expand Down
2 changes: 1 addition & 1 deletion cli/dstack/_internal/backend/lambdalabs/compute.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,7 @@ def get_supported_instances(self) -> List[InstanceType]:
return _list_instance_types(self.api_client, self.lambda_config.regions)

def run_instance(
self, job: Job, instance_type: InstanceType, region: Optional[str] = None
self, job: Job, instance_type: InstanceType, region: str
) -> LaunchedInstanceInfo:
instance_id = _run_instance(
api_client=self.api_client,
Expand Down
2 changes: 1 addition & 1 deletion cli/dstack/_internal/backend/local/compute.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ def get_supported_instances(self) -> List[InstanceType]:
return [InstanceType(instance_name="", resources=resources, available_regions=[""])]

def run_instance(
self, job: Job, instance_type: InstanceType, region: Optional[str] = None
self, job: Job, instance_type: InstanceType, region: str
) -> LaunchedInstanceInfo:
pid = runners.start_runner_process(self.backend_config, job.runner_id)
return LaunchedInstanceInfo(request_id=pid, location=None)
Expand Down
35 changes: 9 additions & 26 deletions cli/dstack/_internal/hub/schemas/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,19 +43,15 @@ class LocalBackendConfig(BaseModel):

class AWSBackendConfigPartial(BaseModel):
type: Literal["aws"] = "aws"
region_name: Optional[str]
region_name_title: Optional[str]
extra_regions: Optional[List[str]]
s3_bucket_name: Optional[str]
regions: Optional[List[str]]
ec2_subnet_id: Optional[str]


class AWSBackendConfig(BaseModel):
type: Literal["aws"] = "aws"
region_name: str
region_name_title: Optional[str]
extra_regions: List[str] = []
s3_bucket_name: str
regions: List[str]
ec2_subnet_id: Optional[str]


Expand Down Expand Up @@ -85,24 +81,18 @@ class AWSBackendConfigWithCreds(AWSBackendConfig):

class GCPBackendConfigPartial(BaseModel):
type: Literal["gcp"] = "gcp"
area: Optional[str]
region: Optional[str]
zone: Optional[str]
bucket_name: Optional[str]
regions: Optional[List[str]]
vpc: Optional[str]
subnet: Optional[str]
extra_regions: Optional[List[str]]


class GCPBackendConfig(BaseModel):
type: Literal["gcp"] = "gcp"
area: str
region: str
zone: str
bucket_name: str
regions: List[str]
vpc: str
subnet: str
extra_regions: List[str] = []


class GCPBackendDefaultCreds(BaseModel):
Expand Down Expand Up @@ -133,9 +123,8 @@ class AzureBackendConfigPartial(BaseModel):
type: Literal["azure"] = "azure"
tenant_id: Optional[str]
subscription_id: Optional[str]
location: Optional[str]
storage_account: Optional[str]
extra_locations: Optional[List[str]]
locations: Optional[List[str]]


class AzureBackendClientCreds(BaseModel):
Expand All @@ -162,9 +151,8 @@ class AzureBackendConfig(BaseModel):
type: Literal["azure"] = "azure"
tenant_id: str
subscription_id: str
location: str
storage_account: str
extra_locations: List[str] = []
locations: List[str]


class AzureBackendConfigWithCreds(AzureBackendConfig):
Expand Down Expand Up @@ -274,8 +262,7 @@ class AWSBucketBackendElement(BaseModel):
class AWSBackendValues(BaseModel):
type: Literal["aws"] = "aws"
default_credentials: bool = False
region_name: Optional[BackendElement]
extra_regions: Optional[BackendMultiElement]
regions: Optional[BackendMultiElement]
s3_bucket_name: Optional[AWSBucketBackendElement]
ec2_subnet_id: Optional[BackendElement]

Expand All @@ -294,22 +281,18 @@ class GCPVPCSubnetBackendElement(BaseModel):
class GCPBackendValues(BaseModel):
type: Literal["gcp"] = "gcp"
default_credentials: bool = False
area: Optional[BackendElement]
region: Optional[BackendElement]
zone: Optional[BackendElement]
bucket_name: Optional[BackendElement]
regions: Optional[BackendMultiElement]
vpc_subnet: Optional[GCPVPCSubnetBackendElement]
extra_regions: Optional[BackendMultiElement]


class AzureBackendValues(BaseModel):
type: Literal["azure"] = "azure"
default_credentials: bool = False
tenant_id: Optional[BackendElement]
subscription_id: Optional[BackendElement]
location: Optional[BackendElement]
storage_account: Optional[BackendElement]
extra_locations: Optional[BackendMultiElement]
locations: Optional[BackendMultiElement]


class AWSStorageBackendValues(BaseModel):
Expand Down
Loading

0 comments on commit 33f1238

Please sign in to comment.