Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor: Refactor core input size checks #382

Merged
merged 9 commits into from
Sep 4, 2024
55 changes: 36 additions & 19 deletions src/infer_request.cc
Original file line number Diff line number Diff line change
Expand Up @@ -910,6 +910,7 @@ Status
InferenceRequest::Normalize()
{
const inference::ModelConfig& model_config = model_raw_->Config();
const std::string& model_name = ModelName();

// Fill metadata for raw input
if (!raw_input_name_.empty()) {
Expand All @@ -922,7 +923,7 @@ InferenceRequest::Normalize()
std::to_string(original_inputs_.size()) +
") to be deduced but got " +
std::to_string(model_config.input_size()) + " inputs in '" +
ModelName() + "' model configuration");
model_name + "' model configuration");
}
auto it = original_inputs_.begin();
if (raw_input_name_ != it->first) {
Expand Down Expand Up @@ -1055,7 +1056,7 @@ InferenceRequest::Normalize()
Status::Code::INVALID_ARG,
LogRequest() + "input '" + input.Name() +
"' has no shape but model requires batch dimension for '" +
ModelName() + "'");
model_name + "'");
}

if (batch_size_ == 0) {
Expand All @@ -1064,7 +1065,7 @@ InferenceRequest::Normalize()
return Status(
Status::Code::INVALID_ARG,
LogRequest() + "input '" + input.Name() +
"' batch size does not match other inputs for '" + ModelName() +
"' batch size does not match other inputs for '" + model_name +
"'");
}

Expand All @@ -1080,7 +1081,7 @@ InferenceRequest::Normalize()
Status::Code::INVALID_ARG,
LogRequest() + "inference request batch-size must be <= " +
std::to_string(model_config.max_batch_size()) + " for '" +
ModelName() + "'");
model_name + "'");
}

// Verify that each input shape is valid for the model, make
Expand All @@ -1089,17 +1090,17 @@ InferenceRequest::Normalize()
const inference::ModelInput* input_config;
RETURN_IF_ERROR(model_raw_->GetInput(pr.second.Name(), &input_config));

auto& input_id = pr.first;
auto& input_name = pr.first;
auto& input = pr.second;
auto shape = input.MutableShape();

if (input.DType() != input_config->data_type()) {
return Status(
Status::Code::INVALID_ARG,
LogRequest() + "inference input '" + input_id + "' data-type is '" +
LogRequest() + "inference input '" + input_name + "' data-type is '" +
std::string(
triton::common::DataTypeToProtocolString(input.DType())) +
"', but model '" + ModelName() + "' expects '" +
"', but model '" + model_name + "' expects '" +
std::string(triton::common::DataTypeToProtocolString(
input_config->data_type())) +
"'");
Expand All @@ -1119,7 +1120,7 @@ InferenceRequest::Normalize()
Status::Code::INVALID_ARG,
LogRequest() +
"All input dimensions should be specified for input '" +
input_id + "' for model '" + ModelName() + "', got " +
input_name + "' for model '" + model_name + "', got " +
triton::common::DimsListToString(input.OriginalShape()));
} else if (
(config_dims[i] != triton::common::WILDCARD_DIM) &&
Expand Down Expand Up @@ -1148,8 +1149,8 @@ InferenceRequest::Normalize()
}
return Status(
Status::Code::INVALID_ARG,
LogRequest() + "unexpected shape for input '" + input_id +
"' for model '" + ModelName() + "'. Expected " +
LogRequest() + "unexpected shape for input '" + input_name +
"' for model '" + model_name + "'. Expected " +
triton::common::DimsListToString(full_dims) + ", got " +
triton::common::DimsListToString(input.OriginalShape()) + ". " +
implicit_batch_note);
Expand Down Expand Up @@ -1205,9 +1206,8 @@ InferenceRequest::Normalize()
// (prepend 4 bytes to specify string length), so need to add all the
// first 4 bytes for each element to find expected byte size
if (data_type == inference::DataType::TYPE_STRING) {
RETURN_IF_ERROR(
ValidateBytesInputs(input_id, input, &input_memory_type));

RETURN_IF_ERROR(ValidateBytesInputs(
input_name, input, model_name, &input_memory_type));
// FIXME: Temporarily skips byte size checks for GPU tensors. See
// DLIS-6820.
} else {
Expand All @@ -1226,7 +1226,7 @@ InferenceRequest::Normalize()
return Status(
Status::Code::INVALID_ARG,
LogRequest() + "input byte size mismatch for input '" +
input_id + "' for model '" + ModelName() + "'. Expected " +
input_name + "' for model '" + model_name + "'. Expected " +
std::to_string(expected_byte_size) + ", got " +
std::to_string(byte_size));
}
Expand Down Expand Up @@ -1300,7 +1300,8 @@ InferenceRequest::ValidateRequestInputs()

