Skip to content

Commit

Permalink
Changing restricted_apis/protocols to restricted_features
Browse files Browse the repository at this point in the history
  • Loading branch information
KrishnanPrash committed Nov 12, 2024
1 parent 3703c04 commit 770dd79
Show file tree
Hide file tree
Showing 4 changed files with 21 additions and 48 deletions.
46 changes: 10 additions & 36 deletions qa/L0_python_api/test_kserve.py
Original file line number Diff line number Diff line change
Expand Up @@ -419,51 +419,25 @@ def test_metrics_update(self, frontend, url):
utils.teardown_service(metrics_service)
utils.teardown_server(server)

@pytest.mark.parametrize("frontend, client_type, url", [HTTP_ARGS])
def test_http_restricted_features(self, frontend, client_type, url):
server = utils.setup_server()
# Specifying Restricted Features
http_rf_options = KServeHttp.Options(restricted_apis=RESTRICTED_FEATURE_ARG)
service = utils.setup_service(server, frontend, options=http_rf_options)

# Valid headers sent with inference request
headers = {"infer-key": "infer-value"}
assert utils.send_and_test_inference_identity(client_type, url, headers)

# Invalid headers sent with inference request
headers = {"fake-key": "fake-value"}
with pytest.raises(
InferenceServerException,
match=re.escape(
"[403] This API is restricted, expecting header 'infer-key'"
),
):
utils.send_and_test_inference_identity(client_type, url, headers)

utils.teardown_service(service)
utils.teardown_server(server)

@pytest.mark.parametrize("frontend, client_type, url", [GRPC_ARGS])
def test_grpc_restricted_features(self, frontend, client_type, url):
@pytest.mark.parametrize(
"frontend, client_type, url, key_prefix",
[HTTP_ARGS + ("",), GRPC_ARGS + ("triton-grpc-protocol-",)],
)
def test_restricted_features(self, frontend, client_type, url, key_prefix):
server = utils.setup_server()
# Specifying Restricted Features
grpc_rf_options = KServeGrpc.Options(
restricted_protocols=RESTRICTED_FEATURE_ARG
)
service = utils.setup_service(server, frontend, options=grpc_rf_options)
options = frontend.Options(restricted_features=RESTRICTED_FEATURE_ARG)
service = utils.setup_service(server, frontend, options=options)

# Valid headers sent with inference request
grpc_key_prefix = "triton-grpc-protocol-"
headers = {grpc_key_prefix + "infer-key": "infer-value"}
headers = {key_prefix + "infer-key": "infer-value"}
assert utils.send_and_test_inference_identity(client_type, url, headers)

# Invalid headers sent with inference request
headers = {grpc_key_prefix + "fake-key": "fake-value"}
headers = {key_prefix + "fake-key": "fake-value"}
with pytest.raises(
InferenceServerException,
match=re.escape(
"[StatusCode.UNAVAILABLE] This protocol is restricted, expecting header 'triton-grpc-protocol-infer-key'"
),
match=f"expecting header '{key_prefix}infer-key'",
):
utils.send_and_test_inference_identity(client_type, url, headers)

Expand Down
8 changes: 4 additions & 4 deletions src/python/tritonfrontend/_api/_kservegrpc.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,18 +82,18 @@ class Options:
] = Grpc_compression_level.NONE
infer_allocation_pool_size: int = Field(8, ge=0)
forward_header_pattern: str = ""
restricted_protocols: RestrictedFeatures = RestrictedFeatures()
restricted_features: RestrictedFeatures = RestrictedFeatures()

@handle_triton_error
def __post_init__(self):
if isinstance(self.infer_compression_level, Grpc_compression_level):
self.infer_compression_level = self.infer_compression_level.value

if not isinstance(self.restricted_protocols, RestrictedFeatures):
if not isinstance(self.restricted_features, RestrictedFeatures):
raise InvalidArgumentError(
"restricted_protocols needs an instance of RestrictedFeatures."
"restricted_features needs an instance of RestrictedFeatures."
)
self.restricted_protocols = repr(self.restricted_protocols)
self.restricted_features = repr(self.restricted_features)

@handle_triton_error
def __init__(self, server: tritonserver, options: "KServeGrpc.Options" = None):
Expand Down
8 changes: 4 additions & 4 deletions src/python/tritonfrontend/_api/_kservehttp.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,15 +46,15 @@ class Options:
reuse_port: bool = False
thread_count: int = Field(8, gt=0)
header_forward_pattern: str = ""
restricted_apis: RestrictedFeatures = RestrictedFeatures()
restricted_features: RestrictedFeatures = RestrictedFeatures()

@handle_triton_error
def __post_init__(self):
if not isinstance(self.restricted_apis, RestrictedFeatures):
if not isinstance(self.restricted_features, RestrictedFeatures):
raise InvalidArgumentError(
"restricted_apis needs an instance of RestrictedFeatures."
"restricted_features needs an instance of RestrictedFeatures."
)
self.restricted_apis = repr(self.restricted_apis)
self.restricted_features = repr(self.restricted_features)

@handle_triton_error
def __init__(self, server: tritonserver, options: "KServeHttp.Options" = None):
Expand Down
7 changes: 3 additions & 4 deletions src/python/tritonfrontend/_c/tritonfrontend.h
Original file line number Diff line number Diff line change
Expand Up @@ -169,13 +169,12 @@ class TritonFrontend {
static void _populate_restricted_features(
UnorderedMapType& data, RestrictedFeatures& rest_features)
{
std::string map_key; // Name of option in UnorderedMap
std::string key_prefix; // Prefix for header key
std::string map_key =
"restricted_features"; // Name of option in UnorderedMap
std::string key_prefix; // Prefix for header key
if (std::is_same_v<FrontendServer, triton::server::HTTPAPIServer>) {
map_key = "restricted_apis";
key_prefix = "";
} else if (std::is_same_v<FrontendServer, triton::server::grpc::Server>) {
map_key = "restricted_protocols";
key_prefix = "triton-grpc-protocol-";
} else {
// Restricted Features is not supported for this class.
Expand Down

0 comments on commit 770dd79

Please sign in to comment.