From 94eb61c6915b49ad43680243d4eafd06e59a2639 Mon Sep 17 00:00:00 2001 From: Yingge He <157551214+yinggeh@users.noreply.github.com> Date: Wed, 4 Sep 2024 13:54:02 -0700 Subject: [PATCH] refactor: Refactor core input size checks (#382) --- src/infer_request.cc | 73 +++++++++++++-------- src/infer_request.h | 1 + src/test/input_byte_size_test.cc | 109 +++++++++++++++++++++++++------ 3 files changed, 136 insertions(+), 47 deletions(-) diff --git a/src/infer_request.cc b/src/infer_request.cc index e31cd3e5a..0d0c80a0d 100644 --- a/src/infer_request.cc +++ b/src/infer_request.cc @@ -910,6 +910,7 @@ Status InferenceRequest::Normalize() { const inference::ModelConfig& model_config = model_raw_->Config(); + const std::string& model_name = ModelName(); // Fill metadata for raw input if (!raw_input_name_.empty()) { @@ -922,7 +923,7 @@ InferenceRequest::Normalize() std::to_string(original_inputs_.size()) + ") to be deduced but got " + std::to_string(model_config.input_size()) + " inputs in '" + - ModelName() + "' model configuration"); + model_name + "' model configuration"); } auto it = original_inputs_.begin(); if (raw_input_name_ != it->first) { @@ -1055,7 +1056,7 @@ InferenceRequest::Normalize() Status::Code::INVALID_ARG, LogRequest() + "input '" + input.Name() + "' has no shape but model requires batch dimension for '" + - ModelName() + "'"); + model_name + "'"); } if (batch_size_ == 0) { @@ -1064,7 +1065,7 @@ InferenceRequest::Normalize() return Status( Status::Code::INVALID_ARG, LogRequest() + "input '" + input.Name() + - "' batch size does not match other inputs for '" + ModelName() + + "' batch size does not match other inputs for '" + model_name + "'"); } @@ -1080,7 +1081,7 @@ InferenceRequest::Normalize() Status::Code::INVALID_ARG, LogRequest() + "inference request batch-size must be <= " + std::to_string(model_config.max_batch_size()) + " for '" + - ModelName() + "'"); + model_name + "'"); } // Verify that each input shape is valid for the model, make @@ -1089,17 +1090,17 @@ InferenceRequest::Normalize() const inference::ModelInput* input_config; RETURN_IF_ERROR(model_raw_->GetInput(pr.second.Name(), &input_config)); - auto& input_id = pr.first; + auto& input_name = pr.first; auto& input = pr.second; auto shape = input.MutableShape(); if (input.DType() != input_config->data_type()) { return Status( Status::Code::INVALID_ARG, - LogRequest() + "inference input '" + input_id + "' data-type is '" + + LogRequest() + "inference input '" + input_name + "' data-type is '" + std::string( triton::common::DataTypeToProtocolString(input.DType())) + - "', but model '" + ModelName() + "' expects '" + + "', but model '" + model_name + "' expects '" + std::string(triton::common::DataTypeToProtocolString( input_config->data_type())) + "'"); @@ -1119,7 +1120,7 @@ InferenceRequest::Normalize() Status::Code::INVALID_ARG, LogRequest() + "All input dimensions should be specified for input '" + - input_id + "' for model '" + ModelName() + "', got " + + input_name + "' for model '" + model_name + "', got " + triton::common::DimsListToString(input.OriginalShape())); } else if ( (config_dims[i] != triton::common::WILDCARD_DIM) && @@ -1148,8 +1149,8 @@ InferenceRequest::Normalize() } return Status( Status::Code::INVALID_ARG, - LogRequest() + "unexpected shape for input '" + input_id + - "' for model '" + ModelName() + "'. Expected " + + LogRequest() + "unexpected shape for input '" + input_name + + "' for model '" + model_name + "'. Expected " + triton::common::DimsListToString(full_dims) + ", got " + triton::common::DimsListToString(input.OriginalShape()) + ". " + implicit_batch_note); @@ -1201,15 +1202,9 @@ InferenceRequest::Normalize() // TensorRT backend. if (!input.IsNonLinearFormatIo()) { TRITONSERVER_MemoryType input_memory_type; - // Because Triton expects STRING type to be in special format - // (prepend 4 bytes to specify string length), so need to add all the - // first 4 bytes for each element to find expected byte size if (data_type == inference::DataType::TYPE_STRING) { - RETURN_IF_ERROR( - ValidateBytesInputs(input_id, input, &input_memory_type)); - - // FIXME: Temporarily skips byte size checks for GPU tensors. See - // DLIS-6820. + RETURN_IF_ERROR(ValidateBytesInputs( + input_name, input, model_name, &input_memory_type)); } else { // Shape tensor with dynamic batching does not introduce a new // dimension to the tensor but adds an additional value to the 1-D @@ -1217,16 +1212,15 @@ InferenceRequest::Normalize() const std::vector& input_dims = input.IsShapeTensor() ? input.OriginalShape() : input.ShapeWithBatchDim(); - int64_t expected_byte_size = INT_MAX; - expected_byte_size = + int64_t expected_byte_size = triton::common::GetByteSize(data_type, input_dims); const size_t& byte_size = input.Data()->TotalByteSize(); - if ((byte_size > INT_MAX) || + if ((byte_size > LLONG_MAX) || (static_cast(byte_size) != expected_byte_size)) { return Status( Status::Code::INVALID_ARG, LogRequest() + "input byte size mismatch for input '" + - input_id + "' for model '" + ModelName() + "'. Expected " + + input_name + "' for model '" + model_name + "'. Expected " + std::to_string(expected_byte_size) + ", got " + std::to_string(byte_size)); } @@ -1300,7 +1294,8 @@ InferenceRequest::ValidateRequestInputs() Status InferenceRequest::ValidateBytesInputs( - const std::string& input_id, const Input& input, + const std::string& input_name, const Input& input, + const std::string& model_name, TRITONSERVER_MemoryType* buffer_memory_type) const { const auto& input_dims = input.ShapeWithBatchDim(); @@ -1325,27 +1320,48 @@ InferenceRequest::ValidateBytesInputs( buffer_next_idx++, (const void**)(&buffer), &remaining_buffer_size, buffer_memory_type, &buffer_memory_id)); + // GPU tensors are validated at platform backends to avoid additional + // data copying. Check "ValidateStringBuffer" in backend_common.cc. if (*buffer_memory_type == TRITONSERVER_MEMORY_GPU) { return Status::Success; } } - constexpr size_t kElementSizeIndicator = sizeof(uint32_t); // Get the next element if not currently processing one. if (!remaining_element_size) { + // Triton expects STRING type to be in special format + // (prepend 4 bytes to specify string length), so need to add the + // first 4 bytes for each element to find expected byte size. + constexpr size_t kElementSizeIndicator = sizeof(uint32_t); + // FIXME: Assume the string element's byte size indicator is not spread // across buffer boundaries for simplicity. if (remaining_buffer_size < kElementSizeIndicator) { return Status( Status::Code::INVALID_ARG, LogRequest() + - "element byte size indicator exceeds the end of the buffer."); + "incomplete string length indicator for inference input '" + + input_name + "' for model '" + model_name + "', expecting " + + std::to_string(sizeof(uint32_t)) + " bytes but only " + + std::to_string(remaining_buffer_size) + + " bytes available. Please make sure the string length " + "indicator is in one buffer."); } // Start the next element and reset the remaining element size. remaining_element_size = *(reinterpret_cast(buffer)); element_checked++; + // Early stop + if (element_checked > element_count) { + return Status( + Status::Code::INVALID_ARG, + LogRequest() + "unexpected number of string elements " + + std::to_string(element_checked) + " for inference input '" + + input_name + "' for model '" + model_name + "', expecting " + + std::to_string(element_count)); + } + // Advance pointer and remainder by the indicator size. buffer += kElementSizeIndicator; remaining_buffer_size -= kElementSizeIndicator; @@ -1371,8 +1387,8 @@ InferenceRequest::ValidateBytesInputs( return Status( Status::Code::INVALID_ARG, LogRequest() + "expected " + std::to_string(buffer_count) + - " buffers for inference input '" + input_id + "', got " + - std::to_string(buffer_next_idx)); + " buffers for inference input '" + input_name + "' for model '" + + model_name + "', got " + std::to_string(buffer_next_idx)); } // Validate the number of processed elements exactly match expectations. @@ -1380,7 +1396,8 @@ InferenceRequest::ValidateBytesInputs( return Status( Status::Code::INVALID_ARG, LogRequest() + "expected " + std::to_string(element_count) + - " string elements for inference input '" + input_id + "', got " + + " string elements for inference input '" + input_name + + "' for model '" + model_name + "', got " + std::to_string(element_checked)); } diff --git a/src/infer_request.h b/src/infer_request.h index c180d438b..38c89ed63 100644 --- a/src/infer_request.h +++ b/src/infer_request.h @@ -775,6 +775,7 @@ class InferenceRequest { Status ValidateBytesInputs( const std::string& input_id, const Input& input, + const std::string& model_name, TRITONSERVER_MemoryType* buffer_memory_type) const; // Helpers for pending request metrics diff --git a/src/test/input_byte_size_test.cc b/src/test/input_byte_size_test.cc index 066988068..cf3e3bd58 100644 --- a/src/test/input_byte_size_test.cc +++ b/src/test/input_byte_size_test.cc @@ -258,10 +258,11 @@ char InputByteSizeTest::input_data_string_ TEST_F(InputByteSizeTest, ValidInputByteSize) { + const char* model_name = "savedmodel_zero_1_float32"; // Create an inference request FAIL_TEST_IF_ERR( TRITONSERVER_InferenceRequestNew( - &irequest_, server_, "pt_identity", -1 /* model_version */), + &irequest_, server_, model_name, -1 /* model_version */), "creating inference request"); FAIL_TEST_IF_ERR( TRITONSERVER_InferenceRequestSetReleaseCallback( @@ -269,8 +270,8 @@ TEST_F(InputByteSizeTest, ValidInputByteSize) "setting request release callback"); // Define input shape and data - std::vector shape{1, 8}; - std::vector input_data(8, 1); + std::vector shape{1, 16}; + std::vector input_data(16, 1); const auto input0_byte_size = sizeof(input_data[0]) * input_data.size(); // Set input for the request @@ -312,10 +313,11 @@ TEST_F(InputByteSizeTest, ValidInputByteSize) TEST_F(InputByteSizeTest, InputByteSizeMismatch) { + const char* model_name = "savedmodel_zero_1_float32"; // Create an inference request FAIL_TEST_IF_ERR( TRITONSERVER_InferenceRequestNew( - &irequest_, server_, "pt_identity", -1 /* model_version */), + &irequest_, server_, model_name, -1 /* model_version */), "creating inference request"); FAIL_TEST_IF_ERR( TRITONSERVER_InferenceRequestSetReleaseCallback( @@ -323,8 +325,8 @@ TEST_F(InputByteSizeTest, InputByteSizeMismatch) "setting request release callback"); // Define input shape and data - std::vector shape{1, 8}; - std::vector input_data(10, 1); + std::vector shape{1, 16}; + std::vector input_data(17, 1); const auto input0_byte_size = sizeof(input_data[0]) * input_data.size(); // Set input for the request @@ -353,8 +355,8 @@ TEST_F(InputByteSizeTest, InputByteSizeMismatch) FAIL_TEST_IF_SUCCESS( TRITONSERVER_ServerInferAsync(server_, irequest_, nullptr /* trace */), "expect error with inference request", - "input byte size mismatch for input 'INPUT0' for model 'pt_identity'. " - "Expected 32, got 40"); + "input byte size mismatch for input 'INPUT0' for model '" + + std::string{model_name} + "'. Expected 64, got 68"); // Need to manually delete request, otherwise server will not shut down. FAIL_TEST_IF_ERR( @@ -362,12 +364,69 @@ TEST_F(InputByteSizeTest, InputByteSizeMismatch) "deleting inference request"); } +TEST_F(InputByteSizeTest, InputByteSizeLarge) +{ + const char* model_name = "savedmodel_zero_1_float32"; + // Create an inference request + FAIL_TEST_IF_ERR( + TRITONSERVER_InferenceRequestNew( + &irequest_, server_, model_name, -1 /* model_version */), + "creating inference request"); + FAIL_TEST_IF_ERR( + TRITONSERVER_InferenceRequestSetReleaseCallback( + irequest_, InferRequestComplete, nullptr /* request_release_userp */), + "setting request release callback"); + + // Define input shape and data + size_t element_cnt = (1LL << 31) / sizeof(float); + std::vector shape{1, element_cnt}; + std::vector input_data(element_cnt, 1); + const auto input0_byte_size = sizeof(input_data[0]) * input_data.size(); + + // Set input for the request + FAIL_TEST_IF_ERR( + TRITONSERVER_InferenceRequestAddInput( + irequest_, "INPUT0", TRITONSERVER_TYPE_FP32, shape.data(), + shape.size()), + "setting input for the request"); + FAIL_TEST_IF_ERR( + TRITONSERVER_InferenceRequestAppendInputData( + irequest_, "INPUT0", input_data.data(), input0_byte_size, + TRITONSERVER_MEMORY_CPU, 0), + "assigning INPUT data"); + + std::promise p; + std::future future = p.get_future(); + + // Set response callback + FAIL_TEST_IF_ERR( + TRITONSERVER_InferenceRequestSetResponseCallback( + irequest_, allocator_, nullptr /* response_allocator_userp */, + InferResponseComplete, reinterpret_cast(&p)), + "setting response callback"); + + // Run inference + FAIL_TEST_IF_ERR( + TRITONSERVER_ServerInferAsync(server_, irequest_, nullptr /* trace */), + "running inference"); + + // Get the inference response + response_ = future.get(); + FAIL_TEST_IF_ERR( + TRITONSERVER_InferenceResponseError(response_), "response status"); + FAIL_TEST_IF_ERR( + TRITONSERVER_InferenceResponseDelete(response_), + "deleting inference response"); + ASSERT_TRUE(response_ != nullptr) << "Expect successful inference"; +} + TEST_F(InputByteSizeTest, ValidStringInputByteSize) { + const char* model_name = "savedmodel_zero_1_object"; // Create an inference request FAIL_TEST_IF_ERR( TRITONSERVER_InferenceRequestNew( - &irequest_, server_, "simple_identity", -1 /* model_version */), + &irequest_, server_, model_name, -1 /* model_version */), "creating inference request"); FAIL_TEST_IF_ERR( TRITONSERVER_InferenceRequestSetReleaseCallback( @@ -424,10 +483,11 @@ TEST_F(InputByteSizeTest, ValidStringInputByteSize) TEST_F(InputByteSizeTest, StringCountMismatch) { + const char* model_name = "savedmodel_zero_1_object"; // Create an inference request FAIL_TEST_IF_ERR( TRITONSERVER_InferenceRequestNew( - &irequest_, server_, "simple_identity", -1 /* model_version */), + &irequest_, server_, model_name, -1 /* model_version */), "creating inference request"); FAIL_TEST_IF_ERR( TRITONSERVER_InferenceRequestSetReleaseCallback( @@ -457,7 +517,8 @@ TEST_F(InputByteSizeTest, StringCountMismatch) FAIL_TEST_IF_SUCCESS( TRITONSERVER_ServerInferAsync(server_, irequest_, nullptr /* trace */), "expect error with inference request", - "expected 3 string elements for inference input 'INPUT0', got 2"); + "expected 3 string elements for inference input 'INPUT0' for model '" + + std::string{model_name} + "', got 2"); // Need to manually delete request, otherwise server will not shut down. FAIL_TEST_IF_ERR( @@ -467,7 +528,8 @@ TEST_F(InputByteSizeTest, StringCountMismatch) // Create an inference request FAIL_TEST_IF_ERR( TRITONSERVER_InferenceRequestNew( - &irequest_, server_, "simple_identity", -1 /* model_version */), + &irequest_, server_, "savedmodel_zero_1_object", + -1 /* model_version */), "creating inference request"); FAIL_TEST_IF_ERR( TRITONSERVER_InferenceRequestSetReleaseCallback( @@ -495,7 +557,9 @@ TEST_F(InputByteSizeTest, StringCountMismatch) FAIL_TEST_IF_SUCCESS( TRITONSERVER_ServerInferAsync(server_, irequest_, nullptr /* trace */), "expect error with inference request", - "expected 1 string elements for inference input 'INPUT0', got 2"); + "unexpected number of string elements 2 for inference input 'INPUT0' for " + "model '" + + std::string{model_name} + "', expecting 1"); // Need to manually delete request, otherwise server will not shut down. FAIL_TEST_IF_ERR( @@ -505,10 +569,11 @@ TEST_F(InputByteSizeTest, StringCountMismatch) TEST_F(InputByteSizeTest, StringSizeMisalign) { + const char* model_name = "savedmodel_zero_1_object"; // Create an inference request FAIL_TEST_IF_ERR( TRITONSERVER_InferenceRequestNew( - &irequest_, server_, "simple_identity", -1 /* model_version */), + &irequest_, server_, model_name, -1 /* model_version */), "creating inference request"); FAIL_TEST_IF_ERR( TRITONSERVER_InferenceRequestSetReleaseCallback( @@ -542,9 +607,13 @@ TEST_F(InputByteSizeTest, StringSizeMisalign) // Run inference FAIL_TEST_IF_SUCCESS( - TRITONSERVER_ServerInferAsync(server_, irequest_, nullptr /* trace - */), "expect error with inference request", - "element byte size indicator exceeds the end of the buffer"); + TRITONSERVER_ServerInferAsync(server_, irequest_, nullptr /* trace*/), + "expect error with inference request", + "incomplete string length indicator for inference input 'INPUT0' for " + "model '" + + std::string{model_name} + + "', expecting 4 bytes but only 2 bytes available. Please make sure " + "the string length indicator is in one buffer."); // Need to manually delete request, otherwise server will not shut down. FAIL_TEST_IF_ERR( @@ -573,7 +642,8 @@ TEST_F(InputByteSizeTest, StringCountMismatchGPU) // Create an inference request FAIL_TEST_IF_ERR( TRITONSERVER_InferenceRequestNew( - &irequest_, server_, "simple_identity", -1 /* model_version */), + &irequest_, server_, "savedmodel_zero_1_object", + -1 /* model_version */), "creating inference request"); FAIL_TEST_IF_ERR( TRITONSERVER_InferenceRequestSetReleaseCallback( @@ -629,7 +699,8 @@ TEST_F(InputByteSizeTest, StringCountMismatchGPU) // Create an inference request FAIL_TEST_IF_ERR( TRITONSERVER_InferenceRequestNew( - &irequest_, server_, "simple_identity", -1 /* model_version */), + &irequest_, server_, "savedmodel_zero_1_object", + -1 /* model_version */), "creating inference request"); FAIL_TEST_IF_ERR( TRITONSERVER_InferenceRequestSetReleaseCallback(