Status
InferenceRequest::ValidateBytesInputs(
const std::string& input_id, const Input& input,
const std::string& input_name, const Input& input,
const std::string& model_name,
TRITONSERVER_MemoryType* buffer_memory_type) const
{
const auto& input_dims = input.ShapeWithBatchDim();
Expand Down Expand Up @@ -1339,13 +1340,28 @@ InferenceRequest::ValidateBytesInputs(
return Status(
Status::Code::INVALID_ARG,
LogRequest() +
"element byte size indicator exceeds the end of the buffer.");
"incomplete string length indicator for inference input '" +
input_name + "' for model '" + model_name + "', expecting " +
std::to_string(sizeof(uint32_t)) + " bytes but only " +
std::to_string(remaining_buffer_size) +
" bytes available. Please make sure the string length "
"indicator is in one buffer.");
}

// Start the next element and reset the remaining element size.
remaining_element_size = *(reinterpret_cast<const uint32_t*>(buffer));
element_checked++;

// Early stop
if (element_checked > element_count) {
return Status(
Status::Code::INVALID_ARG,
LogRequest() + "unexpected number of string elements " +
std::to_string(element_checked) + " for inference input '" +
input_name + "' for model '" + model_name + "', expecting " +
std::to_string(element_count));
}

// Advance pointer and remainder by the indicator size.
buffer += kElementSizeIndicator;
remaining_buffer_size -= kElementSizeIndicator;
Expand All @@ -1371,16 +1387,17 @@ InferenceRequest::ValidateBytesInputs(
return Status(
Status::Code::INVALID_ARG,
LogRequest() + "expected " + std::to_string(buffer_count) +
" buffers for inference input '" + input_id + "', got " +
std::to_string(buffer_next_idx));
" buffers for inference input '" + input_name + "' for model '" +
model_name + "', got " + std::to_string(buffer_next_idx));
}

// Validate the number of processed elements exactly match expectations.
if (element_checked != element_count) {
return Status(
Status::Code::INVALID_ARG,
LogRequest() + "expected " + std::to_string(element_count) +
" string elements for inference input '" + input_id + "', got " +
" string elements for inference input '" + input_name +
"' for model '" + model_name + "', got " +
std::to_string(element_checked));
}

