Skip to content

Commit

Permalink
[Serve] use RequestProtocol to determine the protocol (ray-project#38471
Browse files Browse the repository at this point in the history
)

- Added `RequestProtocol` enum
- Removed `_is_for_http_requests` boolean flag and replaced with `_request_protocol` on the serve handle and `RequestMetadata`
- Renamed `proxy_name` on the proxy to `protocol`
- Fixed a small bug metrics not logged properly when chaining handles (by passing metrics to the handle chain)

Signed-off-by: e428265 <[email protected]>
  • Loading branch information
GeneDer authored and arvind-chandra committed Aug 31, 2023
1 parent b8f0b70 commit de409c6
Show file tree
Hide file tree
Showing 7 changed files with 133 additions and 42 deletions.
15 changes: 2 additions & 13 deletions python/ray/serve/_private/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -478,7 +478,6 @@ def get_handle(
app_name: Optional[str] = "default",
missing_ok: Optional[bool] = False,
sync: bool = True,
_is_for_http_requests: bool = False,
) -> Union[RayServeHandle, RayServeSyncHandle]:
"""Retrieve RayServeHandle for service deployment to invoke it from Python.
Expand All @@ -489,8 +488,6 @@ def get_handle(
sync: If true, then Serve will return a ServeHandle that
works everywhere. Otherwise, Serve will return a ServeHandle
that's only usable in asyncio loop.
_is_for_http_requests: Indicates that this handle will be used
to send HTTP requests from the proxy to ingress deployment replicas.
Returns:
RayServeHandle
Expand All @@ -510,17 +507,9 @@ def get_handle(
)

if sync:
handle = RayServeSyncHandle(
deployment_name,
app_name,
_is_for_http_requests=_is_for_http_requests,
)
handle = RayServeSyncHandle(deployment_name, app_name)
else:
handle = RayServeHandle(
deployment_name,
app_name,
_is_for_http_requests=_is_for_http_requests,
)
handle = RayServeHandle(deployment_name, app_name)

self.handle_cache[cache_key] = handle
if cache_key in self._evicted_handle_keys:
Expand Down
6 changes: 6 additions & 0 deletions python/ray/serve/_private/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -452,3 +452,9 @@ class StreamingHTTPRequest:

pickled_asgi_scope: bytes
http_proxy_handle: ActorHandle


class RequestProtocol(str, Enum):
UNDEFINED = "UNDEFINED"
HTTP = "HTTP"
GRPC = "gRPC"
31 changes: 16 additions & 15 deletions python/ray/serve/_private/http_proxy.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
EndpointInfo,
EndpointTag,
NodeId,
RequestProtocol,
StreamingHTTPRequest,
)
from ray.serve._private.constants import (
Expand Down Expand Up @@ -212,7 +213,7 @@ class GenericProxy(ABC):
It contains all the common setup and methods required for running a proxy.
The proxy subclass need to implement the following methods:
- `proxy_name()`
- `protocol()`
- `not_found()`
- `draining_response()`
- `timeout_response()`
Expand Down Expand Up @@ -260,7 +261,6 @@ def get_handle(deployment_name, app_name):
app_name,
sync=False,
missing_ok=True,
_is_for_http_requests=True,
)

self.prefix_router = LongestPrefixRouter(get_handle)
Expand All @@ -272,14 +272,14 @@ def get_handle(deployment_name, app_name):
call_in_event_loop=get_or_create_event_loop(),
)
self.request_counter = metrics.Counter(
f"serve_num_{self.proxy_name.lower()}_requests",
description=f"The number of {self.proxy_name} requests processed.",
f"serve_num_{self.protocol.lower()}_requests",
description=f"The number of {self.protocol} requests processed.",
tag_keys=("route", "method", "application", "status_code"),
)

self.request_error_counter = metrics.Counter(
f"serve_num_{self.proxy_name.lower()}_error_requests",
description=f"The number of non-200 {self.proxy_name} responses.",
f"serve_num_{self.protocol.lower()}_error_requests",
description=f"The number of non-200 {self.protocol} responses.",
tag_keys=(
"route",
"error_code",
Expand All @@ -288,9 +288,9 @@ def get_handle(deployment_name, app_name):
)

self.deployment_request_error_counter = metrics.Counter(
f"serve_num_deployment_{self.proxy_name.lower()}_error_requests",
f"serve_num_deployment_{self.protocol.lower()}_error_requests",
description=(
f"The number of non-200 {self.proxy_name} responses returned by "
f"The number of non-200 {self.protocol} responses returned by "
"each deployment."
),
tag_keys=(
Expand All @@ -303,10 +303,10 @@ def get_handle(deployment_name, app_name):
)

self.processing_latency_tracker = metrics.Histogram(
f"serve_{self.proxy_name.lower()}_request_latency_ms",
f"serve_{self.protocol.lower()}_request_latency_ms",
description=(
f"The end-to-end latency of {self.proxy_name} requests "
f"(measured from the Serve {self.proxy_name} proxy)."
f"The end-to-end latency of {self.protocol} requests "
f"(measured from the Serve {self.protocol} proxy)."
),
boundaries=DEFAULT_LATENCY_BUCKET_MS,
tag_keys=(
Expand Down Expand Up @@ -338,8 +338,8 @@ def get_handle(deployment_name, app_name):

@property
@abstractmethod
def proxy_name(self) -> str:
"""Proxy name used for metrics.
def protocol(self) -> RequestProtocol:
"""Protocol used for metrics.
Each proxy needs to implement its own logic for setting up the proxy name.
"""
Expand Down Expand Up @@ -667,8 +667,8 @@ class HTTPProxy(GenericProxy):
"""

@property
def proxy_name(self) -> str:
return "HTTP"
def protocol(self) -> RequestProtocol:
return RequestProtocol.HTTP

async def not_found(self, scope, receive, send):
current_path = scope["path"]
Expand Down Expand Up @@ -911,6 +911,7 @@ def setup_request_context_and_handle(
Unpack HTTP request headers and extract info to set up request context and
handle.
"""
handle._set_request_protocol(RequestProtocol.HTTP)
request_context_info = {
"route": route_path,
"app_name": app_name,
Expand Down
21 changes: 17 additions & 4 deletions python/ray/serve/_private/router.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,12 @@
from ray.util import metrics
from ray._private.utils import make_asyncio_event_version_compat, load_class

from ray.serve._private.common import RunningReplicaInfo, DeploymentInfo, DeploymentID
from ray.serve._private.common import (
DeploymentID,
DeploymentInfo,
RequestProtocol,
RunningReplicaInfo,
)
from ray.serve._private.constants import (
SERVE_LOGGER_NAME,
HANDLE_METRIC_PUSH_INTERVAL_S,
Expand Down Expand Up @@ -61,9 +66,6 @@ class RequestMetadata:
endpoint: str
call_method: str = "__call__"

# This flag is set if the request is made from the HTTP proxy to a replica.
is_http_request: bool = False

# HTTP route path of the request.
route: str = ""

Expand All @@ -76,6 +78,17 @@ class RequestMetadata:
# If this request expects a streaming response.
is_streaming: bool = False

# The protocol to serve this request
_request_protocol: RequestProtocol = RequestProtocol.UNDEFINED

@property
def is_http_request(self) -> bool:
return self._request_protocol == RequestProtocol.HTTP

@property
def is_grpc_request(self) -> bool:
return self._request_protocol == RequestProtocol.GRPC


@dataclass
class Query:
Expand Down
22 changes: 13 additions & 9 deletions python/ray/serve/handle.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import asyncio
import concurrent.futures
from dataclasses import dataclass
from dataclasses import dataclass, asdict
from functools import wraps
import inspect
import threading
Expand All @@ -10,7 +10,7 @@
from ray._private.utils import get_or_create_event_loop

from ray import serve
from ray.serve._private.common import EndpointTag
from ray.serve._private.common import EndpointTag, RequestProtocol
from ray.serve._private.constants import (
RAY_SERVE_ENABLE_NEW_ROUTING,
)
Expand Down Expand Up @@ -62,6 +62,7 @@ class _HandleOptions:
multiplexed_model_id: str = ""
stream: bool = False
_router_cls: str = ""
_request_protocol: str = RequestProtocol.UNDEFINED

def copy_and_update(
self,
Expand All @@ -83,6 +84,7 @@ def copy_and_update(
_router_cls=self._router_cls
if _router_cls == DEFAULT.VALUE
else _router_cls,
_request_protocol=self._request_protocol,
)


Expand Down Expand Up @@ -134,13 +136,12 @@ def __init__(
*,
handle_options: Optional[_HandleOptions] = None,
_router: Optional[Router] = None,
_is_for_http_requests: bool = False,
_request_counter: Optional[metrics.Counter] = None,
):
self.deployment_id = EndpointTag(deployment_name, app_name)
self.handle_options = handle_options or _HandleOptions()
self._is_for_http_requests = _is_for_http_requests

self.request_counter = metrics.Counter(
self.request_counter = _request_counter or metrics.Counter(
"serve_handle_request_counter",
description=(
"The number of handle.remote() calls that have been "
Expand All @@ -160,6 +161,11 @@ def __init__(

self._router: Optional[Router] = _router

def _set_request_protocol(self, request_protocol: RequestProtocol):
self.handle_options = _HandleOptions(
**{**asdict(self.handle_options), **{"_request_protocol": request_protocol}}
)

def _get_or_create_router(self) -> Router:
if self._router is None:
self._router = Router(
Expand Down Expand Up @@ -211,7 +217,7 @@ def _options(
self.app_name,
handle_options=new_handle_options,
_router=None if _router_cls != DEFAULT.VALUE else self._router,
_is_for_http_requests=self._is_for_http_requests,
_request_counter=self.request_counter,
)

def options(
Expand Down Expand Up @@ -247,11 +253,11 @@ def _remote(self, deployment_id, handle_options, args, kwargs) -> Coroutine:
_request_context.request_id,
deployment_id.name,
call_method=handle_options.method_name,
is_http_request=self._is_for_http_requests,
route=_request_context.route,
app_name=_request_context.app_name,
multiplexed_model_id=handle_options.multiplexed_model_id,
is_streaming=handle_options.stream,
_request_protocol=handle_options._request_protocol,
)
self.request_counter.inc(
tags={
Expand Down Expand Up @@ -295,7 +301,6 @@ def __reduce__(self):
"deployment_name": self.deployment_name,
"app_name": self.app_name,
"handle_options": self.handle_options,
"_is_for_http_requests": self._is_for_http_requests,
}
return RayServeHandle._deserialize, (serialized_data,)

Expand Down Expand Up @@ -401,7 +406,6 @@ def __reduce__(self):
"deployment_name": self.deployment_name,
"app_name": self.app_name,
"handle_options": self.handle_options,
"_is_for_http_requests": self._is_for_http_requests,
}
return RayServeSyncHandle._deserialize, (serialized_data,)

Expand Down
40 changes: 40 additions & 0 deletions python/ray/serve/tests/test_handle.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,15 @@
RAY_SERVE_ENABLE_EXPERIMENTAL_STREAMING,
SERVE_DEFAULT_APP_NAME,
)
from ray.serve._private.common import RequestProtocol


def test_handle_options():
default_options = _HandleOptions()
assert default_options.method_name == "__call__"
assert default_options.multiplexed_model_id == ""
assert default_options.stream is False
assert default_options._request_protocol == RequestProtocol.UNDEFINED

# Test setting method name.
only_set_method = default_options.copy_and_update(method_name="hi")
Expand All @@ -33,6 +35,7 @@ def test_handle_options():
assert default_options.method_name == "__call__"
assert default_options.multiplexed_model_id == ""
assert default_options.stream is False
assert default_options._request_protocol == RequestProtocol.UNDEFINED

# Test setting model ID.
only_set_model_id = default_options.copy_and_update(multiplexed_model_id="hi")
Expand All @@ -44,6 +47,7 @@ def test_handle_options():
assert default_options.method_name == "__call__"
assert default_options.multiplexed_model_id == ""
assert default_options.stream is False
assert default_options._request_protocol == RequestProtocol.UNDEFINED

# Test setting stream.
only_set_stream = default_options.copy_and_update(stream=True)
Expand All @@ -55,12 +59,14 @@ def test_handle_options():
assert default_options.method_name == "__call__"
assert default_options.multiplexed_model_id == ""
assert default_options.stream is False
assert default_options._request_protocol == RequestProtocol.UNDEFINED

# Test setting multiple.
set_multiple = default_options.copy_and_update(method_name="hi", stream=True)
assert set_multiple.method_name == "hi"
assert set_multiple.multiplexed_model_id == ""
assert set_multiple.stream is True
assert default_options._request_protocol == RequestProtocol.UNDEFINED


@pytest.mark.asyncio
Expand Down Expand Up @@ -205,13 +211,16 @@ def __call__(self):
return "__call__"

handle1 = serve.run(MultiMethod.bind())
metrics = handle1.request_counter
assert ray.get(handle1.remote()) == "__call__"

handle2 = handle1.options(method_name="method_a")
assert ray.get(handle2.remote()) == "method_a"
assert handle2.request_counter == metrics

handle3 = handle1.options(method_name="method_b")
assert ray.get(handle3.remote()) == "method_b"
assert handle3.request_counter == metrics


def test_repeated_get_handle_cached(serve_instance):
Expand Down Expand Up @@ -382,6 +391,37 @@ def echo(name: str):
), handle2._router._replica_scheduler


def test_set_request_protocol(serve_instance):
"""Test setting request protocol for a handle.
When a handle is created, it's _request_protocol is undefined. When calling
`_set_request_protocol()`, _request_protocol is set to the specified protocol.
When chaining options, the _request_protocol on the new handle is copied over.
When calling `_set_request_protocol()` on the new handle, _request_protocol
on the new handle is changed accordingly, while _request_protocol on the
original handle remains unchanged.
"""

@serve.deployment
def echo(name: str):
return f"Hi {name}"

handle = serve.run(echo.bind())
assert handle.handle_options._request_protocol == RequestProtocol.UNDEFINED

handle._set_request_protocol(RequestProtocol.HTTP)
assert handle.handle_options._request_protocol == RequestProtocol.HTTP

multiplexed_model_id = "fake-multiplexed_model_id"
new_handle = handle.options(multiplexed_model_id=multiplexed_model_id)
assert new_handle.handle_options.multiplexed_model_id == multiplexed_model_id
assert new_handle.handle_options._request_protocol == RequestProtocol.HTTP

new_handle._set_request_protocol(RequestProtocol.GRPC)
assert new_handle.handle_options._request_protocol == RequestProtocol.GRPC
assert handle.handle_options._request_protocol == RequestProtocol.HTTP


if __name__ == "__main__":
import sys
import pytest
Expand Down
Loading

0 comments on commit de409c6

Please sign in to comment.