Skip to content

Commit

Permalink
feat: Add quota check for restricted image (#3421)
Browse files Browse the repository at this point in the history
  • Loading branch information
weiran-work authored Aug 13, 2024
1 parent cdca4a3 commit a9e59e9
Showing 1 changed file with 19 additions and 4 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -472,18 +472,23 @@ def get_quota(project_id: str, region: str, resource_id: str) -> int:
return -1


def get_resource_id(accelerator_type: str, is_for_training: bool) -> str:
def get_resource_id(
accelerator_type: str,
is_for_training: bool,
is_restricted_image: bool = False,
) -> str:
"""Returns the resource id for a given accelerator type and the use case.
Args:
accelerator_type: The accelerator type.
is_for_training: Whether the resource is used for training. Set false for
serving use case.
is_restricted_image: Whether the image is hosted in `vertex-ai-restricted`.
Returns:
The resource id.
"""
training_accelerator_map = {
default_training_accelerator_map = {
"NVIDIA_TESLA_V100": "custom_model_training_nvidia_v100_gpus",
"NVIDIA_L4": "custom_model_training_nvidia_l4_gpus",
"NVIDIA_TESLA_A100": "custom_model_training_nvidia_a100_gpus",
Expand All @@ -493,6 +498,9 @@ def get_resource_id(accelerator_type: str, is_for_training: bool) -> str:
"TPU_V5e": "custom_model_training_tpu_v5e",
"TPU_V3": "custom_model_training_tpu_v3",
}
restricted_image_training_accelerator_map = {
"NVIDIA_A100_80GB": "restricted_image_training_nvidia_a100_80gb_gpus",
}
serving_accelerator_map = {
"NVIDIA_TESLA_V100": "custom_model_serving_nvidia_v100_gpus",
"NVIDIA_L4": "custom_model_serving_nvidia_l4_gpus",
Expand All @@ -503,6 +511,11 @@ def get_resource_id(accelerator_type: str, is_for_training: bool) -> str:
"TPU_V5e": "custom_model_serving_tpu_v5e",
}
if is_for_training:
training_accelerator_map = (
restricted_image_training_accelerator_map
if is_restricted_image
else default_training_accelerator_map
)
if accelerator_type in training_accelerator_map:
return training_accelerator_map[accelerator_type]
else:
Expand All @@ -524,9 +537,12 @@ def check_quota(
accelerator_type: str,
accelerator_count: int,
is_for_training: bool,
is_restricted_image: bool = False,
):
"""Checks if the project and the region has the required quota."""
resource_id = get_resource_id(accelerator_type, is_for_training)
resource_id = get_resource_id(
accelerator_type, is_for_training, is_restricted_image
)
quota = get_quota(project_id, region, resource_id)
quota_request_instruction = (
"Either use "
Expand All @@ -546,4 +562,3 @@ def check_quota(
f"Quota not enough for {resource_id} in {region}: {quota} <"
f" {accelerator_count}. {quota_request_instruction}"
)

0 comments on commit a9e59e9

Please sign in to comment.