From 0f1f454867b57511e4ebdf5449483d0ce207b368 Mon Sep 17 00:00:00 2001 From: Chester Liu <4710575+skyline75489@users.noreply.github.com> Date: Tue, 25 Jun 2024 10:28:27 +0800 Subject: [PATCH 1/2] Fix C4459 warning in custom_op_lite.h (#751) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Internal workitem: https://task.ms/aii/29719 Co-authored-by: Xavier Dupré --- include/custom_op/custom_op_lite.h | 147 ++++++++++++++--------------- 1 file changed, 73 insertions(+), 74 deletions(-) diff --git a/include/custom_op/custom_op_lite.h b/include/custom_op/custom_op_lite.h index 960f64230..bcb746b91 100644 --- a/include/custom_op/custom_op_lite.h +++ b/include/custom_op/custom_op_lite.h @@ -16,19 +16,19 @@ namespace Custom { class OrtKernelContextStorage : public ITensorStorage { public: - OrtKernelContextStorage(const OrtW::CustomOpApi& api, + OrtKernelContextStorage(const OrtW::CustomOpApi& custom_op_api, OrtKernelContext& ctx, size_t indice, - bool is_input) : api_(api), ctx_(ctx), indice_(indice) { + bool is_input) : api_(custom_op_api), ctx_(ctx), indice_(indice) { if (is_input) { - auto input_count = api.KernelContext_GetInputCount(&ctx); + auto input_count = api_.KernelContext_GetInputCount(&ctx); if (indice >= input_count) { ORTX_CXX_API_THROW("invalid indice", ORT_RUNTIME_EXCEPTION); } - const_value_ = api.KernelContext_GetInput(&ctx, indice); - auto* info = api.GetTensorTypeAndShape(const_value_); - shape_ = api.GetTensorShape(info); - api.ReleaseTensorTypeAndShapeInfo(info); + const_value_ = api_.KernelContext_GetInput(&ctx, indice); + auto* info = api_.GetTensorTypeAndShape(const_value_); + shape_ = api_.GetTensorShape(info); + api_.ReleaseTensorTypeAndShapeInfo(info); } } @@ -66,18 +66,18 @@ class OrtKernelContextStorage : public ITensorStorage { std::optional> shape_; }; -static std::string get_mem_type(const OrtW::CustomOpApi& api, - OrtKernelContext& ctx, - size_t indice, - bool is_input){ +static std::string get_mem_type(const OrtW::CustomOpApi& custom_op_api, + OrtKernelContext& ctx, + size_t indice, + bool is_input) { std::string output = "Cpu"; if (is_input) { - const OrtValue* const_value = api.KernelContext_GetInput(&ctx, indice); + const OrtValue* const_value = custom_op_api.KernelContext_GetInput(&ctx, indice); const OrtMemoryInfo* mem_info = {}; - api.ThrowOnError(api.GetOrtApi().GetTensorMemoryInfo(const_value, &mem_info)); + custom_op_api.ThrowOnError(custom_op_api.GetOrtApi().GetTensorMemoryInfo(const_value, &mem_info)); if (mem_info) { const char* mem_type = nullptr; - api.ThrowOnError(api.GetOrtApi().MemoryInfoGetName(mem_info, &mem_type)); + custom_op_api.ThrowOnError(custom_op_api.GetOrtApi().MemoryInfoGetName(mem_info, &mem_type)); if (mem_type) { output = mem_type; } @@ -88,29 +88,29 @@ static std::string get_mem_type(const OrtW::CustomOpApi& api, template class OrtTensor : public Tensor { -public: - OrtTensor(const OrtW::CustomOpApi& api, - OrtKernelContext& ctx, - size_t indice, - bool is_input) : Tensor(std::make_unique(api, ctx, indice, is_input)), - mem_type_(get_mem_type(api, ctx, indice, is_input)) { + public: + OrtTensor(const OrtW::CustomOpApi& custom_op_api, + OrtKernelContext& ctx, + size_t indice, + bool is_input) : Tensor(std::make_unique(custom_op_api, ctx, indice, is_input)), + mem_type_(get_mem_type(custom_op_api, ctx, indice, is_input)) { } bool IsCpuTensor() const { return mem_type_ == "Cpu"; } -private: + private: std::string mem_type_ = "Cpu"; }; class OrtStringTensorStorage : public IStringTensorStorage { public: using strings = std::vector; - OrtStringTensorStorage(const OrtW::CustomOpApi& api, + OrtStringTensorStorage(const OrtW::CustomOpApi& custom_op_api, OrtKernelContext& ctx, size_t indice, - bool is_input) : api_(api), ctx_(ctx), indice_(indice) { + bool is_input) : api_(custom_op_api), ctx_(ctx), indice_(indice) { if (is_input) { auto input_count = api_.KernelContext_GetInputCount(&ctx_); if (indice >= input_count) { @@ -197,10 +197,10 @@ class OrtStringTensorStorage : public IStringTensorStorage { class OrtStringViewTensorStorage : public IStringTensorStorage { public: using strings = std::vector; - OrtStringViewTensorStorage(const OrtW::CustomOpApi& api, + OrtStringViewTensorStorage(const OrtW::CustomOpApi& custom_op_api, OrtKernelContext& ctx, size_t indice, - bool is_input) : api_(api), ctx_(ctx), indice_(indice) { + bool is_input) : api_(custom_op_api), ctx_(ctx), indice_(indice) { if (is_input) { auto input_count = api_.KernelContext_GetInputCount(&ctx_); if (indice >= input_count) { @@ -275,57 +275,56 @@ class OrtStringViewTensorStorage : public IStringTensorStorage // to make the metaprogramming magic happy. template <> -class OrtTensor : public Tensor{ -public: - OrtTensor(const OrtW::CustomOpApi& api, - OrtKernelContext& ctx, - size_t indice, - bool is_input) : Tensor(std::make_unique(api, ctx, indice, is_input)), - mem_type_(get_mem_type(api, ctx, indice, is_input)) {} - +class OrtTensor : public Tensor { + public: + OrtTensor(const OrtW::CustomOpApi& custom_op_api, + OrtKernelContext& ctx, + size_t indice, + bool is_input) : Tensor(std::make_unique(custom_op_api, ctx, indice, is_input)), + mem_type_(get_mem_type(custom_op_api, ctx, indice, is_input)) {} + bool IsCpuTensor() const { return mem_type_ == "Cpu"; } -private: + private: std::string mem_type_ = "Cpu"; }; template <> -class OrtTensor : public Tensor{ -public: - OrtTensor(const OrtW::CustomOpApi& api, - OrtKernelContext& ctx, - size_t indice, - bool is_input) : Tensor(std::make_unique(api, ctx, indice, is_input)), - mem_type_(get_mem_type(api, ctx, indice, is_input)) {} +class OrtTensor : public Tensor { + public: + OrtTensor(const OrtW::CustomOpApi& custom_op_api, + OrtKernelContext& ctx, + size_t indice, + bool is_input) : Tensor(std::make_unique(custom_op_api, ctx, indice, is_input)), + mem_type_(get_mem_type(custom_op_api, ctx, indice, is_input)) {} bool IsCpuTensor() const { return mem_type_ == "Cpu"; } -private: + private: std::string mem_type_ = "Cpu"; }; using TensorPtr = std::unique_ptr; using TensorPtrs = std::vector; - using TensorBasePtr = std::unique_ptr; using TensorBasePtrs = std::vector; // Represent variadic input or output struct Variadic : public Arg { - Variadic(const OrtW::CustomOpApi& api, + Variadic(const OrtW::CustomOpApi& custom_op_api, OrtKernelContext& ctx, size_t indice, - bool is_input) : api_(api), ctx_(ctx), indice_(indice), mem_type_(get_mem_type(api, ctx, indice, is_input)) { + bool is_input) : api_(custom_op_api), ctx_(ctx), indice_(indice), mem_type_(get_mem_type(custom_op_api, ctx, indice, is_input)) { #if ORT_API_VERSION < 14 ORTX_CXX_API_THROW("Variadic input or output only supported after onnxruntime 1.14", ORT_RUNTIME_EXCEPTION); #endif if (is_input) { - auto input_count = api.KernelContext_GetInputCount(&ctx_); + auto input_count = api_.KernelContext_GetInputCount(&ctx_); for (size_t ith_input = 0; ith_input < input_count; ++ith_input) { auto* const_value = api_.KernelContext_GetInput(&ctx_, ith_input); auto* info = api_.GetTensorTypeAndShape(const_value); @@ -334,40 +333,40 @@ struct Variadic : public Arg { TensorBasePtr tensor; switch (type) { case ONNX_TENSOR_ELEMENT_DATA_TYPE_BOOL: - tensor = std::make_unique>(api, ctx, ith_input, true); + tensor = std::make_unique>(api_, ctx, ith_input, true); break; case ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT: - tensor = std::make_unique>(api, ctx, ith_input, true); + tensor = std::make_unique>(api_, ctx, ith_input, true); break; case ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE: - tensor = std::make_unique>(api, ctx, ith_input, true); + tensor = std::make_unique>(api_, ctx, ith_input, true); break; case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8: - tensor = std::make_unique>(api, ctx, ith_input, true); + tensor = std::make_unique>(api_, ctx, ith_input, true); break; case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8: - tensor = std::make_unique>(api, ctx, ith_input, true); + tensor = std::make_unique>(api_, ctx, ith_input, true); break; case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT16: - tensor = std::make_unique>(api, ctx, ith_input, true); + tensor = std::make_unique>(api_, ctx, ith_input, true); break; case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT16: - tensor = std::make_unique>(api, ctx, ith_input, true); + tensor = std::make_unique>(api_, ctx, ith_input, true); break; case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT32: - tensor = std::make_unique>(api, ctx, ith_input, true); + tensor = std::make_unique>(api_, ctx, ith_input, true); break; case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32: - tensor = std::make_unique>(api, ctx, ith_input, true); + tensor = std::make_unique>(api_, ctx, ith_input, true); break; case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT64: - tensor = std::make_unique>(api, ctx, ith_input, true); + tensor = std::make_unique>(api_, ctx, ith_input, true); break; case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64: - tensor = std::make_unique>(api, ctx, ith_input, true); + tensor = std::make_unique>(api_, ctx, ith_input, true); break; case ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING: - tensor = std::make_unique>(api, ctx, ith_input, true); + tensor = std::make_unique>(api_, ctx, ith_input, true); break; default: ORTX_CXX_API_THROW("unknow input type", ORT_RUNTIME_EXCEPTION); @@ -395,7 +394,7 @@ struct Variadic : public Arg { size_t Size() const { return tensors_.size(); } - + const TensorBasePtr& operator[](size_t indice) const { return tensors_.at(indice); } @@ -412,11 +411,11 @@ struct Variadic : public Arg { class OrtGraphKernelContext : public KernelContext { public: - OrtGraphKernelContext(const OrtApi& api, const OrtKernelContext& ctx) : api_(api) { + OrtGraphKernelContext(const OrtApi& ort_api, const OrtKernelContext& ctx) : api_(ort_api) { OrtMemoryInfo* info; - OrtW::ThrowOnError(api, api.CreateCpuMemoryInfo(OrtAllocatorType::OrtArenaAllocator, OrtMemType::OrtMemTypeDefault, &info)); - OrtW::ThrowOnError(api, api.KernelContext_GetAllocator(&ctx, info, &allocator_)); - api.ReleaseMemoryInfo(info); + OrtW::ThrowOnError(api_, api_.CreateCpuMemoryInfo(OrtAllocatorType::OrtArenaAllocator, OrtMemType::OrtMemTypeDefault, &info)); + OrtW::ThrowOnError(api_, api_.KernelContext_GetAllocator(&ctx, info, &allocator_)); + api_.ReleaseMemoryInfo(info); } virtual ~OrtGraphKernelContext() { @@ -458,31 +457,31 @@ class OrtGraphCudaKernelContext : public CUDAKernelContext { public: static const int cuda_resource_ver = 1; - OrtGraphCudaKernelContext(const OrtApi& api, const OrtKernelContext& ctx) : api_(api) { - api.KernelContext_GetResource(&ctx, cuda_resource_ver, CudaResource::cuda_handle_t, &cuda_stream_); + OrtGraphCudaKernelContext(const OrtApi& ort_api, const OrtKernelContext& ctx) : api_(ort_api) { + api_.KernelContext_GetResource(&ctx, cuda_resource_ver, CudaResource::cuda_handle_t, &cuda_stream_); if (!cuda_stream_) { ORTX_CXX_API_THROW("Failed to fetch cuda stream from context", ORT_RUNTIME_EXCEPTION); } - api.KernelContext_GetResource(&ctx, cuda_resource_ver, CudaResource::cublas_handle_t, &cublas_); + api_.KernelContext_GetResource(&ctx, cuda_resource_ver, CudaResource::cublas_handle_t, &cublas_); if (!cublas_) { ORTX_CXX_API_THROW("Failed to fetch cublas handle from context", ORT_RUNTIME_EXCEPTION); } void* resource = nullptr; - OrtStatusPtr result = api.KernelContext_GetResource(&ctx, cuda_resource_ver, CudaResource::device_id_t, &resource); + OrtStatusPtr result = api_.KernelContext_GetResource(&ctx, cuda_resource_ver, CudaResource::device_id_t, &resource); if (result) { ORTX_CXX_API_THROW("Failed to fetch device id from context", ORT_RUNTIME_EXCEPTION); } memcpy(&device_id_, &resource, sizeof(int)); OrtMemoryInfo* info; - OrtW::ThrowOnError(api, api.CreateCpuMemoryInfo(OrtAllocatorType::OrtArenaAllocator, OrtMemType::OrtMemTypeDefault, &info)); - OrtW::ThrowOnError(api, api.KernelContext_GetAllocator(&ctx, info, &cpu_allocator_)); - api.ReleaseMemoryInfo(info); + OrtW::ThrowOnError(api_, api_.CreateCpuMemoryInfo(OrtAllocatorType::OrtArenaAllocator, OrtMemType::OrtMemTypeDefault, &info)); + OrtW::ThrowOnError(api_, api_.KernelContext_GetAllocator(&ctx, info, &cpu_allocator_)); + api_.ReleaseMemoryInfo(info); OrtMemoryInfo* cuda_mem_info; - OrtW::ThrowOnError(api, api.CreateMemoryInfo("Cuda", OrtAllocatorType::OrtArenaAllocator, device_id_, OrtMemType::OrtMemTypeDefault, &cuda_mem_info)); - OrtW::ThrowOnError(api, api.KernelContext_GetAllocator(&ctx, cuda_mem_info, &cuda_allocator_)); - api.ReleaseMemoryInfo(cuda_mem_info); + OrtW::ThrowOnError(api_, api_.CreateMemoryInfo("Cuda", OrtAllocatorType::OrtArenaAllocator, device_id_, OrtMemType::OrtMemTypeDefault, &cuda_mem_info)); + OrtW::ThrowOnError(api_, api_.KernelContext_GetAllocator(&ctx, cuda_mem_info, &cuda_allocator_)); + api_.ReleaseMemoryInfo(cuda_mem_info); } virtual ~OrtGraphCudaKernelContext() { @@ -944,7 +943,7 @@ struct OrtLiteCustomFunc : public OrtLiteCustomOp { class OrtAttributeReader { public: - OrtAttributeReader(const OrtApi& api, const OrtKernelInfo& info) : base_kernel_(api, info) { + OrtAttributeReader(const OrtApi& ort_api, const OrtKernelInfo& info) : base_kernel_(ort_api, info) { } template From f1abea14e88422a16e8337f86820bf306b506c61 Mon Sep 17 00:00:00 2001 From: Wenbing Li <10278425+wenbingl@users.noreply.github.com> Date: Tue, 25 Jun 2024 11:12:21 -0700 Subject: [PATCH 2/2] Update CMakeLists.txt (#754) --- CMakeLists.txt | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index 3b9cb0ead..4f0e8440a 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -36,9 +36,10 @@ if(CMAKE_VERSION VERSION_GREATER_EQUAL "3.24.0") cmake_policy(SET CMP0077 NEW) endif() -# Avoid warning of Calling FetchContent_Populate(GSL) is deprecated +# Avoid warning of Calling FetchContent_Populate(GSL) is deprecated temporarily +# TODO: find a better way to handle the header-only 3rd party deps if(CMAKE_VERSION VERSION_GREATER_EQUAL "3.30.0") - cmake_policy(CMP0169 OLD) + cmake_policy(SET CMP0169 OLD) endif() # Needed for Java