Skip to content

Commit

Permalink
fix: make gateway load balancer should stream results (#6024)
Browse files Browse the repository at this point in the history
Co-authored-by: Jina Dev Bot <[email protected]>
Co-authored-by: Joan Fontanals <[email protected]>
  • Loading branch information
3 people authored Aug 9, 2023
1 parent 36e67f5 commit 0a33cde
Show file tree
Hide file tree
Showing 3 changed files with 117 additions and 24 deletions.
29 changes: 22 additions & 7 deletions jina/serve/runtimes/gateway/request_handling.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@

if TYPE_CHECKING: # pragma: no cover
from types import SimpleNamespace

import grpc

from jina.logging.logger import JinaLogger
Expand Down Expand Up @@ -185,12 +186,23 @@ async def _load_balance(self, request):
async with aiohttp.ClientSession() as session:
if request.method == 'GET':
async with session.get(target_url) as response:
content = await response.read()
return web.Response(
body=content,
# Create a StreamResponse with the same headers and status as the target response
stream_response = web.StreamResponse(
status=response.status,
content_type=response.content_type,
headers=response.headers,
)

# Prepare the response to send headers
await stream_response.prepare(request)

# Stream the response from the target server to the client
async for chunk in response.content.iter_any():
await stream_response.write(chunk)

# Close the stream response once all chunks are sent
await stream_response.write_eof()
return stream_response

elif request.method == 'POST':
d = await request.read()
import json
Expand Down Expand Up @@ -282,7 +294,7 @@ async def stream(
yield resp

async def stream_doc(
self, request: SingleDocumentRequest, context: 'grpc.aio.ServicerContext'
self, request: SingleDocumentRequest, context: 'grpc.aio.ServicerContext'
) -> SingleDocumentRequest:
"""
Process the received requests and return the result as a new request
Expand All @@ -293,7 +305,7 @@ async def stream_doc(
"""
self.logger.debug('recv a stream_doc request')
async for result in self.streamer.rpc_stream_doc(
request=request,
request=request,
):
yield result

Expand All @@ -317,6 +329,7 @@ async def endpoint_discovery(self, empty, context) -> jina_pb2.EndpointsProto:
:returns: the response request
"""
from google.protobuf import json_format

self.logger.debug('got an endpoint discovery request')
response = jina_pb2.EndpointsProto()
await self.streamer._get_endpoints_input_output_models(is_cancel=None)
Expand All @@ -332,7 +345,9 @@ async def endpoint_discovery(self, empty, context) -> jina_pb2.EndpointsProto:
response.endpoints.extend(schema_maps.keys())
json_format.ParseDict(schema_maps, response.schemas)
else:
endpoints = await self.streamer.topology_graph._get_all_endpoints(self.streamer._connection_pool, retry_forever=True, is_cancel=None)
endpoints = await self.streamer.topology_graph._get_all_endpoints(
self.streamer._connection_pool, retry_forever=True, is_cancel=None
)
response.endpoints.extend(list(endpoints))
return response

Expand Down
66 changes: 54 additions & 12 deletions tests/integration/docarray_v2/test_streaming.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
import asyncio
import time
from typing import AsyncGenerator, Generator, Optional

import pytest
from docarray import BaseDoc, DocList

from jina import Client, Executor, requests, Flow, Deployment
from docarray import DocList, BaseDoc
from jina import Client, Deployment, Executor, Flow, requests
from jina.helper import random_port


Expand Down Expand Up @@ -67,20 +69,21 @@ async def test_streaming_deployment(protocol, include_gateway):
assert doc.text == f'hello world {i}'
i += 1


@pytest.mark.asyncio
@pytest.mark.parametrize('protocol', ['http', 'grpc'])
async def test_streaming_flow(protocol):
port = random_port()

with Flow(protocol=protocol, port=port, cors=True).add(
uses=MyExecutor,
uses=MyExecutor,
):
client = Client(port=port, protocol=protocol, asyncio=True)
i = 10
async for doc in client.stream_doc(
on='/hello',
inputs=MyDocument(text='hello world', number=i),
return_type=MyDocument,
on='/hello',
inputs=MyDocument(text='hello world', number=i),
return_type=MyDocument,
):
assert doc.text == f'hello world {i}'
i += 1
Expand Down Expand Up @@ -111,27 +114,66 @@ async def test_streaming_custom_response(protocol, endpoint, include_gateway):
i += 1


class WaitStreamExecutor(Executor):
@requests(on='/hello')
async def task(self, doc: MyDocument, **kwargs) -> MyDocument:
for i in range(5):
yield MyDocument(text=f'{doc.text} {doc.number + i}')
await asyncio.sleep(0.5)


@pytest.mark.asyncio
@pytest.mark.parametrize('protocol', ['http', 'grpc'])
@pytest.mark.parametrize('include_gateway', [False, True])
async def test_streaming_delay(protocol, include_gateway):
from jina import Deployment

port = random_port()

with Deployment(
uses=WaitStreamExecutor,
timeout_ready=-1,
protocol=protocol,
port=port,
include_gateway=include_gateway,
):
client = Client(port=port, protocol=protocol, asyncio=True)
i = 0
start_time = time.time()
async for doc in client.stream_doc(
on='/hello',
inputs=MyDocument(text='hello world', number=i),
return_type=MyDocument,
):
assert doc.text == f'hello world {i}'
i += 1

# 0.5 seconds between each request + 0.5 seconds tolerance interval
assert time.time() - start_time < (0.5 * i) + 0.5


@pytest.mark.asyncio
@pytest.mark.parametrize('protocol', ['http', 'grpc'])
@pytest.mark.parametrize('endpoint', ['task1', 'task2', 'task3'])
async def test_streaming_custom_response_flow_one_executor(protocol, endpoint):
port = random_port()

with Flow(
protocol=protocol,
cors=True,
port=port,
protocol=protocol,
cors=True,
port=port,
).add(uses=CustomResponseExecutor):
client = Client(port=port, protocol=protocol, cors=True, asyncio=True)
i = 0
async for doc in client.stream_doc(
on=f'/{endpoint}',
inputs=MyDocument(text='hello world', number=5),
return_type=OutputDocument,
on=f'/{endpoint}',
inputs=MyDocument(text='hello world', number=5),
return_type=OutputDocument,
):
assert doc.text == f'hello world 5-{i}-{endpoint}'
i += 1


class Executor1(Executor):
@requests
def generator(self, doc: MyDocument, **kwargs) -> MyDocument:
Expand Down
46 changes: 41 additions & 5 deletions tests/integration/streaming/test_streaming.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
import asyncio
import time

import pytest

from jina import Client, Deployment, Executor, requests
Expand Down Expand Up @@ -29,11 +32,10 @@ async def test_streaming_deployment(protocol, include_gateway):
uses=MyExecutor,
timeout_ready=-1,
protocol=protocol,
cors=True,
port=port,
include_gateway=include_gateway,
):
client = Client(port=port, protocol=protocol, cors=True, asyncio=True)
client = Client(port=port, protocol=protocol, asyncio=True)
i = 0
async for doc in client.stream_doc(
on='/hello', inputs=Document(text='hello world')
Expand All @@ -42,6 +44,42 @@ async def test_streaming_deployment(protocol, include_gateway):
i += 1


class WaitStreamExecutor(Executor):
@requests(on='/hello')
async def task(self, doc: Document, **kwargs):
for i in range(5):
yield Document(text=f'{doc.text} {i}')
await asyncio.sleep(0.5)


@pytest.mark.asyncio
@pytest.mark.parametrize('protocol', ['http', 'grpc'])
@pytest.mark.parametrize('include_gateway', [False, True])
async def test_streaming_delay(protocol, include_gateway):
from jina import Deployment

port = random_port()

with Deployment(
uses=WaitStreamExecutor,
timeout_ready=-1,
protocol=protocol,
port=port,
include_gateway=include_gateway,
):
client = Client(port=port, protocol=protocol, asyncio=True)
i = 0
start_time = time.time()
async for doc in client.stream_doc(
on='/hello', inputs=Document(text='hello world')
):
assert doc.text == f'hello world {i}'
i += 1

# 0.5 seconds between each request + 0.5 seconds tolerance interval
assert time.time() - start_time < (0.5 * i) + 0.5


@pytest.mark.asyncio
@pytest.mark.parametrize('protocol', ['grpc'])
async def test_streaming_client_non_gen_endpoint(protocol):
Expand All @@ -53,11 +91,10 @@ async def test_streaming_client_non_gen_endpoint(protocol):
uses=MyExecutor,
timeout_ready=-1,
protocol=protocol,
cors=True,
port=port,
include_gateway=False,
):
client = Client(port=port, protocol=protocol, cors=True, asyncio=True)
client = Client(port=port, protocol=protocol, asyncio=True)
i = 0
with pytest.raises(BadServer):
async for _ in client.stream_doc(
Expand All @@ -67,7 +104,6 @@ async def test_streaming_client_non_gen_endpoint(protocol):


def test_invalid_executor():

with pytest.raises(RuntimeError) as exc_info:

class InvalidExecutor3(Executor):
Expand Down

0 comments on commit 0a33cde

Please sign in to comment.