Skip to content

Commit

Permalink
refactor: Refactor core input size checks (#382)
Browse files Browse the repository at this point in the history
  • Loading branch information
yinggeh authored Sep 4, 2024
1 parent 623d0a5 commit 94eb61c
Show file tree
Hide file tree
Showing 3 changed files with 136 additions and 47 deletions.
73 changes: 45 additions & 28 deletions src/infer_request.cc
Original file line number Diff line number Diff line change
Expand Up @@ -910,6 +910,7 @@ Status
InferenceRequest::Normalize()
{
const inference::ModelConfig& model_config = model_raw_->Config();
const std::string& model_name = ModelName();

// Fill metadata for raw input
if (!raw_input_name_.empty()) {
Expand All @@ -922,7 +923,7 @@ InferenceRequest::Normalize()
std::to_string(original_inputs_.size()) +
") to be deduced but got " +
std::to_string(model_config.input_size()) + " inputs in '" +
ModelName() + "' model configuration");
model_name + "' model configuration");
}
auto it = original_inputs_.begin();
if (raw_input_name_ != it->first) {
Expand Down Expand Up @@ -1055,7 +1056,7 @@ InferenceRequest::Normalize()
Status::Code::INVALID_ARG,
LogRequest() + "input '" + input.Name() +
"' has no shape but model requires batch dimension for '" +
ModelName() + "'");
model_name + "'");
}

if (batch_size_ == 0) {
Expand All @@ -1064,7 +1065,7 @@ InferenceRequest::Normalize()
return Status(
Status::Code::INVALID_ARG,
LogRequest() + "input '" + input.Name() +
"' batch size does not match other inputs for '" + ModelName() +
"' batch size does not match other inputs for '" + model_name +
"'");
}

Expand All @@ -1080,7 +1081,7 @@ InferenceRequest::Normalize()
Status::Code::INVALID_ARG,
LogRequest() + "inference request batch-size must be <= " +
std::to_string(model_config.max_batch_size()) + " for '" +
ModelName() + "'");
model_name + "'");
}

// Verify that each input shape is valid for the model, make
Expand All @@ -1089,17 +1090,17 @@ InferenceRequest::Normalize()
const inference::ModelInput* input_config;
RETURN_IF_ERROR(model_raw_->GetInput(pr.second.Name(), &input_config));

auto& input_id = pr.first;
auto& input_name = pr.first;
auto& input = pr.second;
auto shape = input.MutableShape();