Expand Down
1 change: 1 addition & 0 deletions src/infer_request.h
Original file line number Diff line number Diff line change
Expand Up @@ -775,6 +775,7 @@ class InferenceRequest {

Status ValidateBytesInputs(
const std::string& input_id, const Input& input,
const std::string& model_name,
TRITONSERVER_MemoryType* buffer_memory_type) const;

// Helpers for pending request metrics
Expand Down
53 changes: 34 additions & 19 deletions src/test/input_byte_size_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -258,19 +258,20 @@ char InputByteSizeTest::input_data_string_

TEST_F(InputByteSizeTest, ValidInputByteSize)
{
const char* model_name = "savedmodel_zero_1_float32";
yinggeh marked this conversation as resolved.
Show resolved Hide resolved
// Create an inference request
FAIL_TEST_IF_ERR(
TRITONSERVER_InferenceRequestNew(
&irequest_, server_, "pt_identity", -1 /* model_version */),
&irequest_, server_, model_name, -1 /* model_version */),
"creating inference request");
FAIL_TEST_IF_ERR(
TRITONSERVER_InferenceRequestSetReleaseCallback(
irequest_, InferRequestComplete, nullptr /* request_release_userp */),
"setting request release callback");

// Define input shape and data
std::vector<int64_t> shape{1, 8};
std::vector<float> input_data(8, 1);
std::vector<int64_t> shape{1, 16};
yinggeh marked this conversation as resolved.
Show resolved Hide resolved
std::vector<float> input_data(16, 1);
const auto input0_byte_size = sizeof(input_data[0]) * input_data.size();

// Set input for the request
Expand Down Expand Up @@ -312,19 +313,20 @@ TEST_F(InputByteSizeTest, ValidInputByteSize)

TEST_F(InputByteSizeTest, InputByteSizeMismatch)
{
const char* model_name = "savedmodel_zero_1_float32";
// Create an inference request
FAIL_TEST_IF_ERR(
TRITONSERVER_InferenceRequestNew(
&irequest_, server_, "pt_identity", -1 /* model_version */),
&irequest_, server_, model_name, -1 /* model_version */),
"creating inference request");
FAIL_TEST_IF_ERR(
TRITONSERVER_InferenceRequestSetReleaseCallback(
irequest_, InferRequestComplete, nullptr /* request_release_userp */),
"setting request release callback");

// Define input shape and data
std::vector<int64_t> shape{1, 8};
std::vector<float> input_data(10, 1);
std::vector<int64_t> shape{1, 16};
yinggeh marked this conversation as resolved.
Show resolved Hide resolved
std::vector<float> input_data(17, 1);
const auto input0_byte_size = sizeof(input_data[0]) * input_data.size();

// Set input for the request
Expand Down Expand Up @@ -353,8 +355,8 @@ TEST_F(InputByteSizeTest, InputByteSizeMismatch)
FAIL_TEST_IF_SUCCESS(
TRITONSERVER_ServerInferAsync(server_, irequest_, nullptr /* trace */),
"expect error with inference request",
"input byte size mismatch for input 'INPUT0' for model 'pt_identity'. "
"Expected 32, got 40");
"input byte size mismatch for input 'INPUT0' for model '" +
std::string{model_name} + "'. Expected 64, got 68");

