diff --git a/modal/functions.py b/modal/functions.py index 9a00e635b..b6069c255 100644 --- a/modal/functions.py +++ b/modal/functions.py @@ -107,7 +107,14 @@ def __init__(self, stub: api_grpc.ModalClientStub, function_call_id: str, client self.function_call_id = function_call_id # TODO: remove and use only input_id @staticmethod - async def create(function: "_Function", args, kwargs, *, client: _Client) -> "_Invocation": + async def create( + function: "_Function", + args, + kwargs, + *, + client: _Client, + function_call_invocation_type: "api_pb2.FunctionCallInvocationType.ValueType", + ) -> "_Invocation": assert client.stub function_id = function._invocation_function_id() item = await _create_input(args, kwargs, client, method_name=function._use_method_name) @@ -117,6 +124,7 @@ async def create(function: "_Function", args, kwargs, *, client: _Client) -> "_I parent_input_id=current_input_id() or "", function_call_type=api_pb2.FUNCTION_CALL_TYPE_UNARY, pipelined_inputs=[item], + function_call_invocation_type=function_call_invocation_type, ) response = await retry_transient_errors(client.stub.FunctionMap, request) function_call_id = response.function_call_id @@ -1185,7 +1193,13 @@ async def _map( yield item async def _call_function(self, args, kwargs) -> R: - invocation = await _Invocation.create(self, args, kwargs, client=self._client) + invocation = await _Invocation.create( + self, + args, + kwargs, + client=self._client, + function_call_invocation_type=api_pb2.FUNCTION_CALL_INVOCATION_TYPE_SYNC_LEGACY, + ) try: return await invocation.run_function() except asyncio.CancelledError: @@ -1196,19 +1210,37 @@ async def _call_function(self, args, kwargs) -> R: return # type: ignore async def _call_function_nowait(self, args, kwargs) -> _Invocation: - return await _Invocation.create(self, args, kwargs, client=self._client) + return await _Invocation.create( + self, + args, + kwargs, + client=self._client, + function_call_invocation_type=api_pb2.FUNCTION_CALL_INVOCATION_TYPE_ASYNC_LEGACY, + ) @warn_if_generator_is_not_consumed() @live_method_gen @synchronizer.no_input_translation async def _call_generator(self, args, kwargs): - invocation = await _Invocation.create(self, args, kwargs, client=self._client) + invocation = await _Invocation.create( + self, + args, + kwargs, + client=self._client, + function_call_invocation_type=api_pb2.FUNCTION_CALL_INVOCATION_TYPE_SYNC_LEGACY, + ) async for res in invocation.run_generator(): yield res @synchronizer.no_io_translation async def _call_generator_nowait(self, args, kwargs): - return await _Invocation.create(self, args, kwargs, client=self._client) + return await _Invocation.create( + self, + args, + kwargs, + client=self._client, + function_call_invocation_type=api_pb2.FUNCTION_CALL_INVOCATION_TYPE_ASYNC_LEGACY, + ) @synchronizer.no_io_translation @live_method diff --git a/modal_proto/api.proto b/modal_proto/api.proto index fef96710e..2cc536bcf 100644 --- a/modal_proto/api.proto +++ b/modal_proto/api.proto @@ -118,6 +118,12 @@ enum FileDescriptor { FILE_DESCRIPTOR_INFO = 3; } +enum FunctionCallInvocationType { + FUNCTION_CALL_INVOCATION_TYPE_UNSPECIFIED = 0; + FUNCTION_CALL_INVOCATION_TYPE_SYNC_LEGACY = 1; + FUNCTION_CALL_INVOCATION_TYPE_ASYNC_LEGACY = 2; +} + enum FunctionCallType { FUNCTION_CALL_TYPE_UNSPECIFIED = 0; FUNCTION_CALL_TYPE_UNARY = 1; @@ -1259,6 +1265,7 @@ message FunctionMapRequest { bool return_exceptions = 3; FunctionCallType function_call_type = 4; repeated FunctionPutInputsItem pipelined_inputs = 5; + FunctionCallInvocationType function_call_invocation_type = 6; } message FunctionMapResponse {