Skip to content

Commit

Permalink
Check output batch size in @Batch decorator
Browse files Browse the repository at this point in the history
  • Loading branch information
pziecina-nv committed Sep 27, 2023
1 parent 7e93b40 commit 98d2fdd
Show file tree
Hide file tree
Showing 7 changed files with 212 additions and 105 deletions.
8 changes: 8 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,14 @@ limitations under the License.

# Changelog

## Unreleased

- Change: `@batch` decorator raises a `ValueError` if any of the outputs have a different batch size than expected.

[//]: <> (put here on external component update with short summary what change or link to changelog)

- Version of [Triton Inference Server](https://github.com/triton-inference-server/) embedded in wheel: [2.36.0](https://github.com/triton-inference-server/server/releases/tag/v2.36.0)

## 0.3.1 (2023-09-26)

- Change: `KeyboardInterrupt` is now handled in `triton.serve()`. PyTriton hosting scripts return an exit code of 0 instead of 130 when they receive a SIGINT signal.
Expand Down
15 changes: 9 additions & 6 deletions examples/dali_resnet101_pytorch/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

import argparse
import logging
import pathlib

import numpy as np # pytype: disable=import-error

Expand All @@ -33,15 +34,17 @@ def infer_model(input, args):
with ModelClient(args.url, "ResNet101", init_timeout_s=args.init_timeout_s) as client:
result_data = client.infer_batch(input)

original = result_data["original"]
segmented = result_data["segmented"]
original_batch = result_data["original"]
segmented_batch = result_data["segmented"]

if args.dump_images:
for i, (orig, segm) in enumerate(zip(original, segmented)):
import cv2 # pytype: disable=import-error
pathlib.Path("test_video").mkdir(parents=True, exist_ok=True)
for batch_idx, (original, segmented) in enumerate(zip(original_batch, segmented_batch)):
for frame_idx, (orig, segm) in enumerate(zip(original, segmented)):
import cv2 # pytype: disable=import-error

cv2.imwrite(f"test_video/orig{i}.jpg", orig)
cv2.imwrite(f"test_video/segm{i}.jpg", segm)
cv2.imwrite(f"test_video/orig_{batch_idx:03d}_{frame_idx:04d}.jpg", orig)
cv2.imwrite(f"test_video/segm_{batch_idx:03d}_{frame_idx:04d}.jpg", segm)

logger.info("Processing finished.")

Expand Down
15 changes: 8 additions & 7 deletions examples/dali_resnet101_pytorch/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,10 +111,11 @@ def postprocess(images, probabilities):


@batch
def _infer_fn(**enc):
enc = enc["video"]
def _infer_fn(**inputs):
encoded_video = inputs["video"]

image, input = preprocess(enc)
image, input = preprocess(encoded_video)
batch_size, frames_num = image.shape[:2]

input = input.reshape(-1, *input.shape[-3:]) # NFCHW to NCHW (flattening first two dimensions)
image = image.reshape(-1, *image.shape[-3:]) # NFHWC to NHWC (flattening first two dimensions)
Expand All @@ -123,8 +124,8 @@ def _infer_fn(**enc):
out = postprocess(image, prob)

return {
"original": image.cpu().numpy(),
"segmented": out.as_cpu().as_array(),
"original": image.cpu().numpy().reshape(batch_size, frames_num, *image.shape[-3:]),
"segmented": out.as_cpu().as_array().reshape(batch_size, frames_num, *image.shape[-3:]),
}


Expand All @@ -148,8 +149,8 @@ def main():
Tensor(name="video", dtype=np.uint8, shape=(-1,)), # Encoded video
],
outputs=[
Tensor(name="original", dtype=np.uint8, shape=(-1, -1, -1)),
Tensor(name="segmented", dtype=np.uint8, shape=(-1, -1, -1)),
Tensor(name="original", dtype=np.uint8, shape=(-1, -1, -1, -1)), # FHWC
Tensor(name="segmented", dtype=np.uint8, shape=(-1, -1, -1, -1)), # FHWC
],
config=ModelConfig(
max_batch_size=MAX_BATCH_SIZE,
Expand Down
10 changes: 6 additions & 4 deletions examples/online_learning_mnist/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,11 +114,13 @@ def replace_inference_model(self):
with self.lock:
self.infer_model.load_state_dict(self.trained_model.state_dict())

@batch
def train(self, image, target):
def train(self, requests):
"""Train function is used in training endpoint."""
self.train_data_queue.put((image.copy(), target.copy()))
return {"last_loss": np.array([[self.last_loss]]).astype(np.float32)}
# concatenate all requests into one batch. No need for padding due to fixed image dimensions
images = np.concatenate([request["image"] for request in requests], axis=0)
targets = np.concatenate([request["target"] for request in requests], axis=0)
self.train_data_queue.put((images, targets))
return [{"last_loss": np.array([[self.last_loss]]).astype(np.float32)} for _ in requests]

@batch
def infer(self, image):
Expand Down
20 changes: 18 additions & 2 deletions pytriton/decorators.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,6 +182,11 @@ def batch(wrapped, instance, args, kwargs):
received by Triton server.
We assume that each request has the same set of keys (you can use group_by_keys decorator before
using @batch decorator if your requests may have different set of keys).
Raises:
PyTritonValidationError: If the requests have different set of keys.
ValueError: If the output tensors have different than expected batch sizes. Expected batch size is
calculated as a sum of batch sizes of all requests.
"""
req_list = args[0]
input_names = req_list[0].keys()
Expand All @@ -204,12 +209,23 @@ def _split_result(_result):
outputs = convert_output(_result, wrapped, instance)
output_names = outputs.keys()

requests_total_batch_size = sum(get_inference_request_batch_size(req) for req in req_list)
not_matching_tensors_shapes = {
output_name: output_tensor.shape
for output_name, output_tensor in outputs.items()
if output_tensor.shape[0] != requests_total_batch_size
}
if not_matching_tensors_shapes:
raise ValueError(
f"Received output tensors with different batch sizes: {', '.join(': '.join(map(str, item)) for item in not_matching_tensors_shapes.items())}. "
f"Expected batch size: {requests_total_batch_size}. "
)

out_list = []
start_idx = 0
for request in req_list:
# get batch_size of first input for each request - assume that all inputs have same batch_size
first_input = next(iter(request.values()))
request_batch_size = first_input.shape[0]
request_batch_size = get_inference_request_batch_size(request)
req_output_dict = {}
for _output_ind, output_name in enumerate(output_names):
req_output = outputs[output_name][start_idx : start_idx + request_batch_size, ...]
Expand Down
114 changes: 28 additions & 86 deletions tests/unit/test_decorators.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
# Copyright (c) 2022-2023, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand All @@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""Inference decorators tests."""
import inspect
import typing

import numpy as np
Expand Down Expand Up @@ -62,8 +61,6 @@
Request({"b": np.array([[1, 2]]), "a": np.array([[1]])}, {}),
]

input_batch_with_params = {"b": np.array([[1, 2], [1, 2], [9, 9]]), "a": np.array([[1], [1], [1]])}


def _prepare_and_inject_context_with_config(config, fun):
context = TritonContext()
Expand Down Expand Up @@ -126,88 +123,7 @@ def _prepare_context_for_input(inputs, fun):
return context


def test_batch():
@batch
def batched_fun(**inputs):
assert isinstance(inputs, dict) and "a" in inputs and "b" in inputs
assert inputs["a"].shape == (6, 1)
assert inputs["b"].shape == (6, 2)

return {"a": inputs["a"] * 2, "b": inputs["b"] * 3}

results = batched_fun(three_request_for_batching)
assert not inspect.isgenerator(results)

for input, output in zip(three_request_for_batching, results):
assert np.all(input["a"] * 2 == output["a"]) and np.all(input["b"] * 3 == output["b"])


def test_batch_output_list():
@batch
def batched_fun(**inputs):
assert isinstance(inputs, dict) and "a" in inputs and "b" in inputs
assert inputs["a"].shape == (6, 1)
assert inputs["b"].shape == (6, 2)

return [inputs["a"] * 2, inputs["b"] * 3]

context = _prepare_context_for_input(three_request_for_batching, batched_fun)

batched_fun.__triton_context__ = context
results = batched_fun(three_request_for_batching)
assert not inspect.isgenerator(results)

for input, output in zip(three_request_for_batching, results):
assert np.all(input["a"] * 2 == output["a"]) and np.all(input["b"] * 3 == output["b"])


def test_batch_with_generator_fn():
@batch
def _infer_gen_fn(**inputs):
yield {"a": inputs["a"] * 2, "b": inputs["b"] * 3}
yield {"a": inputs["a"] * 2, "b": inputs["b"] * 3}

results_gen = _infer_gen_fn(three_request_for_batching)
assert inspect.isgenerator(results_gen)

results = next(results_gen)
assert len(three_request_for_batching) == len(results)
for request, result in zip(three_request_for_batching, results):
assert np.all(request["a"] * 2 == result["a"]) and np.all(request["b"] * 3 == result["b"])

results = next(results_gen)
assert len(three_request_for_batching) == len(results)
for request, result in zip(three_request_for_batching, results):
assert np.all(request["a"] * 2 == result["a"]) and np.all(request["b"] * 3 == result["b"])

with pytest.raises(StopIteration):
next(results_gen)


def test_sample():
@sample
def sample_fun(**inputs):
assert isinstance(inputs, dict) and "a" in inputs and "b" in inputs
return {"a": inputs["a"] * 2, "b": inputs["b"] * 3}

results = sample_fun(input_requests_for_sample)

for input, output in zip(three_request_for_batching, results):
assert np.all(input["a"] * 2 == output["a"]) and np.all(input["b"] * 3 == output["b"])


def test_sample_output_list():
@sample
def sample1(**inputs):
assert isinstance(inputs, dict) and "a" in inputs and "b" in inputs
return [inputs["a"] * 2, inputs["b"] * 3]

context = _prepare_context_for_input(input_requests_for_sample, sample1)
sample1.__triton_context__ = context
results = sample1(input_requests_for_sample)

for input, output in zip(input_requests_for_sample, results):
assert np.all(input["a"] * 2 == output["a"]) and np.all(input["b"] * 3 == output["b"])
input_batch_with_params = {"b": np.array([[1, 2], [1, 2], [9, 9]]), "a": np.array([[1], [1], [1]])}


def test_pad_batch():
Expand Down Expand Up @@ -244,6 +160,32 @@ def padded_fun(**inputs):
assert results["a"].shape[0] == config.max_batch_size and results["b"].shape[0] == config.max_batch_size


def test_sample():
@sample
def sample_fun(**inputs):
assert isinstance(inputs, dict) and "a" in inputs and "b" in inputs
return {"a": inputs["a"] * 2, "b": inputs["b"] * 3}

results = sample_fun(input_requests_for_sample)

for input, output in zip(three_request_for_batching, results):
assert np.all(input["a"] * 2 == output["a"]) and np.all(input["b"] * 3 == output["b"])


def test_sample_output_list():
@sample
def sample1(**inputs):
assert isinstance(inputs, dict) and "a" in inputs and "b" in inputs
return [inputs["a"] * 2, inputs["b"] * 3]

context = _prepare_context_for_input(input_requests_for_sample, sample1)
sample1.__triton_context__ = context
results = sample1(input_requests_for_sample)

for input, output in zip(input_requests_for_sample, results):
assert np.all(input["a"] * 2 == output["a"]) and np.all(input["b"] * 3 == output["b"])


_FIRST_VALUE_MODEL_CONFIG = TritonModelConfig(model_name="foo", inputs=[], outputs=[])


Expand Down
Loading

0 comments on commit 98d2fdd

Please sign in to comment.