Skip to content

Commit

Permalink
Merge branch 'main' into rotary
Browse files Browse the repository at this point in the history
  • Loading branch information
xadupre authored Jun 24, 2024
2 parents 1c535cb + 3b275b1 commit d70a588
Show file tree
Hide file tree
Showing 21 changed files with 1,070 additions and 328 deletions.
5 changes: 1 addition & 4 deletions .clang-format
Original file line number Diff line number Diff line change
@@ -1,10 +1,7 @@
---
# Defaults for all languages.
BasedOnStyle: Google

# Setting ColumnLimit to 0 so developer choices about where to break lines are maintained.
# Developers are responsible for adhering to the 120 character maximum.
ColumnLimit: 0
ColumnLimit: 120
SortIncludes: false
DerivePointerAlignment: false

Expand Down
2 changes: 1 addition & 1 deletion cgmanifest.json
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,7 @@
"component": {
"type": "git",
"git": {
"commitHash": "80dc998efced8ceb2be59756668a7e90e8bef917",
"commitHash": "3e9dfa2866941655c56877882565e7577de6fc7b",
"repositoryUrl": "https://github.com/pybind/pybind11.git"
},
"comments": "v2.10.1"
Expand Down
4 changes: 2 additions & 2 deletions cmake/ext_tests.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -145,8 +145,8 @@ if (OCOS_ENABLE_C_API)
"$<TARGET_PROPERTY:ortcustomops,INTERFACE_INCLUDE_DIRECTORIES>"
"$<TARGET_PROPERTY:ocos_operators,INTERFACE_INCLUDE_DIRECTORIES>")

if (ORTX_TEST_DATA2)
file(TO_NATIVE_PATH "${ORTX_TEST_DATA2}/tests/data2" _TEST_DATA2)
if (ORTX_DATA_PATH)
file(TO_NATIVE_PATH "${ORTX_DATA_PATH}/tests/data2" _TEST_DATA2)
add_custom_command(TARGET pp_api_test POST_BUILD
COMMAND ${CMAKE_COMMAND} -E create_symlink ${_TEST_DATA2} ${onnxruntime_extensions_BINARY_DIR}/data2)
endif()
Expand Down
2 changes: 2 additions & 0 deletions cmake/externals/json.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -8,3 +8,5 @@ FetchContent_GetProperties(nlohmann_json)
if(NOT nlohmann_json_POPULATED)
FetchContent_Populate(nlohmann_json)
endif()

add_compile_definitions(JSON_HAS_CPP_17=1)
4 changes: 2 additions & 2 deletions cmake/externals/pybind11.cmake
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
FetchContent_Declare(
pybind11
URL https://github.com/pybind/pybind11/archive/refs/tags/v2.10.1.zip
URL_HASH SHA1=769b6aa67a77f17a770960f604b727645b6f6a13
URL https://github.com/pybind/pybind11/archive/refs/tags/v2.12.0.zip
URL_HASH SHA1=8482f57ed55c7b100672815a311d5450858723fb
)

FetchContent_GetProperties(pybind11)
Expand Down
14 changes: 14 additions & 0 deletions include/custom_op/tensor_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -174,6 +174,8 @@ class TensorBase : public Arg {
virtual int64_t NumberOfElement() const = 0;
virtual const void* DataRaw() const = 0;
virtual size_t SizeInBytes() const = 0;

virtual std::byte* AllocateRaw(const std::vector<int64_t>& shape) = 0;
};

template <typename T>
Expand Down Expand Up @@ -283,6 +285,10 @@ class Tensor : public TensorBase {
return static_cast<TT*>(buffer);
}

std::byte* AllocateRaw(const std::vector<int64_t>& shape) override {
return reinterpret_cast<std::byte*>(Allocate(shape));
}

