From b9cefa5ba90f1846c4d64ca517d8b3fe18c6cc07 Mon Sep 17 00:00:00 2001 From: Greg Larmore Date: Fri, 4 Oct 2024 14:51:47 -0600 Subject: [PATCH] fix: Add enumeration for tensor data type --- src/pb_stub.cc | 18 ++++++++++++++++++ 1 file changed, 18 insertions(+) diff --git a/src/pb_stub.cc b/src/pb_stub.cc index 007e7f29..acd85b58 100644 --- a/src/pb_stub.cc +++ b/src/pb_stub.cc @@ -1774,6 +1774,24 @@ PYBIND11_EMBEDDED_MODULE(c_python_backend_utils, module) .def("__dlpack__", &PbTensor::DLPack, py::arg("stream") = py::none()) .def("__dlpack_device__", &PbTensor::DLPackDevice); + py::enum_(module, "TritonDtype") + .value("TYPE_INVALID", TRITONSERVER_DataType::TRITONSERVER_TYPE_INVALID) + .value("TYPE_BOOL", TRITONSERVER_DataType::TRITONSERVER_TYPE_BOOL) + .value("TYPE_UINT8", TRITONSERVER_DataType::TRITONSERVER_TYPE_UINT8) + .value("TYPE_UINT16", TRITONSERVER_DataType::TRITONSERVER_TYPE_UINT16) + .value("TYPE_UINT32", TRITONSERVER_DataType::TRITONSERVER_TYPE_UINT32) + .value("TYPE_UINT64", TRITONSERVER_DataType::TRITONSERVER_TYPE_UINT64) + .value("TYPE_INT8", TRITONSERVER_DataType::TRITONSERVER_TYPE_INT8) + .value("TYPE_INT16", TRITONSERVER_DataType::TRITONSERVER_TYPE_INT16) + .value("TYPE_INT32", TRITONSERVER_DataType::TRITONSERVER_TYPE_INT32) + .value("TYPE_INT64", TRITONSERVER_DataType::TRITONSERVER_TYPE_INT64) + .value("TYPE_FP16", TRITONSERVER_DataType::TRITONSERVER_TYPE_FP16) + .value("TYPE_FP32", TRITONSERVER_DataType::TRITONSERVER_TYPE_FP32) + .value("TYPE_FP64", TRITONSERVER_DataType::TRITONSERVER_TYPE_FP64) + .value("TYPE_BYTES", TRITONSERVER_DataType::TRITONSERVER_TYPE_BYTES) + .value("TYPE_BF16", TRITONSERVER_DataType::TRITONSERVER_TYPE_BF16) + .export_values(); + py::class_>( module, "InferenceResponse") .def(