diff --git a/src/infer_request.cc b/src/infer_request.cc index 57ea6cf1..8a95b524 100644 --- a/src/infer_request.cc +++ b/src/infer_request.cc @@ -68,14 +68,13 @@ InferRequest::InferRequest( } } - inputs_ = inputs; - requested_output_names_ = requested_output_names; #ifdef TRITON_PB_STUB pb_cancel_ = std::make_shared(response_factory_address_, request_address_); response_sender_ = std::make_shared( request_address_, response_factory_address_, nullptr /* is_decoupled */, - Stub::GetOrCreateInstance()->SharedMemory(), pb_cancel_); + RequestedOutputNames(), Stub::GetOrCreateInstance()->SharedMemory(), + pb_cancel_); #endif } @@ -390,7 +389,8 @@ InferRequest::InferRequest( std::make_shared(response_factory_address_, request_address_); response_sender_ = std::make_shared( request_address_, response_factory_address_, is_model_decoupled, - Stub::GetOrCreateInstance()->SharedMemory(), pb_cancel_); + RequestedOutputNames(), Stub::GetOrCreateInstance()->SharedMemory(), + pb_cancel_); #endif } diff --git a/src/response_sender.cc b/src/response_sender.cc index 74914ab4..1831601f 100644 --- a/src/response_sender.cc +++ b/src/response_sender.cc @@ -54,12 +54,15 @@ CheckResponseSenderArguments( ResponseSender::ResponseSender( intptr_t request_address, intptr_t response_factory_address, - bool const* is_decoupled, std::unique_ptr& shm_pool, + bool const* is_decoupled, + const std::set& requested_output_names, + std::unique_ptr& shm_pool, const std::shared_ptr& pb_cancel) : request_address_(request_address), response_factory_address_(response_factory_address), - is_decoupled_(is_decoupled), shm_pool_(shm_pool), pb_cancel_(pb_cancel), - closed_(false), number_of_response_sent_(0) + is_decoupled_(is_decoupled), + requested_output_names_(requested_output_names), shm_pool_(shm_pool), + pb_cancel_(pb_cancel), closed_(false), number_of_response_sent_(0) { } @@ -123,6 +126,9 @@ ResponseSender::Send( CheckResponseSenderArguments(infer_response, flags); UpdateStateAndCounters(infer_response, flags); + if (infer_response) { + infer_response->PruneOutputTensors(requested_output_names_); + } std::unique_ptr& stub = Stub::GetOrCreateInstance(); diff --git a/src/response_sender.h b/src/response_sender.h index 1b57508e..f274f5b4 100644 --- a/src/response_sender.h +++ b/src/response_sender.h @@ -38,7 +38,9 @@ class ResponseSender { public: ResponseSender( intptr_t request_address, intptr_t response_factory_address, - bool const* is_decoupled, std::unique_ptr& shm_pool, + bool const* is_decoupled, + const std::set& requested_output_names, + std::unique_ptr& shm_pool, const std::shared_ptr& pb_cancel); ~ResponseSender(); void Send(std::shared_ptr response, const uint32_t flags); @@ -54,6 +56,7 @@ class ResponseSender { intptr_t request_address_; intptr_t response_factory_address_; bool const* is_decoupled_; + std::set requested_output_names_; std::unique_ptr& shm_pool_; std::shared_ptr pb_cancel_;