const Span<T>& AsSpan() {
if (!storage_)
ORTX_CXX_API_THROW("tensor not initialized.", ORT_RUNTIME_EXCEPTION);
Expand Down Expand Up @@ -448,6 +454,10 @@ class Tensor<std::string> : public TensorBase {
return ss[0].size();
}

std::byte* AllocateRaw(const std::vector<int64_t>& shape) override {
ORTX_CXX_API_THROW("AllocateRaw() not supported for string tensor", ORT_RUNTIME_EXCEPTION);
}

void SetStringOutput(const strings& ss, const std::vector<int64_t>& dims) {
storage_->SetStringOutput(ss, dims);
}
Expand Down Expand Up @@ -522,6 +532,10 @@ class Tensor<std::string_view> : public TensorBase {
return ss[0].size();
}

std::byte* AllocateRaw(const std::vector<int64_t>& shape) override {
ORTX_CXX_API_THROW("AllocateRaw() not supported for string tensor", ORT_RUNTIME_EXCEPTION);
}

void SetStringOutput(const strings& ss, const std::vector<int64_t>& dims) {
storage_->SetStringOutput(ss, dims);
}
Expand Down
2 changes: 1 addition & 1 deletion include/ort_c_to_cpp.h
Original file line number Diff line number Diff line change
Expand Up @@ -343,8 +343,8 @@ struct BaseKernel {
OrtErrorCode GetErrorCodeAndRelease(OrtStatusPtr status) const noexcept;

const OrtApi& api_;
OrtW::CustomOpApi ort_;
const OrtKernelInfo& info_;
OrtW::CustomOpApi ort_;
};

// Deprecated: Use OrtW::CustomOpApi::KernelInfoGetAttribute instead
Expand Down
97 changes: 97 additions & 0 deletions include/ortx_c_helper.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

#pragma once

#include "ortx_utils.h"

namespace ort_extensions {

template <typename T>
class OrtxDeleter {
public:
void operator()(T* p) const {
if (p) {
OrtxDisposeOnly(p);
}
}
};

/**
* @brief A smart pointer class that manages the lifetime of an OrtxObject.
*
* This class is derived from std::unique_ptr and provides additional functionality
* specific to OrtxObject. It automatically calls the OrtxDeleter to release the
* owned object when it goes out of scope.
*
* @tparam T The type of the object being managed.
*/
template <typename T>
class OrtxObjectPtr : public std::unique_ptr<T, OrtxDeleter<T>> {
public:
/**
* @brief Default constructor.
*
* Constructs an OrtxObjectPtr with a null pointer.
*/
OrtxObjectPtr() : std::unique_ptr<T, OrtxDeleter<T>>(nullptr) {}

/**
* @brief Constructor that creates an OrtxObjectPtr from a function call.
*
* This constructor calls the specified function with the given arguments to
* create an OrtxObject. If the function call succeeds, the created object is
* owned by the OrtxObjectPtr.
*
* @tparam TFn The type of the function pointer or function object.
* @tparam Args The types of the arguments to be passed to the function.
* @param fn The function pointer or function object used to create the OrtxObject.
* @param args The arguments to be passed to the function.
*/
template <typename TFn, typename... Args>
OrtxObjectPtr(TFn fn, Args&&... args) {
OrtxObject* proc = nullptr;
err_ = fn(&proc, std::forward<Args>(args)...);
if (err_ == kOrtxOK) {
this->reset(static_cast<T*>(proc));
}
}

/**
* @brief Get the error code associated with the creation of the OrtxObject.
*
* @return The error code.
*/
extError_t Code() const { return err_; }

private:
extError_t err_ = kOrtxOK; /**< The error code associated with the creation of the OrtxObject. */
};

template <typename T>
struct PointerAssigner {
OrtxObject* obj_{};
OrtxObjectPtr<T>& ptr_;
PointerAssigner(OrtxObjectPtr<T>& ptr) : ptr_(ptr){};

~PointerAssigner() { ptr_.reset(static_cast<T*>(obj_)); };

operator T**() { return reinterpret_cast<T**>(&obj_); };
};

/**
* @brief A wrapper function for OrtxObjectPtr that can be used as a function parameter on creation.
*
* This function creates a PointerAssigner object for the given OrtxObjectPtr. The PointerAssigner
* object can be used to assign a pointer value to the OrtxObjectPtr.
*
* @tparam T The type of the object pointed to by the OrtxObjectPtr.
* @param ptr The OrtxObjectPtr to create the PointerAssigner for.
* @return A PointerAssigner object for the given OrtxObjectPtr.
*/
template <typename T>
PointerAssigner<T> ptr(OrtxObjectPtr<T>& ptr) {
return PointerAssigner<T>{ptr};
};

} // namespace ort_extensions
53 changes: 51 additions & 2 deletions include/ortx_processor.h
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@

// typedefs to create/dispose function flood, and to make the API more C++ friendly with less casting
typedef OrtxObject OrtxProcessor;
typedef OrtxObject OrtxRawImages;
typedef OrtxObject OrtxImageProcessorResult;

#ifdef __cplusplus
extern "C" {
Expand All @@ -17,11 +19,58 @@ extern "C" {
/** \brief Create a processor object with the specified processor definition
*
* \param processor Pointer to store the created processor object
* \param processor_def The processor definition, either a path to the processor directory or a JSON string, and is utf-8 encoded.
* \return Error code indicating the success or failure of the operation
* \param processor_def The processor definition, either a path to the processor directory or a JSON string, and is
* utf-8 encoded. \return Error code indicating the success or failure of the operation
*/
extError_t ORTX_API_CALL OrtxCreateProcessor(OrtxProcessor** processor, const char* processor_def);

/**
* @brief Loads a set of images from the specified image paths.
*
* This function loads a set of images from the given image paths and returns a pointer to the loaded images.
* The number of images loaded is also returned through the `num_images_loaded` parameter.
*
* @param[out] images A pointer to a pointer that will be set to the loaded images.
* @param[in] image_paths An array of image paths.
* @param[in] num_images The number of images to load.
* @param[out] num_images_loaded A pointer to a variable that will be set to the number of images loaded.
*
* @return An error code indicating the status of the operation.
*/
extError_t ORTX_API_CALL OrtxLoadImages(OrtxRawImages** images, const char** image_paths, size_t num_images,
size_t* num_images_loaded);

/**
* @brief Preprocesses the given raw images using the specified processor.
*
* This function applies preprocessing operations on the raw images using the provided processor.
* The result of the preprocessing is stored in the `OrtxImageProcessorResult` object.
*
* @param processor A pointer to the `OrtxProcessor` object used for preprocessing.
* @param images A pointer to the `OrtxRawImages` object containing the raw images to be processed.
* @param result A pointer to the `OrtxImageProcessorResult` object to store the preprocessing result.
* @return An `extError_t` value indicating the success or failure of the preprocessing operation.
*/
extError_t ORTX_API_CALL OrtxImagePreProcess(OrtxProcessor* processor, OrtxRawImages* images,
OrtxImageProcessorResult** result);

/**
* @brief Retrieves the image processor result at the specified index.
*
* @param result Pointer to the OrtxImageProcessorResult structure to store the result.
* @param index The index of the result to retrieve.
* @return extError_t The error code indicating the success or failure of the operation.
*/
extError_t ORTX_API_CALL OrtxImageGetTensorResult(OrtxImageProcessorResult* result, size_t index, OrtxTensor** tensor);

/** \brief Clear the outputs of the processor
*
* \param processor The processor object
* \param result The result object to clear
* \return Error code indicating the success or failure of the operation
*/
extError_t ORTX_API_CALL OrtxClearOutputs(OrtxProcessor* processor, OrtxImageProcessorResult* result);

#ifdef __cplusplus
}
#endif
44 changes: 42 additions & 2 deletions include/ortx_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@

#include "ortx_types.h"

const int API_VERSION = 1;

typedef enum {
kOrtxKindUnknown = 0,

Expand All @@ -14,7 +16,10 @@ typedef enum {
kOrtxKindTokenId2DArray = 0x778A,
kOrtxKindDetokenizerCache = 0x778B,
kOrtxKindProcessor = 0x778C,
kOrtxKindProcessorResult = 0x778D,
kOrtxKindRawImages = 0x778D,
kOrtxKindImageProcessorResult = 0x778E,
kOrtxKindProcessorResult = 0x778F,
kOrtxKindTensor = 0x7790,
kOrtxKindEnd = 0x7999
} extObjectKind_t;

Expand All @@ -24,7 +29,7 @@ typedef struct {
int ext_kind_;
} OrtxObject;

const int API_VERSION = 1;
typedef OrtxObject OrtxTensor;

// C, instead of C++ doesn't cast automatically,
// so we need to use a macro to cast the object to the correct type
Expand Down Expand Up @@ -72,6 +77,41 @@ extError_t ORTX_API_CALL OrtxDispose(OrtxObject** object);
*/
extError_t ORTX_API_CALL OrtxDisposeOnly(OrtxObject* object);

/** \brief Get the data from the tensor
*
* \param tensor The tensor object
* \param data Pointer to store the data
* \param shape Pointer to store the shape
* \param num_dims Pointer to store the number of dimensions
* \return Error code indicating the success or failure of the operation
*/
extError_t ORTX_API_CALL OrtxGetTensorData(OrtxTensor* tensor, const void** data, const int64_t** shape,
size_t* num_dims);
/**
* \brief Get the data from the tensor as int64_t type
*
* \param tensor The tensor object
* \param data Pointer to store the data
* \param shape Pointer to store the shape
* \param num_dims Pointer to store the number of dimensions
* \return Error code indicating the success or failure of the operation
*/

extError_t ORTX_API_CALL OrtxGetTensorDataInt64(OrtxTensor* tensor, const int64_t** data, const int64_t** shape,
size_t* num_dims);

/**
* \brief Get the data from the tensor as float type
*
* \param tensor The tensor object
* \param data Pointer to store the data
* \param shape Pointer to store the shape
* \param num_dims Pointer to store the number of dimensions
* \return Error code indicating the success or failure of the operation
*/
extError_t ORTX_API_CALL OrtxGetTensorDataFloat(OrtxTensor* tensor, const float** data, const int64_t** shape,
size_t* num_dims);

#ifdef __cplusplus
}
#endif
Loading

0 comments on commit d70a588

Please sign in to comment.