Skip to content

Commit

Permalink
Add sample_optional, and fix pytest for parameters, optional
Browse files Browse the repository at this point in the history
  • Loading branch information
kimdwkimdw committed Aug 31, 2023
1 parent f071bc5 commit d82974e
Show file tree
Hide file tree
Showing 9 changed files with 151 additions and 21 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/pre-commit_pytest.yml
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ jobs:
runs-on: ubuntu-latest
needs: pre-commit
container:
image: nvcr.io/nvidia/tritonserver:23.03-pyt-python-py3
image: nvcr.io/nvidia/tritonserver:23.08-pyt-python-py3
options:
--shm-size=1g
steps:
Expand Down
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ if __name__ == "__main__":

## Release Notes

- 23.08.30 Support `optional` with model input, `parameters` on config.pbtxt
- 23.06.16 Support tritonclient>=2.34.0
- Loosely modified the requirements related to tritonclient

Expand Down
18 changes: 18 additions & 0 deletions bin/run_triton_tritony_sample.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
#!/bin/bash

HERE=$(dirname "$(readlink -f $0)")
PARENT_DIR=$(dirname "$HERE")

docker run -it --rm --name triton_tritony \
-p8100:8000 \
-p8101:8001 \
-p8102:8002 \
-v "${PARENT_DIR}"/model_repository:/models:ro \
-e OMP_NUM_THREADS=2 \
-e OPENBLAS_NUM_THREADS=2 \
--shm-size=1g \
nvcr.io/nvidia/tritonserver:23.08-pyt-python-py3 \
tritonserver --model-repository=/models \
--exit-timeout-secs 15 \
--min-supported-compute-capability 7.0 \
--log-verbose 0 # 0-nothing, 1-info, 2-debug, 3-trace
4 changes: 1 addition & 3 deletions model_repository/sample/1/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,11 @@ def initialize(self, args):
pb_utils.triton_string_to_numpy(output_config["data_type"]) for output_config in output_configs
]

parameters = self.model_config["parameters"]