if (input.DType() != input_config->data_type()) {
return Status(
Status::Code::INVALID_ARG,
LogRequest() + "inference input '" + input_id + "' data-type is '" +
LogRequest() + "inference input '" + input_name + "' data-type is '" +
std::string(
triton::common::DataTypeToProtocolString(input.DType())) +
"', but model '" + ModelName() + "' expects '" +
"', but model '" + model_name + "' expects '" +
std::string(triton::common::DataTypeToProtocolString(
input_config->data_type())) +
"'");
Expand All @@ -1119,7 +1120,7 @@ InferenceRequest::Normalize()
Status::Code::INVALID_ARG,
LogRequest() +
"All input dimensions should be specified for input '" +
input_id + "' for model '" + ModelName() + "', got " +
input_name + "' for model '" + model_name + "', got " +
triton::common::DimsListToString(input.OriginalShape()));
} else if (
(config_dims[i] != triton::common::WILDCARD_DIM) &&
Expand Down Expand Up @@ -1148,8 +1149,8 @@ InferenceRequest::Normalize()
}
return Status(
Status::Code::INVALID_ARG,
LogRequest() + "unexpected shape for input '" + input_id +
"' for model '" + ModelName() + "'. Expected " +
LogRequest() + "unexpected shape for input '" + input_name +
"' for model '" + model_name + "'. Expected " +
triton::common::DimsListToString(full_dims) + ", got " +
triton::common::DimsListToString(input.OriginalShape()) + ". " +
implicit_batch_note);
Expand Down Expand Up @@ -1201,32 +1202,25 @@ InferenceRequest::Normalize()
// TensorRT backend.
if (!input.IsNonLinearFormatIo()) {
TRITONSERVER_MemoryType input_memory_type;
// Because Triton expects STRING type to be in special format
// (prepend 4 bytes to specify string length), so need to add all the
// first 4 bytes for each element to find expected byte size
if (data_type == inference::DataType::TYPE_STRING) {
RETURN_IF_ERROR(
ValidateBytesInputs(input_id, input, &input_memory_type));

// FIXME: Temporarily skips byte size checks for GPU tensors. See
// DLIS-6820.
RETURN_IF_ERROR(ValidateBytesInputs(
input_name, input, model_name, &input_memory_type));
} else {
// Shape tensor with dynamic batching does not introduce a new
// dimension to the tensor but adds an additional value to the 1-D
// array.
const std::vector<int64_t>& input_dims =
input.IsShapeTensor() ? input.OriginalShape()
: input.ShapeWithBatchDim();
int64_t expected_byte_size = INT_MAX;
expected_byte_size =
int64_t expected_byte_size =
triton::common::GetByteSize(data_type, input_dims);
const size_t& byte_size = input.Data()->TotalByteSize();
if ((byte_size > INT_MAX) ||
if ((byte_size > LLONG_MAX) ||
(static_cast<int64_t>(byte_size) != expected_byte_size)) {
return Status(
Status::Code::INVALID_ARG,
LogRequest() + "input byte size mismatch for input '" +
input_id + "' for model '" + ModelName() + "'. Expected " +
input_name + "' for model '" + model_name + "'. Expected " +
std::to_string(expected_byte_size) + ", got " +
std::to_string(byte_size));
}
Expand Down Expand Up @@ -1300,7 +1294,8 @@ InferenceRequest::ValidateRequestInputs()

Status
InferenceRequest::ValidateBytesInputs(
const std::string& input_id, const Input& input,
const std::string& input_name, const Input& input,
const std::string& model_name,
TRITONSERVER_MemoryType* buffer_memory_type) const
{
const auto& input_dims = input.ShapeWithBatchDim();
Expand All @@ -1325,27 +1320,48 @@ InferenceRequest::ValidateBytesInputs(
buffer_next_idx++, (const void**)(&buffer), &remaining_buffer_size,
buffer_memory_type, &buffer_memory_id));

// GPU tensors are validated at platform backends to avoid additional
// data copying. Check "ValidateStringBuffer" in backend_common.cc.
if (*buffer_memory_type == TRITONSERVER_MEMORY_GPU) {
return Status::Success;
}
}

constexpr size_t kElementSizeIndicator = sizeof(uint32_t);
// Get the next element if not currently processing one.
if (!remaining_element_size) {
// Triton expects STRING type to be in special format
// (prepend 4 bytes to specify string length), so need to add the
// first 4 bytes for each element to find expected byte size.
constexpr size_t kElementSizeIndicator = sizeof(uint32_t);

// FIXME: Assume the string element's byte size indicator is not spread
// across buffer boundaries for simplicity.
if (remaining_buffer_size < kElementSizeIndicator) {
return Status(
Status::Code::INVALID_ARG,
LogRequest() +
"element byte size indicator exceeds the end of the buffer.");
"incomplete string length indicator for inference input '" +
input_name + "' for model '" + model_name + "', expecting " +
std::to_string(sizeof(uint32_t)) + " bytes but only " +
std::to_string(remaining_buffer_size) +
" bytes available. Please make sure the string length "
"indicator is in one buffer.");
}

// Start the next element and reset the remaining element size.
remaining_element_size = *(reinterpret_cast<const uint32_t*>(buffer));
element_checked++;

// Early stop
if (element_checked > element_count) {
return Status(
Status::Code::INVALID_ARG,
LogRequest() + "unexpected number of string elements " +
std::to_string(element_checked) + " for inference input '" +
input_name + "' for model '" + model_name + "', expecting " +
std::to_string(element_count));
}

// Advance pointer and remainder by the indicator size.
buffer += kElementSizeIndicator;
remaining_buffer_size -= kElementSizeIndicator;
Expand All @@ -1371,16 +1387,17 @@ InferenceRequest::ValidateBytesInputs(
return Status(
Status::Code::INVALID_ARG,
LogRequest() + "expected " + std::to_string(buffer_count) +
" buffers for inference input '" + input_id + "', got " +
std::to_string(buffer_next_idx));
" buffers for inference input '" + input_name + "' for model '" +
model_name + "', got " + std::to_string(buffer_next_idx));
}

// Validate the number of processed elements exactly match expectations.
if (element_checked != element_count) {
return Status(
Status::Code::INVALID_ARG,
LogRequest() + "expected " + std::to_string(element_count) +
" string elements for inference input '" + input_id + "', got " +
" string elements for inference input '" + input_name +
"' for model '" + model_name + "', got " +
std::to_string(element_checked));
}

Expand Down
1 change: 1 addition & 0 deletions src/infer_request.h
Original file line number Diff line number Diff line change
Expand Up @@ -775,6 +775,7 @@ class InferenceRequest {

Status ValidateBytesInputs(
const std::string& input_id, const Input& input,
const std::string& model_name,
TRITONSERVER_MemoryType* buffer_memory_type) const;

// Helpers for pending request metrics
Expand Down
Loading

0 comments on commit 94eb61c

Please sign in to comment.