diff --git a/modal/functions.py b/modal/functions.py index b6069c255..bcab78ba9 100644 --- a/modal/functions.py +++ b/modal/functions.py @@ -828,14 +828,17 @@ async def _load(self: _Function, resolver: Resolver, existing_object_id: Optiona scheduler_placement=scheduler_placement.proto if scheduler_placement else None, is_class=info.is_service_class(), class_parameter_info=info.class_parameter_info(), - _experimental_resources=[ - convert_fn_config_to_resources_config( - cpu=cpu, memory=memory, gpu=_experimental_gpu, ephemeral_disk=ephemeral_disk + i6pn_enabled=config.get("i6pn_enabled"), + _experimental_concurrent_cancellations=True, + _experimental_task_templates=[ + api_pb2.TaskTemplate( + priority=1, + resources=convert_fn_config_to_resources_config( + cpu=cpu, memory=memory, gpu=_experimental_gpu, ephemeral_disk=ephemeral_disk + ), ) for _experimental_gpu in _experimental_gpus ], - i6pn_enabled=config.get("i6pn_enabled"), - _experimental_concurrent_cancellations=True, ) assert resolver.app_id request = api_pb2.FunctionCreateRequest( diff --git a/modal/sandbox.py b/modal/sandbox.py index aed5c8df4..cb2381884 100644 --- a/modal/sandbox.py +++ b/modal/sandbox.py @@ -128,12 +128,6 @@ async def _load(self: _Sandbox, resolver: Resolver, _existing_object_id: Optiona volume_mounts=volume_mounts, pty_info=pty_info, scheduler_placement=scheduler_placement.proto if scheduler_placement else None, - _experimental_resources=[ - convert_fn_config_to_resources_config( - cpu=cpu, memory=memory, gpu=_experimental_gpu, ephemeral_disk=ephemeral_disk - ) - for _experimental_gpu in _experimental_gpus - ], worker_id=config.get("worker_id"), ) diff --git a/modal_proto/api.proto b/modal_proto/api.proto index 2cc536bcf..e2b9e83ab 100644 --- a/modal_proto/api.proto +++ b/modal_proto/api.proto @@ -949,7 +949,6 @@ message FileEntry { uint64 size = 4; } - message Function { string module_name = 1; string function_name = 2; @@ -1051,8 +1050,7 @@ message Function { ClassParameterInfo class_parameter_info = 56; - repeated Resources _experimental_resources = 57; // overrides `resources` field above - + reserved 57; // _experimental_resources reserved 58; reserved 59; uint32 batch_max_size = 60; // Maximum number of inputs to fetch at once @@ -1060,6 +1058,9 @@ message Function { bool i6pn_enabled = 62; bool _experimental_concurrent_cancellations = 63; uint32 max_concurrent_inputs = 64; + + bool _experimental_task_templates_enabled = 65; // forces going through the new gpu-fallbacks integration path, even if no fallback options are specified + repeated TaskTemplate _experimental_task_templates = 66; // for fallback options, where the first/most-preferred "template" is derived from fields above } message FunctionBindParamsRequest { @@ -1769,7 +1770,7 @@ message Sandbox { // to look at fine-grained placement constraints. reserved 16; // _experimental_scheduler optional SchedulerPlacement scheduler_placement = 17; - repeated Resources _experimental_resources = 18; // overrides `resources` field above + reserved 18; // _experimental_resources string worker_id = 19; // for internal debugging use only oneof open_ports_oneof { @@ -2079,6 +2080,11 @@ message TaskStats { double started_at = 4; } +message TaskTemplate { + uint32 priority = 1; + Resources resources = 2; + uint32 concurrent_inputs = 3; +} message TokenFlowCreateRequest { string utm_source = 3; diff --git a/test/experimental_gpus_test.py b/test/experimental_gpus_test.py index 46301391a..d4b244d0f 100644 --- a/test/experimental_gpus_test.py +++ b/test/experimental_gpus_test.py @@ -1,10 +1,8 @@ # Copyright Modal Labs 2024 import modal -from modal import App, Sandbox +from modal import App from modal_proto import api_pb2 -from .sandbox_test import skip_non_linux - app = App() @@ -23,7 +21,7 @@ def f3(): pass -def test_experimental_resources(servicer, client): +def test_experimental_task_templates(servicer, client): with app.run(client=client): assert len(servicer.app_functions) == 3 @@ -54,54 +52,21 @@ def test_experimental_resources(servicer, client): ) fn1 = servicer.app_functions["fu-1"] # f1 - assert len(fn1._experimental_resources) == 1 - assert fn1._experimental_resources[0].gpu_config.type == a10_1.gpu_config.type - assert fn1._experimental_resources[0].gpu_config.count == a10_1.gpu_config.count + assert len(fn1._experimental_task_templates) == 1 + assert fn1._experimental_task_templates[0].resources.gpu_config.type == a10_1.gpu_config.type + assert fn1._experimental_task_templates[0].resources.gpu_config.count == a10_1.gpu_config.count fn2 = servicer.app_functions["fu-2"] # f2 - assert len(fn2._experimental_resources) == 2 - assert fn2._experimental_resources[0].gpu_config.type == a10_1.gpu_config.type - assert fn2._experimental_resources[0].gpu_config.count == a10_1.gpu_config.count - assert fn2._experimental_resources[1].gpu_config.type == t4_2.gpu_config.type - assert fn2._experimental_resources[1].gpu_config.count == t4_2.gpu_config.count + assert len(fn2._experimental_task_templates) == 2 + assert fn2._experimental_task_templates[0].resources.gpu_config.type == a10_1.gpu_config.type + assert fn2._experimental_task_templates[0].resources.gpu_config.count == a10_1.gpu_config.count + assert fn2._experimental_task_templates[1].resources.gpu_config.type == t4_2.gpu_config.type + assert fn2._experimental_task_templates[1].resources.gpu_config.count == t4_2.gpu_config.count fn3 = servicer.app_functions["fu-3"] # f3 - assert len(fn3._experimental_resources) == 2 - assert fn3._experimental_resources[0].gpu_config.type == h100_2.gpu_config.type - assert fn3._experimental_resources[0].gpu_config.count == h100_2.gpu_config.count - assert fn3._experimental_resources[1].gpu_config.type == a100_80gb_2.gpu_config.type - assert fn3._experimental_resources[1].gpu_config.count == a100_80gb_2.gpu_config.count - assert fn3._experimental_resources[1].gpu_config.memory == a100_80gb_2.gpu_config.memory - - -@skip_non_linux -def test_sandbox_experimental_resources(client, servicer): - Sandbox.create( - "bash", - "-c", - "echo bye >&2 && sleep 1 && echo hi && exit 42", - timeout=600, - _experimental_gpus=["a10g:2", "t4:4"], - client=client, - ) - - a10_2 = api_pb2.Resources( - gpu_config=api_pb2.GPUConfig( - type=api_pb2.GPU_TYPE_A10G, - count=2, - ) - ) - t4_4 = api_pb2.Resources( - gpu_config=api_pb2.GPUConfig( - type=api_pb2.GPU_TYPE_T4, - count=4, - ) - ) - - assert len(servicer.sandbox_defs) == 1 - sb_def = servicer.sandbox_defs[0] - assert len(sb_def._experimental_resources) == 2 - assert sb_def._experimental_resources[0].gpu_config.type == a10_2.gpu_config.type - assert sb_def._experimental_resources[0].gpu_config.count == a10_2.gpu_config.count - assert sb_def._experimental_resources[1].gpu_config.type == t4_4.gpu_config.type - assert sb_def._experimental_resources[1].gpu_config.count == t4_4.gpu_config.count + assert len(fn3._experimental_task_templates) == 2 + assert fn3._experimental_task_templates[0].resources.gpu_config.type == h100_2.gpu_config.type + assert fn3._experimental_task_templates[0].resources.gpu_config.count == h100_2.gpu_config.count + assert fn3._experimental_task_templates[1].resources.gpu_config.type == a100_80gb_2.gpu_config.type + assert fn3._experimental_task_templates[1].resources.gpu_config.count == a100_80gb_2.gpu_config.count + assert fn3._experimental_task_templates[1].resources.gpu_config.memory == a100_80gb_2.gpu_config.memory