// Need to manually delete request, otherwise server will not shut down.
FAIL_TEST_IF_ERR(
Expand All @@ -364,10 +366,11 @@ TEST_F(InputByteSizeTest, InputByteSizeMismatch)

TEST_F(InputByteSizeTest, ValidStringInputByteSize)
{
const char* model_name = "savedmodel_zero_1_object";
// Create an inference request
FAIL_TEST_IF_ERR(
TRITONSERVER_InferenceRequestNew(
&irequest_, server_, "simple_identity", -1 /* model_version */),
&irequest_, server_, model_name, -1 /* model_version */),
"creating inference request");
FAIL_TEST_IF_ERR(
TRITONSERVER_InferenceRequestSetReleaseCallback(
Expand Down Expand Up @@ -424,10 +427,11 @@ TEST_F(InputByteSizeTest, ValidStringInputByteSize)

TEST_F(InputByteSizeTest, StringCountMismatch)
{
const char* model_name = "savedmodel_zero_1_object";
// Create an inference request
FAIL_TEST_IF_ERR(
TRITONSERVER_InferenceRequestNew(
&irequest_, server_, "simple_identity", -1 /* model_version */),
&irequest_, server_, model_name, -1 /* model_version */),
"creating inference request");
FAIL_TEST_IF_ERR(
TRITONSERVER_InferenceRequestSetReleaseCallback(
Expand Down Expand Up @@ -457,7 +461,8 @@ TEST_F(InputByteSizeTest, StringCountMismatch)
FAIL_TEST_IF_SUCCESS(
TRITONSERVER_ServerInferAsync(server_, irequest_, nullptr /* trace */),
"expect error with inference request",
"expected 3 string elements for inference input 'INPUT0', got 2");
"expected 3 string elements for inference input 'INPUT0' for model '" +
std::string{model_name} + "', got 2");

// Need to manually delete request, otherwise server will not shut down.
FAIL_TEST_IF_ERR(
Expand All @@ -467,7 +472,8 @@ TEST_F(InputByteSizeTest, StringCountMismatch)
// Create an inference request
FAIL_TEST_IF_ERR(
TRITONSERVER_InferenceRequestNew(
&irequest_, server_, "simple_identity", -1 /* model_version */),
&irequest_, server_, "savedmodel_zero_1_object",
-1 /* model_version */),
"creating inference request");
FAIL_TEST_IF_ERR(
TRITONSERVER_InferenceRequestSetReleaseCallback(
Expand Down Expand Up @@ -495,7 +501,9 @@ TEST_F(InputByteSizeTest, StringCountMismatch)
FAIL_TEST_IF_SUCCESS(
TRITONSERVER_ServerInferAsync(server_, irequest_, nullptr /* trace */),
"expect error with inference request",
"expected 1 string elements for inference input 'INPUT0', got 2");
"unexpected number of string elements 2 for inference input 'INPUT0' for "
"model '" +
std::string{model_name} + "', expecting 1");

// Need to manually delete request, otherwise server will not shut down.
FAIL_TEST_IF_ERR(
Expand All @@ -505,10 +513,11 @@ TEST_F(InputByteSizeTest, StringCountMismatch)

TEST_F(InputByteSizeTest, StringSizeMisalign)
{
const char* model_name = "savedmodel_zero_1_object";
// Create an inference request
FAIL_TEST_IF_ERR(
TRITONSERVER_InferenceRequestNew(
&irequest_, server_, "simple_identity", -1 /* model_version */),
&irequest_, server_, model_name, -1 /* model_version */),
"creating inference request");
FAIL_TEST_IF_ERR(
TRITONSERVER_InferenceRequestSetReleaseCallback(
Expand Down Expand Up @@ -542,9 +551,13 @@ TEST_F(InputByteSizeTest, StringSizeMisalign)

// Run inference
FAIL_TEST_IF_SUCCESS(
TRITONSERVER_ServerInferAsync(server_, irequest_, nullptr /* trace
*/), "expect error with inference request",
"element byte size indicator exceeds the end of the buffer");
TRITONSERVER_ServerInferAsync(server_, irequest_, nullptr /* trace*/),
"expect error with inference request",
"incomplete string length indicator for inference input 'INPUT0' for "
"model '" +
std::string{model_name} +
"', expecting 4 bytes but only 2 bytes available. Please make sure "
"the string length indicator is in one buffer.");

// Need to manually delete request, otherwise server will not shut down.
FAIL_TEST_IF_ERR(
Expand Down Expand Up @@ -573,7 +586,8 @@ TEST_F(InputByteSizeTest, StringCountMismatchGPU)
// Create an inference request
FAIL_TEST_IF_ERR(
TRITONSERVER_InferenceRequestNew(
&irequest_, server_, "simple_identity", -1 /* model_version */),
&irequest_, server_, "savedmodel_zero_1_object",
-1 /* model_version */),
"creating inference request");
FAIL_TEST_IF_ERR(
TRITONSERVER_InferenceRequestSetReleaseCallback(
Expand Down Expand Up @@ -629,7 +643,8 @@ TEST_F(InputByteSizeTest, StringCountMismatchGPU)
// Create an inference request
FAIL_TEST_IF_ERR(
TRITONSERVER_InferenceRequestNew(
&irequest_, server_, "simple_identity", -1 /* model_version */),
&irequest_, server_, "savedmodel_zero_1_object",
rmccorm4 marked this conversation as resolved.
Show resolved Hide resolved
-1 /* model_version */),
"creating inference request");
FAIL_TEST_IF_ERR(
TRITONSERVER_InferenceRequestSetReleaseCallback(
Expand Down
Loading