Skip to content

Commit

Permalink
fix: Models should filter outputs based on requested outputs (#366) (#…
Browse files Browse the repository at this point in the history
…367)

* Prune non requested outputs from non-decoupled models

* Prune non requested outputs from decoupled models

* [chore] Remove redundant copy
  • Loading branch information
kthui authored Jun 17, 2024
1 parent 5b6389d commit 3787765
Show file tree
Hide file tree
Showing 3 changed files with 17 additions and 8 deletions.
8 changes: 4 additions & 4 deletions src/infer_request.cc
Original file line number Diff line number Diff line change
Expand Up @@ -68,14 +68,13 @@ InferRequest::InferRequest(
}
}

inputs_ = inputs;
requested_output_names_ = requested_output_names;
#ifdef TRITON_PB_STUB
pb_cancel_ =
std::make_shared<PbCancel>(response_factory_address_, request_address_);
response_sender_ = std::make_shared<ResponseSender>(
request_address_, response_factory_address_, nullptr /* is_decoupled */,
Stub::GetOrCreateInstance()->SharedMemory(), pb_cancel_);
RequestedOutputNames(), Stub::GetOrCreateInstance()->SharedMemory(),
pb_cancel_);
#endif
}

Expand Down Expand Up @@ -390,7 +389,8 @@ InferRequest::InferRequest(
std::make_shared<PbCancel>(response_factory_address_, request_address_);
response_sender_ = std::make_shared<ResponseSender>(
request_address_, response_factory_address_, is_model_decoupled,
Stub::GetOrCreateInstance()->SharedMemory(), pb_cancel_);
RequestedOutputNames(), Stub::GetOrCreateInstance()->SharedMemory(),
pb_cancel_);
#endif
}

Expand Down
12 changes: 9 additions & 3 deletions src/response_sender.cc
Original file line number Diff line number Diff line change
Expand Up @@ -54,12 +54,15 @@ CheckResponseSenderArguments(

ResponseSender::ResponseSender(
intptr_t request_address, intptr_t response_factory_address,
bool const* is_decoupled, std::unique_ptr<SharedMemoryManager>& shm_pool,
bool const* is_decoupled,
const std::set<std::string>& requested_output_names,
std::unique_ptr<SharedMemoryManager>& shm_pool,
const std::shared_ptr<PbCancel>& pb_cancel)
: request_address_(request_address),
response_factory_address_(response_factory_address),
is_decoupled_(is_decoupled), shm_pool_(shm_pool), pb_cancel_(pb_cancel),
closed_(false), number_of_response_sent_(0)
is_decoupled_(is_decoupled),
requested_output_names_(requested_output_names), shm_pool_(shm_pool),
pb_cancel_(pb_cancel), closed_(false), number_of_response_sent_(0)
{
}

Expand Down Expand Up @@ -123,6 +126,9 @@ ResponseSender::Send(

CheckResponseSenderArguments(infer_response, flags);
UpdateStateAndCounters(infer_response, flags);
if (infer_response) {
infer_response->PruneOutputTensors(requested_output_names_);
}

std::unique_ptr<Stub>& stub = Stub::GetOrCreateInstance();

Expand Down
5 changes: 4 additions & 1 deletion src/response_sender.h
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,9 @@ class ResponseSender {
public:
ResponseSender(
intptr_t request_address, intptr_t response_factory_address,
bool const* is_decoupled, std::unique_ptr<SharedMemoryManager>& shm_pool,
bool const* is_decoupled,
const std::set<std::string>& requested_output_names,
std::unique_ptr<SharedMemoryManager>& shm_pool,
const std::shared_ptr<PbCancel>& pb_cancel);
~ResponseSender();
void Send(std::shared_ptr<InferResponse> response, const uint32_t flags);
Expand All @@ -54,6 +56,7 @@ class ResponseSender {
intptr_t request_address_;
intptr_t response_factory_address_;
bool const* is_decoupled_;
std::set<std::string> requested_output_names_;
std::unique_ptr<SharedMemoryManager>& shm_pool_;
std::shared_ptr<PbCancel> pb_cancel_;

Expand Down

0 comments on commit 3787765

Please sign in to comment.