From 794a5f02b0df4ef356a449a929f106950e69c3fe Mon Sep 17 00:00:00 2001 From: Yingge He Date: Thu, 29 Aug 2024 15:38:48 -0700 Subject: [PATCH] Fix bug when INT_MAX < byte_size <= LLONG_MAX --- src/infer_request.cc | 16 +++++++--------- 1 file changed, 7 insertions(+), 9 deletions(-) diff --git a/src/infer_request.cc b/src/infer_request.cc index 823d5fd29..550023323 100644 --- a/src/infer_request.cc +++ b/src/infer_request.cc @@ -1202,14 +1202,9 @@ InferenceRequest::Normalize() // TensorRT backend. if (!input.IsNonLinearFormatIo()) { TRITONSERVER_MemoryType input_memory_type; - // Because Triton expects STRING type to be in special format - // (prepend 4 bytes to specify string length), so need to add all the - // first 4 bytes for each element to find expected byte size if (data_type == inference::DataType::TYPE_STRING) { RETURN_IF_ERROR(ValidateBytesInputs( input_name, input, model_name, &input_memory_type)); - // FIXME: Temporarily skips byte size checks for GPU tensors. See - // DLIS-6820. } else { // Shape tensor with dynamic batching does not introduce a new // dimension to the tensor but adds an additional value to the 1-D @@ -1217,11 +1212,10 @@ InferenceRequest::Normalize() const std::vector& input_dims = input.IsShapeTensor() ? input.OriginalShape() : input.ShapeWithBatchDim(); - int64_t expected_byte_size = INT_MAX; - expected_byte_size = + int64_t expected_byte_size = triton::common::GetByteSize(data_type, input_dims); const size_t& byte_size = input.Data()->TotalByteSize(); - if ((byte_size > INT_MAX) || + if ((byte_size > LLONG_MAX) || (static_cast(byte_size) != expected_byte_size)) { return Status( Status::Code::INVALID_ARG, @@ -1331,9 +1325,13 @@ InferenceRequest::ValidateBytesInputs( } } - constexpr size_t kElementSizeIndicator = sizeof(uint32_t); // Get the next element if not currently processing one. if (!remaining_element_size) { + // Triton expects STRING type to be in special format + // (prepend 4 bytes to specify string length), so need to add the + // first 4 bytes for each element to find expected byte size. + constexpr size_t kElementSizeIndicator = sizeof(uint32_t); + // FIXME: Assume the string element's byte size indicator is not spread // across buffer boundaries for simplicity. if (remaining_buffer_size < kElementSizeIndicator) {