def execute(self, requests):
responses = [None for _ in requests]
for idx, request in enumerate(requests):
current_add_value = int(json.loads(request.parameters()).get("add", 0))
in_tensor = [item.as_numpy() + current_add_value for item in request.inputs()]
in_tensor = [item.as_numpy() + current_add_value for item in request.inputs() if item.name() == "model_in"]
out_tensor = [
pb_utils.Tensor(output_name, x.astype(output_dtype))
for x, output_name, output_dtype in zip(in_tensor, self.output_name_list, self.output_dtype_list)
Expand Down
7 changes: 0 additions & 7 deletions model_repository/sample/config.pbtxt
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,6 @@ name: "sample"
backend: "python"
max_batch_size: 0

parameters [
{
key: "add",
value: { string_value: "0" }
}
]
input [
{
name: "model_in"
Expand Down
37 changes: 37 additions & 0 deletions model_repository/sample_optional/1/model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
import json

import triton_python_backend_utils as pb_utils


class TritonPythonModel:
def initialize(self, args):
self.model_config = model_config = json.loads(args["model_config"])
output_configs = model_config["output"]

self.output_name_list = [output_config["name"] for output_config in output_configs]
self.output_dtype_list = [
pb_utils.triton_string_to_numpy(output_config["data_type"]) for output_config in output_configs
]

def execute(self, requests):
responses = [None for _ in requests]
for idx, request in enumerate(requests):
current_add_value = int(json.loads(request.parameters()).get("add", 0))
optional_in_tensor = pb_utils.get_input_tensor_by_name(request, "optional_model_sub")
if optional_in_tensor:
optional_in_tensor = optional_in_tensor.as_numpy()
else:
optional_in_tensor = 0

in_tensor = [
item.as_numpy() + current_add_value - optional_in_tensor
for item in request.inputs()
if item.name() == "model_in"
]
out_tensor = [
pb_utils.Tensor(output_name, x.astype(output_dtype))
for x, output_name, output_dtype in zip(in_tensor, self.output_name_list, self.output_dtype_list)
]
inference_response = pb_utils.InferenceResponse(output_tensors=out_tensor)
responses[idx] = inference_response
return responses
54 changes: 54 additions & 0 deletions model_repository/sample_optional/config.pbtxt
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
name: "sample_optional"
backend: "python"
max_batch_size: 0

parameters [
{
key: "add",
value: { string_value: "0" }
}
]
input [
{
name: "model_in"
data_type: TYPE_FP32
dims: [ -1 ]
},
{
name: "optional_model_sub"
data_type: TYPE_FP32
optional: true
dims: [ -1 ]
}
]
output [
{
name: "model_out"
data_type: TYPE_FP32
dims: [ -1 ]
}
]
instance_group [{ kind: KIND_CPU, count: 1 }]
model_warmup {
name: "RandomSampleInput"
batch_size: 1
inputs [{
key: "model_in"
value: {
data_type: TYPE_FP32
dims: [ 10 ]
random_data: true
}
}, {
key: "model_in"
value: {
data_type: TYPE_FP32
dims: [ 10 ]
zero_data: true
}
}]
}
43 changes: 34 additions & 9 deletions tests/test_model_call.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,47 +5,72 @@

from tritony import InferenceClient

MODEL_NAME = os.environ.get("MODEL_NAME", "sample")
TRITON_HOST = os.environ.get("TRITON_HOST", "localhost")
TRITON_HTTP = os.environ.get("TRITON_HTTP", "8000")
TRITON_GRPC = os.environ.get("TRITON_GRPC", "8001")


EPSILON = 1e-8


@pytest.fixture(params=[("http", TRITON_HTTP), ("grpc", TRITON_GRPC)])
def protocol_and_port(request):
return request.param


def get_client(protocol, port):
def get_client(protocol, port, model_name):
print(f"Testing {protocol}", flush=True)
return InferenceClient.create_with(MODEL_NAME, f"{TRITON_HOST}:{port}", protocol=protocol)
return InferenceClient.create_with(model_name, f"{TRITON_HOST}:{port}", protocol=protocol)


def test_swithcing(protocol_and_port):
client = get_client(*protocol_and_port)
client = get_client(*protocol_and_port, model_name="sample")

sample = np.random.rand(1, 100).astype(np.float32)
result = client(sample)
assert {np.isclose(result, sample).all()}

sample_batched = np.random.rand(100, 100).astype(np.float32)
client(sample_batched, model_name="sample_autobatching")
assert {np.isclose(result, sample).all()}
assert np.isclose(result, sample).all()


def test_with_input_name(protocol_and_port):
client = get_client(*protocol_and_port)
client = get_client(*protocol_and_port, model_name="sample")

sample = np.random.rand(100, 100).astype(np.float32)
result = client({client.default_model_spec.model_input[0].name: sample})
assert {np.isclose(result, sample).all()}
assert np.isclose(result, sample).all()


def test_with_parameters(protocol_and_port):
client = get_client(*protocol_and_port)
client = get_client(*protocol_and_port, model_name="sample")

sample = np.random.rand(1, 100).astype(np.float32)
ADD_VALUE = 1
result = client({client.default_model_spec.model_input[0].name: sample}, parameters={"add": f"{ADD_VALUE}"})

assert {np.isclose(result[0], sample[0] + ADD_VALUE).all()}
assert np.isclose(result[0], sample[0] + ADD_VALUE).all()


def test_with_optional(protocol_and_port):
client = get_client(*protocol_and_port, model_name="sample_optional")

sample = np.random.rand(1, 100).astype(np.float32)

result = client({client.default_model_spec.model_input[0].name: sample})
assert np.isclose(result[0], sample[0], rtol=EPSILON).all()

OPTIONAL_SUB_VALUE = np.zeros_like(sample) + 3
result = client(
{
client.default_model_spec.model_input[0].name: sample,
"optional_model_sub": OPTIONAL_SUB_VALUE,
}
)
assert np.isclose(result[0], sample[0] - OPTIONAL_SUB_VALUE, rtol=EPSILON).all()


if __name__ == "__main__":
test_with_parameters(("grpc", "8101"))
test_with_optional(("grpc", "8101"))
6 changes: 5 additions & 1 deletion tritony/tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,11 @@ async def request_async(protocol: TritonProtocol, model_input: Dict, triton_clie
loop = asyncio.get_running_loop()

if "parameters" in grpc_get_inference_request.__code__.co_varnames:
model_input["parameters"] = None
# check tritonclient[all]>=2.34.0, NGC 23.04
model_input["parameters"] = model_input.get("parameters", None)
else:
logger.warning("tritonclient[all]<2.34.0, NGC 21.04")
model_input.pop("parameters")
request = grpc_get_inference_request(
**model_input,
priority=0,
Expand Down

0 comments on commit d82974e

Please sign in to comment.