Skip to content

Commit

Permalink
Fix response factory cleanup
Browse files Browse the repository at this point in the history
  • Loading branch information
Tabrizian committed Sep 24, 2024
1 parent c42afe1 commit 47adab9
Show file tree
Hide file tree
Showing 3 changed files with 32 additions and 7 deletions.
1 change: 1 addition & 0 deletions src/infer_request.h
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,7 @@ class InferRequest {
InferenceTrace& GetTrace();
uint32_t ReleaseFlags();
void SetReleaseFlags(const uint32_t& flags);
intptr_t GetResponseFactoryAddress() { return response_factory_address_; }

#ifdef TRITON_PB_STUB
std::shared_ptr<InferResponse> Exec(const bool is_decoupled);
Expand Down
37 changes: 30 additions & 7 deletions src/python_be.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1089,6 +1089,17 @@ ModelInstanceState::ResponseSendDecoupled(
ResponseSendMessage* send_message_payload =
reinterpret_cast<ResponseSendMessage*>(send_message.data_.get());
std::unique_ptr<PbString> error_message;
ScopedDefer response_factory_deleter([send_message_payload] {
if (send_message_payload->flags == TRITONSERVER_RESPONSE_COMPLETE_FINAL) {
TRITONBACKEND_ResponseFactory* response_factory =
reinterpret_cast<TRITONBACKEND_ResponseFactory*>(
send_message_payload->response_factory_address);
std::unique_ptr<
TRITONBACKEND_ResponseFactory, backend::ResponseFactoryDeleter>
lresponse_factory(reinterpret_cast<TRITONBACKEND_ResponseFactory*>(
response_factory));
}
});
ScopedDefer _([send_message_payload] {
{
bi::scoped_lock<bi::interprocess_mutex> guard{send_message_payload->mu};
Expand Down Expand Up @@ -1214,13 +1225,6 @@ ModelInstanceState::ResponseSendDecoupled(
SetErrorForResponseSendMessage(
send_message_payload, WrapTritonErrorInSharedPtr(error), error_message);
}

if (send_message_payload->flags == TRITONSERVER_RESPONSE_COMPLETE_FINAL) {
std::unique_ptr<
TRITONBACKEND_ResponseFactory, backend::ResponseFactoryDeleter>
lresponse_factory(
reinterpret_cast<TRITONBACKEND_ResponseFactory*>(response_factory));
}
}

TRITONSERVER_Error*
Expand Down Expand Up @@ -1291,6 +1295,15 @@ ModelInstanceState::ProcessRequests(

if (response_batch_shm_ptr->has_error) {
if (response_batch_shm_ptr->is_error_set) {
for (uint32_t r = 0; r < request_count; r++) {
TRITONBACKEND_ResponseFactory* response_factory =
reinterpret_cast<TRITONBACKEND_ResponseFactory*>(
pb_infer_requests[r]->GetResponseFactoryAddress());
std::unique_ptr<
TRITONBACKEND_ResponseFactory, backend::ResponseFactoryDeleter>
lresponse_factory(reinterpret_cast<TRITONBACKEND_ResponseFactory*>(
response_factory));
}
auto error = PbString::LoadFromSharedMemory(
Stub()->ShmPool(), response_batch_shm_ptr->error);
return TRITONSERVER_ErrorNew(
Expand Down Expand Up @@ -1357,6 +1370,16 @@ ModelInstanceState::ProcessRequests(
(*responses)[r] = nullptr;
continue;
}
{
TRITONBACKEND_ResponseFactory* response_factory =
reinterpret_cast<TRITONBACKEND_ResponseFactory*>(
pb_infer_requests[r]->GetResponseFactoryAddress());
std::unique_ptr<
TRITONBACKEND_ResponseFactory, backend::ResponseFactoryDeleter>
lresponse_factory(
reinterpret_cast<TRITONBACKEND_ResponseFactory*>(
response_factory));
}
infer_response = InferResponse::LoadFromSharedMemory(
Stub()->ShmPool(), response_shm_handle[r],
false /* open_cuda_handle */);
Expand Down
1 change: 1 addition & 0 deletions src/response_sender.h
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ class ResponseSender {
const std::set<std::string>& requested_output_names,
std::unique_ptr<SharedMemoryManager>& shm_pool,
const std::shared_ptr<PbCancel>& pb_cancel);
intptr_t ResponseFactory() { return response_factory_address_; }
~ResponseSender();
void Send(std::shared_ptr<InferResponse> response, const uint32_t flags);
bool IsCancelled();
Expand Down

0 comments on commit 47adab9

Please sign in to comment.