Skip to content

Commit

Permalink
Some optimizations and fixes
Browse files Browse the repository at this point in the history
* Fix py_future object lifecycle

* Fix request released after complete final
  • Loading branch information
kthui committed Jan 12, 2025
1 parent 131078a commit 83f78f3
Show file tree
Hide file tree
Showing 3 changed files with 22 additions and 24 deletions.
2 changes: 0 additions & 2 deletions python/tritonserver/_api/_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,7 +211,6 @@ async def __anext__(self):

response = InferenceResponse._from_tritonserver_inference_response(
self._model,
self._request,
response,
flags,
self._inference_request.output_memory_type,
Expand Down Expand Up @@ -319,7 +318,6 @@ def __next__(self):

response = InferenceResponse._from_tritonserver_inference_response(
self._model,
self._request,
response,
flags,
self._inference_request.output_memory_type,
Expand Down
8 changes: 1 addition & 7 deletions python/tritonserver/_api/_response.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,11 +44,7 @@
from tritonserver._api._dlpack import DLDeviceType as DLDeviceType
from tritonserver._api._logging import LogMessage
from tritonserver._api._tensor import Tensor
from tritonserver._c.triton_bindings import (
InternalError,
TritonError,
TRITONSERVER_InferenceRequest,
)
from tritonserver._c.triton_bindings import InternalError, TritonError
from tritonserver._c.triton_bindings import TRITONSERVER_LogLevel as LogLevel
from tritonserver._c.triton_bindings import TRITONSERVER_MemoryType as MemoryType
from tritonserver._c.triton_bindings import (
Expand Down Expand Up @@ -103,14 +99,12 @@ class InferenceResponse:
@staticmethod
def _from_tritonserver_inference_response(
model: _model.Model,
request: TRITONSERVER_InferenceRequest,
response,
flags: TRITONSERVER_ResponseCompleteFlag,
output_memory_type: Optional[DeviceOrMemoryType] = None,
):
result = InferenceResponse(
model,
request.id,
final=(flags == TRITONSERVER_ResponseCompleteFlag.FINAL),
)

Expand Down
36 changes: 21 additions & 15 deletions python/tritonserver/_c/tritonserver_pybind.cc
Original file line number Diff line number Diff line change
Expand Up @@ -893,6 +893,8 @@ class PyInferenceRequest
}
}

void Release() { owned_ = false; }

void SetReleaseCallback()
{
ThrowIfError(TRITONSERVER_InferenceRequestSetReleaseCallback(
Expand All @@ -902,10 +904,9 @@ class PyInferenceRequest
struct TRITONSERVER_InferenceRequest* request, const uint32_t flags,
void* userp)
{
// This wrapper object will be kept alive for the entire inference process,
// so this wrapper will not attempt to destruct the Triton request object
// during inference. Thus, it is ok for now to not update the 'owned' flag
// after passing the Triton request to the core and before it is released.
// This function may be called after the request is deallocated, so the
// request object should not be attempted to be accessed.
ThrowIfError(TRITONSERVER_InferenceRequestDelete(request));
}

void SetResponseAllocator()
Expand Down Expand Up @@ -1159,16 +1160,15 @@ class PyInferenceRequest
}

// Response management
void GetNextResponse(py::object& py_future)
void GetNextResponse(const py::object& py_future)
{
std::lock_guard lock(response_mu_);

if (responses_.empty()) {
if (response_future_.get() != nullptr) {
throw AlreadyExistsError("cannot call GetNextResponse concurrently");
}
response_future_.reset(new py::object());
*response_future_ = py_future;
response_future_.reset(new py::object(py_future));
} else {
std::pair<std::shared_ptr<PyInferenceResponse>, const uint32_t>&
py_response = responses_.front();
Expand All @@ -1186,7 +1186,7 @@ class PyInferenceRequest
managed_ptr.reset(new PyInferenceResponse(response, true /* owned */));
}
std::pair<std::shared_ptr<PyInferenceResponse>, const uint32_t> py_response(
managed_ptr, flags);
std::move(managed_ptr), std::move(flags));
{
std::lock_guard lock(response_mu_);
if (response_future_.get() == nullptr) {
Expand All @@ -1196,13 +1196,14 @@ class PyInferenceRequest
}
{
py::gil_scoped_acquire gil;
PyFutureSetResult(*response_future_, py_response);
response_future_.reset(nullptr);
std::unique_ptr<py::object> response_future_local(nullptr);
response_future_.swap(response_future_local);
PyFutureSetResult(*response_future_local, py_response);
}
}

void PyFutureSetResult(
py::object& py_future,
const py::object& py_future,
std::pair<std::shared_ptr<PyInferenceResponse>, const uint32_t>&
py_response)
{
Expand Down Expand Up @@ -1634,16 +1635,25 @@ class PyServer : public PyWrapper<struct TRITONSERVER_Server> {
void InferAsync(
const std::shared_ptr<PyInferenceRequest>& request, PyTrace& trace)
{
request->SetReleaseCallback();
request->SetResponseAllocator();
request->SetResponseCallback();
ThrowIfError(TRITONSERVER_ServerInferAsync(
triton_object_, request->Ptr(), trace.Ptr()));
// Ownership of the internal C object is transferred.
request->Release();
trace.Release();
}

void InferAsync(const std::shared_ptr<PyInferenceRequest>& request)
{
request->SetReleaseCallback();
request->SetResponseAllocator();
request->SetResponseCallback();
ThrowIfError(
TRITONSERVER_ServerInferAsync(triton_object_, request->Ptr(), nullptr));
// Ownership of the internal C object is transferred.
request->Release();
}
};

Expand Down Expand Up @@ -1710,10 +1720,6 @@ PyInferenceRequest::PyInferenceRequest(
ThrowIfError(TRITONSERVER_InferenceRequestNew(
&triton_object_, server.Ptr(), model_name.c_str(), model_version));
owned_ = true;

SetReleaseCallback();
SetResponseAllocator();
SetResponseCallback();
}

// [FIXME] module name?
Expand Down

0 comments on commit 83f78f3

Please sign in to comment.