Skip to content

Commit

Permalink
Add proto definitions for GPU fallbacks (#2182)
Browse files Browse the repository at this point in the history
  • Loading branch information
irfansharif authored Sep 3, 2024
1 parent c1abcd8 commit 1a1dbe2
Show file tree
Hide file tree
Showing 4 changed files with 34 additions and 66 deletions.
13 changes: 8 additions & 5 deletions modal/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -828,14 +828,17 @@ async def _load(self: _Function, resolver: Resolver, existing_object_id: Optiona
scheduler_placement=scheduler_placement.proto if scheduler_placement else None,
is_class=info.is_service_class(),
class_parameter_info=info.class_parameter_info(),
_experimental_resources=[
convert_fn_config_to_resources_config(
cpu=cpu, memory=memory, gpu=_experimental_gpu, ephemeral_disk=ephemeral_disk
i6pn_enabled=config.get("i6pn_enabled"),
_experimental_concurrent_cancellations=True,
_experimental_task_templates=[
api_pb2.TaskTemplate(
priority=1,
resources=convert_fn_config_to_resources_config(
cpu=cpu, memory=memory, gpu=_experimental_gpu, ephemeral_disk=ephemeral_disk
),
)
for _experimental_gpu in _experimental_gpus
],
i6pn_enabled=config.get("i6pn_enabled"),
_experimental_concurrent_cancellations=True,
)
assert resolver.app_id
request = api_pb2.FunctionCreateRequest(
Expand Down
6 changes: 0 additions & 6 deletions modal/sandbox.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,12 +128,6 @@ async def _load(self: _Sandbox, resolver: Resolver, _existing_object_id: Optiona
volume_mounts=volume_mounts,
pty_info=pty_info,
scheduler_placement=scheduler_placement.proto if scheduler_placement else None,
_experimental_resources=[
convert_fn_config_to_resources_config(
cpu=cpu, memory=memory, gpu=_experimental_gpu, ephemeral_disk=ephemeral_disk
)
for _experimental_gpu in _experimental_gpus
],
worker_id=config.get("worker_id"),
)

Expand Down
14 changes: 10 additions & 4 deletions modal_proto/api.proto
Original file line number Diff line number Diff line change
Expand Up @@ -949,7 +949,6 @@ message FileEntry {
uint64 size = 4;
}


message Function {
string module_name = 1;
string function_name = 2;
Expand Down Expand Up @@ -1051,15 +1050,17 @@ message Function {

ClassParameterInfo class_parameter_info = 56;

repeated Resources _experimental_resources = 57; // overrides `resources` field above

reserved 57; // _experimental_resources
reserved 58;
reserved 59;
uint32 batch_max_size = 60; // Maximum number of inputs to fetch at once
uint64 batch_linger_ms = 61; // Miliseconds to block before a response is needed
bool i6pn_enabled = 62;
bool _experimental_concurrent_cancellations = 63;
uint32 max_concurrent_inputs = 64;

bool _experimental_task_templates_enabled = 65; // forces going through the new gpu-fallbacks integration path, even if no fallback options are specified
repeated TaskTemplate _experimental_task_templates = 66; // for fallback options, where the first/most-preferred "template" is derived from fields above
}

message FunctionBindParamsRequest {
Expand Down Expand Up @@ -1769,7 +1770,7 @@ message Sandbox {
// to look at fine-grained placement constraints.
reserved 16; // _experimental_scheduler
optional SchedulerPlacement scheduler_placement = 17;
repeated Resources _experimental_resources = 18; // overrides `resources` field above
reserved 18; // _experimental_resources

string worker_id = 19; // for internal debugging use only
oneof open_ports_oneof {
Expand Down Expand Up @@ -2079,6 +2080,11 @@ message TaskStats {
double started_at = 4;
}

message TaskTemplate {
uint32 priority = 1;
Resources resources = 2;
uint32 concurrent_inputs = 3;
}

message TokenFlowCreateRequest {
string utm_source = 3;
Expand Down
67 changes: 16 additions & 51 deletions test/experimental_gpus_test.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,8 @@
# Copyright Modal Labs 2024
import modal
from modal import App, Sandbox
from modal import App
from modal_proto import api_pb2

from .sandbox_test import skip_non_linux

app = App()


Expand All @@ -23,7 +21,7 @@ def f3():
pass


def test_experimental_resources(servicer, client):
def test_experimental_task_templates(servicer, client):
with app.run(client=client):
assert len(servicer.app_functions) == 3

Expand Down Expand Up @@ -54,54 +52,21 @@ def test_experimental_resources(servicer, client):
)

fn1 = servicer.app_functions["fu-1"] # f1
assert len(fn1._experimental_resources) == 1
assert fn1._experimental_resources[0].gpu_config.type == a10_1.gpu_config.type
assert fn1._experimental_resources[0].gpu_config.count == a10_1.gpu_config.count
assert len(fn1._experimental_task_templates) == 1
assert fn1._experimental_task_templates[0].resources.gpu_config.type == a10_1.gpu_config.type
assert fn1._experimental_task_templates[0].resources.gpu_config.count == a10_1.gpu_config.count

fn2 = servicer.app_functions["fu-2"] # f2
assert len(fn2._experimental_resources) == 2
assert fn2._experimental_resources[0].gpu_config.type == a10_1.gpu_config.type
assert fn2._experimental_resources[0].gpu_config.count == a10_1.gpu_config.count
assert fn2._experimental_resources[1].gpu_config.type == t4_2.gpu_config.type
assert fn2._experimental_resources[1].gpu_config.count == t4_2.gpu_config.count
assert len(fn2._experimental_task_templates) == 2
assert fn2._experimental_task_templates[0].resources.gpu_config.type == a10_1.gpu_config.type
assert fn2._experimental_task_templates[0].resources.gpu_config.count == a10_1.gpu_config.count
assert fn2._experimental_task_templates[1].resources.gpu_config.type == t4_2.gpu_config.type
assert fn2._experimental_task_templates[1].resources.gpu_config.count == t4_2.gpu_config.count

fn3 = servicer.app_functions["fu-3"] # f3
assert len(fn3._experimental_resources) == 2
assert fn3._experimental_resources[0].gpu_config.type == h100_2.gpu_config.type
assert fn3._experimental_resources[0].gpu_config.count == h100_2.gpu_config.count
assert fn3._experimental_resources[1].gpu_config.type == a100_80gb_2.gpu_config.type
assert fn3._experimental_resources[1].gpu_config.count == a100_80gb_2.gpu_config.count
assert fn3._experimental_resources[1].gpu_config.memory == a100_80gb_2.gpu_config.memory


@skip_non_linux
def test_sandbox_experimental_resources(client, servicer):
Sandbox.create(
"bash",
"-c",
"echo bye >&2 && sleep 1 && echo hi && exit 42",
timeout=600,
_experimental_gpus=["a10g:2", "t4:4"],
client=client,
)

a10_2 = api_pb2.Resources(
gpu_config=api_pb2.GPUConfig(
type=api_pb2.GPU_TYPE_A10G,
count=2,
)
)
t4_4 = api_pb2.Resources(
gpu_config=api_pb2.GPUConfig(
type=api_pb2.GPU_TYPE_T4,
count=4,
)
)

assert len(servicer.sandbox_defs) == 1
sb_def = servicer.sandbox_defs[0]
assert len(sb_def._experimental_resources) == 2
assert sb_def._experimental_resources[0].gpu_config.type == a10_2.gpu_config.type
assert sb_def._experimental_resources[0].gpu_config.count == a10_2.gpu_config.count
assert sb_def._experimental_resources[1].gpu_config.type == t4_4.gpu_config.type
assert sb_def._experimental_resources[1].gpu_config.count == t4_4.gpu_config.count
assert len(fn3._experimental_task_templates) == 2
assert fn3._experimental_task_templates[0].resources.gpu_config.type == h100_2.gpu_config.type
assert fn3._experimental_task_templates[0].resources.gpu_config.count == h100_2.gpu_config.count
assert fn3._experimental_task_templates[1].resources.gpu_config.type == a100_80gb_2.gpu_config.type
assert fn3._experimental_task_templates[1].resources.gpu_config.count == a100_80gb_2.gpu_config.count
assert fn3._experimental_task_templates[1].resources.gpu_config.memory == a100_80gb_2.gpu_config.memory

0 comments on commit 1a1dbe2

Please sign in to comment.