diff --git a/modal/functions.py b/modal/functions.py index 08db1c19c..93146fd1f 100644 --- a/modal/functions.py +++ b/modal/functions.py @@ -785,7 +785,8 @@ async def _load(self: _Function, resolver: Resolver, existing_object_id: Optiona task_idle_timeout_secs=container_idle_timeout or 0, concurrency_limit=concurrency_limit or 0, pty_info=pty_info, - cloud_provider=cloud_provider, + cloud_provider=cloud_provider, # Deprecated at some point + cloud_provider_str=cloud.upper() if cloud else None, # Supersedes cloud_provider warm_pool_size=keep_warm or 0, runtime=config.get("function_runtime"), runtime_debug=config.get("function_runtime_debug"), diff --git a/modal/sandbox.py b/modal/sandbox.py index aae30b8db..ea460a4e7 100644 --- a/modal/sandbox.py +++ b/modal/sandbox.py @@ -165,7 +165,8 @@ async def _load(self: _Sandbox, resolver: Resolver, _existing_object_id: Optiona resources=convert_fn_config_to_resources_config( cpu=cpu, memory=memory, gpu=gpu, ephemeral_disk=ephemeral_disk ), - cloud_provider=parse_cloud_provider(cloud) if cloud else None, + cloud_provider=parse_cloud_provider(cloud) if cloud else None, # Deprecated at some point + cloud_provider_str=cloud.upper() if cloud else None, # Supersedes cloud_provider nfs_mounts=network_file_system_mount_protos(validated_network_file_systems, False), runtime_debug=config.get("function_runtime_debug"), cloud_bucket_mounts=cloud_bucket_mounts_to_proto(cloud_bucket_mounts), diff --git a/modal_proto/api.proto b/modal_proto/api.proto index 8579b04b9..94876b80b 100644 --- a/modal_proto/api.proto +++ b/modal_proto/api.proto @@ -1163,7 +1163,7 @@ message Function { uint32 task_idle_timeout_secs = 25; - optional CloudProvider cloud_provider = 26; + optional CloudProvider cloud_provider = 26; // Deprecated at some point uint32 warm_pool_size = 27; @@ -1257,6 +1257,8 @@ message Function { bool method_definitions_set = 75; bool _experimental_custom_scaling = 76; + + string cloud_provider_str = 77; // Supersedes cloud_provider } message FunctionAsyncInvokeRequest { @@ -2192,7 +2194,7 @@ message Sandbox { repeated string secret_ids = 4; Resources resources = 5; - CloudProvider cloud_provider = 6; + CloudProvider cloud_provider = 6; // Deprecated at some point uint32 timeout_secs = 7; @@ -2236,6 +2238,8 @@ message Sandbox { // Used to pin gVisor version for memory-snapshottable sandboxes. // This field is set by the server, not the client. optional uint32 snapshot_version = 25; + + string cloud_provider_str = 26; // Supersedes cloud_provider } message SandboxCreateRequest { diff --git a/test/function_test.py b/test/function_test.py index b200404dc..285c596af 100644 --- a/test/function_test.py +++ b/test/function_test.py @@ -778,6 +778,7 @@ def test_default_cloud_provider(client, servicer, monkeypatch): f = servicer.app_functions[object_id] assert f.cloud_provider == api_pb2.CLOUD_PROVIDER_OCI + assert f.cloud_provider_str == "OCI" def test_not_hydrated(): diff --git a/test/gpu_test.py b/test/gpu_test.py index 869e7cd7e..b9427285f 100644 --- a/test/gpu_test.py +++ b/test/gpu_test.py @@ -86,6 +86,7 @@ def test_cloud_provider_selection(client, servicer): assert len(servicer.app_functions) == 1 func_def = next(iter(servicer.app_functions.values())) assert func_def.cloud_provider == api_pb2.CLOUD_PROVIDER_GCP + assert func_def.cloud_provider_str == "GCP" assert func_def.resources.gpu_config.count == 1 assert func_def.resources.gpu_config.type == api_pb2.GPU_TYPE_A100 diff --git a/test/image_test.py b/test/image_test.py index ddf0c0022..cd4f91c4e 100644 --- a/test/image_test.py +++ b/test/image_test.py @@ -645,6 +645,7 @@ def test_image_run_function_with_cloud_selection(servicer, client): assert len(servicer.app_functions) == 2 func_def = next(iter(servicer.app_functions.values())) assert func_def.cloud_provider == api_pb2.CLOUD_PROVIDER_OCI + assert func_def.cloud_provider_str == "OCI" def test_poetry(builder_version, servicer, client):