Skip to content

Commit

Permalink
Support parameters on config.pbtxt
Browse files Browse the repository at this point in the history
  • Loading branch information
kimdwkimdw committed Aug 31, 2023
1 parent 0d343d7 commit f071bc5
Show file tree
Hide file tree
Showing 7 changed files with 75 additions and 20 deletions.
5 changes: 4 additions & 1 deletion model_repository/sample/1/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,13 @@ def initialize(self, args):
pb_utils.triton_string_to_numpy(output_config["data_type"]) for output_config in output_configs
]

parameters = self.model_config["parameters"]

def execute(self, requests):
responses = [None for _ in requests]
for idx, request in enumerate(requests):
in_tensor = [item.as_numpy() for item in request.inputs()]
current_add_value = int(json.loads(request.parameters()).get("add", 0))
in_tensor = [item.as_numpy() + current_add_value for item in request.inputs()]
out_tensor = [
pb_utils.Tensor(output_name, x.astype(output_dtype))
for x, output_name, output_dtype in zip(in_tensor, self.output_name_list, self.output_dtype_list)
Expand Down
8 changes: 8 additions & 0 deletions model_repository/sample/config.pbtxt
Original file line number Diff line number Diff line change
@@ -1,6 +1,14 @@
name: "sample"
backend: "python"
max_batch_size: 0

parameters [
{
key: "add",
value: { string_value: "0" }
}
]
input [
{
name: "model_in"
Expand Down
8 changes: 8 additions & 0 deletions model_repository/sample_autobatching/config.pbtxt
Original file line number Diff line number Diff line change
@@ -1,6 +1,14 @@
name: "sample_autobatching"
backend: "python"
max_batch_size: 2

parameters [
{
key: "add",
value: { string_value: "0" }
}
]
input [
{
name: "model_in"
Expand Down
8 changes: 8 additions & 0 deletions model_repository/sample_multiple/config.pbtxt
Original file line number Diff line number Diff line change
@@ -1,6 +1,14 @@
name: "sample_multiple"
backend: "python"
max_batch_size: 2

parameters [
{
key: "add",
value: { string_value: "0" }
}
]
input [
{
name: "model_in0"
Expand Down
30 changes: 19 additions & 11 deletions tests/test_model_call.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,28 +16,36 @@ def protocol_and_port(request):
return request.param


def test_swithcing(protocol_and_port):
protocol, port = protocol_and_port
print(f"Testing {protocol}")
def get_client(protocol, port):
print(f"Testing {protocol}", flush=True)
return InferenceClient.create_with(MODEL_NAME, f"{TRITON_HOST}:{port}", protocol=protocol)


client = InferenceClient.create_with(MODEL_NAME, f"{TRITON_HOST}:{port}", protocol=protocol)
def test_swithcing(protocol_and_port):
client = get_client(*protocol_and_port)

sample = np.random.rand(1, 100).astype(np.float32)
result = client(sample)
print(f"Result: {np.isclose(result, sample).all()}")
assert {np.isclose(result, sample).all()}

sample_batched = np.random.rand(100, 100).astype(np.float32)
client(sample_batched, model_name="sample_autobatching")
print(f"Result: {np.isclose(result, sample).all()}")
assert {np.isclose(result, sample).all()}


def test_with_input_name(protocol_and_port):
protocol, port = protocol_and_port
print(f"Testing {protocol}")

client = InferenceClient.create_with(MODEL_NAME, f"{TRITON_HOST}:{port}", protocol=protocol)
client = get_client(*protocol_and_port)

sample = np.random.rand(100, 100).astype(np.float32)
result = client({client.default_model_spec.model_input[0].name: sample})
assert {np.isclose(result, sample).all()}


def test_with_parameters(protocol_and_port):
client = get_client(*protocol_and_port)

sample = np.random.rand(1, 100).astype(np.float32)
ADD_VALUE = 1
result = client({client.default_model_spec.model_input[0].name: sample}, parameters={"add": f"{ADD_VALUE}"})

print(f"Result: {np.isclose(result, sample).all()}")
assert {np.isclose(result[0], sample[0] + ADD_VALUE).all()}
34 changes: 27 additions & 7 deletions tritony/tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@ async def send_request_async(
done_event,
triton_client: Union[grpcclient.InferenceServerClient, httpclient.InferenceServerClient],
model_spec: TritonModelSpec,
parameters: dict | None = None,
):
ret = []
while True:
Expand All @@ -86,7 +87,7 @@ async def send_request_async(
try:
a_pred = await request_async(
inference_client.flag.protocol,
inference_client.build_triton_input(batch_data, model_spec),
inference_client.build_triton_input(batch_data, model_spec, parameters=parameters),
triton_client,
timeout=inference_client.client_timeout,
compression=inference_client.flag.compression_algorithm,
Expand Down Expand Up @@ -232,6 +233,7 @@ def _get_request_id(self):
def __call__(
self,
sequences_or_dict: Union[List[Any], Dict[str, List[Any]]],
parameters: dict | None = None,
model_name: str | None = None,
model_version: str | None = None,
):
Expand All @@ -254,9 +256,14 @@ def __call__(
or (model_input.optional is True and model_input.name in sequences_or_dict) # check optional
]

return self._call_async(sequences_list, model_spec=model_spec)
return self._call_async(sequences_list, model_spec=model_spec, parameters=parameters)

def build_triton_input(self, _input_list: List[np.array], model_spec: TritonModelSpec):
def build_triton_input(
self,
_input_list: List[np.array],
model_spec: TritonModelSpec,
parameters: dict | None = None,
):
if self.flag.protocol is TritonProtocol.grpc:
client = grpcclient
else:
Expand All @@ -278,19 +285,30 @@ def build_triton_input(self, _input_list: List[np.array], model_spec: TritonMode
request_id=str(request_id),
model_version=model_spec.model_version,
outputs=infer_requested_output,
parameters=parameters,
)

return request_input

def _call_async(self, data: List[np.ndarray], model_spec: TritonModelSpec) -> Optional[np.ndarray]:
async_result = asyncio.run(self._call_async_item(data=data, model_spec=model_spec))
def _call_async(
self,
data: List[np.ndarray],
model_spec: TritonModelSpec,
parameters: dict | None = None,
) -> Optional[np.ndarray]:
async_result = asyncio.run(self._call_async_item(data=data, model_spec=model_spec, parameters=parameters))

if isinstance(async_result, Exception):
raise async_result

return async_result

async def _call_async_item(self, data: List[np.ndarray], model_spec: TritonModelSpec):
async def _call_async_item(
self,
data: List[np.ndarray],
model_spec: TritonModelSpec,
parameters: dict | None = None,
):
current_grpc_async_tasks = []

try:
Expand All @@ -301,7 +319,9 @@ async def _call_async_item(self, data: List[np.ndarray], model_spec: TritonModel
current_grpc_async_tasks.append(generator)

predict_tasks = [
asyncio.create_task(send_request_async(self, data_queue, done_event, self.triton_client, model_spec))
asyncio.create_task(
send_request_async(self, data_queue, done_event, self.triton_client, model_spec, parameters)
)
for idx in range(ASYNC_TASKS)
]
current_grpc_async_tasks.extend(predict_tasks)
Expand Down
2 changes: 1 addition & 1 deletion tritony/version.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = "0.0.11"
__version__ = "0.0.12rc0"

0 comments on commit f071bc5

Please sign in to comment.