Skip to content

Commit

Permalink
Merge branch 'main' into rotary
Browse files Browse the repository at this point in the history
  • Loading branch information
xadupre authored Jul 12, 2024
2 parents dfeafa5 + 8153bc1 commit 8359f41
Show file tree
Hide file tree
Showing 41 changed files with 2,824 additions and 1,525 deletions.
13 changes: 11 additions & 2 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -171,6 +171,10 @@ if (MSVC)
endif()
message(STATUS "_STATIC_MSVC_RUNTIME_LIBRARY: ${_STATIC_MSVC_RUNTIME_LIBRARY}")

# DLL initialization errors due to old conda msvcp140.dll dll are a result of the new MSVC compiler
# See https://developercommunity.visualstudio.com/t/Access-violation-with-std::mutex::lock-a/10664660#T-N10668856
# Remove this definition once the conda msvcp140.dll dll is updated.
add_compile_definitions(_DISABLE_CONSTEXPR_MUTEX_CONSTRUCTOR)
endif()

if(NOT OCOS_BUILD_PYTHON AND OCOS_ENABLE_PYTHON)
Expand Down Expand Up @@ -442,7 +446,9 @@ endif()
if(OCOS_ENABLE_BERT_TOKENIZER)
# Bert
set(_HAS_TOKENIZER ON)
file(GLOB bert_TARGET_SRC "operators/tokenizer/basic_tokenizer.*" "operators/tokenizer/bert_tokenizer.*" "operators/tokenizer/bert_tokenizer_decoder.*")
file(GLOB bert_TARGET_SRC "operators/tokenizer/basic_tokenizer.*"
"operators/tokenizer/bert_tokenizer.*"
"operators/tokenizer/bert_tokenizer_decoder.*")
list(APPEND TARGET_SRC ${bert_TARGET_SRC})
endif()

Expand Down Expand Up @@ -820,7 +826,9 @@ if(OCOS_ENABLE_AZURE)
endif()

target_compile_definitions(ortcustomops PUBLIC ${OCOS_COMPILE_DEFINITIONS})
target_include_directories(ortcustomops PUBLIC "$<TARGET_PROPERTY:noexcep_operators,INTERFACE_INCLUDE_DIRECTORIES>")
target_include_directories(ortcustomops PUBLIC "$<TARGET_PROPERTY:ocos_operators,INTERFACE_INCLUDE_DIRECTORIES>")

target_link_libraries(ortcustomops PUBLIC ocos_operators)

if(_BUILD_SHARED_LIBRARY)
Expand All @@ -840,7 +848,8 @@ if(_BUILD_SHARED_LIBRARY)
standardize_output_folder(extensions_shared)

if(LINUX OR ANDROID)
set_property(TARGET extensions_shared APPEND_STRING PROPERTY LINK_FLAGS " -Wl,--version-script -Wl,${PROJECT_SOURCE_DIR}/shared/ortcustomops.ver")
set_property(TARGET extensions_shared APPEND_STRING PROPERTY LINK_FLAGS
" -Wl,--version-script -Wl,${PROJECT_SOURCE_DIR}/shared/ortcustomops.ver")
# strip if not a debug build
if(NOT CMAKE_BUILD_TYPE STREQUAL "Debug")
set_property(TARGET extensions_shared APPEND_STRING PROPERTY LINK_FLAGS " -Wl,-s")
Expand Down
2 changes: 0 additions & 2 deletions cmake/ext_cuda.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,6 @@ set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} -Xcudafe \"--diag_suppress=expr_has_no

add_compile_definitions(USE_CUDA)

set(OCOS_USE_MEMORY_EFFICIENT_ATTENTION OFF) # turn off for the build time. Turn them on when these 2 libs are really in use
set(OCOS_USE_FLASH_ATTENTION OFF)
if (OCOS_USE_FLASH_ATTENTION)
message(STATUS "Enable flash attention")
add_compile_definitions(OCOS_USE_FLASH_ATTENTION)
Expand Down
60 changes: 60 additions & 0 deletions docs/How_to_write_custom_op.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
# How to write custom ops

Custom Ops are based on ONNXRuntime-extensions API, especially **OrtLiteCustomOp** and **Tensor** class. C++ template metaprogramming is heavily used under the hood to provide big flexibility to the Custom Op authors on the parameter's count, type and order.

## Basic scenario

You have 2 ways to write a custom op: by writing a function, or by writing a structure.

### Custom op in the form of function

If your kernel is simple, you can use this option by just providing a function to compute the customized kernel. That function can have arbitrary number of inputs and outputs. For the inputs that are mandatory, their type would be like:

```C++
const Ort::Custom::Tensor<T>&
// or
const Ort::Custom::Tensor<T>*
```

For the inputs that are optional, their type would be like:

```C++
std::optional<const Ort::Custom::Tensor<T>*>
```

The function can also accept the pointer of **CUDAKernelContext**, where you can retrieve CUDA stream and other CUDA resources, if it requires to be run in CUDA GPU.

The function will return the type **OrtStatusPtr**

Please refer to [negpos_def.h](https://github.com/microsoft/onnxruntime-extensions/blob/main/operators/math/cuda/negpos_def.h) as an example and [tensor_tuple.inc](https://github.com/microsoft/onnxruntime-extensions/blob/main/include/custom_op/tensor_tuple.inc) for more possible parameter types.

### Custom op in the form of structure

If the kernel is complicated and there are extra properties of the custom op, you can use this option by providing a C++ structure where you can put these properties as the structure's member variables. Besides that, you also need to provide the following member functions:

```C++
OrtStatusPtr OnModelAttach(const OrtApi& api, const OrtKernelInfo& info) // This function initialize the properties of the custom op

OrtStatusPtr Compute(...) const // This function computes the customized kernel.
```
The specification of the parameters of the Compute function is the same as the first way (custom op in the form of function)
## Advanced scenario
In some cases you need more control on the parameters, in this case you have to use the structure form, which you need to provide the implementations of the following member functions such as:
```C++
// By default the function will return OrtMemType::OrtMemTypeDefault for all the inputs,
// you can provide your own implementation to specify the ith input is in CPU or GPU.
static OrtMemType GetInputMemoryType(size_t input_index)
// You can specify input i shares the same memory with output j if possible, by allocating
// two array with same length for the pointer input_index and output_index seperately, and
// then let (*input_index)[k] = i and (*output_index)[k] = j.
// The return value is the length of the allocated array.
static size_t GetMayInplace(int** input_index, int** output_index)
// Release the allocated array from the GetMayInplace() function.
static void ReleaseMayInplace(int* input_index, int* output_index)
```
19 changes: 19 additions & 0 deletions docs/c_api.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
# ONNXRuntime Extensions C ABI

ONNXRuntime Extensions provides a C-style ABI for pre-processing. It offers support for tokenization, image processing, speech feature extraction, and more. You can compile the ONNXRuntime Extensions as either a static library or a dynamic library to access these APIs.

The C ABI header files are named `ortx_*.h` and can be found in the include folder. There are three types of data processing APIs available:

- [`ortx_tokenizer.h`](../include/ortx_tokenizer.h): Provides tokenization for LLM models.
- [`ortx_processor.h`](../include/ortx_processor.h): Offers image processing APIs for multimodels.
- [`ortx_extraction.h`](../include/ortx_extractor.h): Provides speech feature extraction for audio data processing to assist the Whisper model.

## ABI QuickStart

Most APIs accept raw data inputs such as audio, image compressed binary formats, or UTF-8 encoded text for tokenization.

**Tokenization:** You can create a tokenizer object using `OrtxCreateTokenizer` and then use the object to tokenize a text or decode the token ID into the text. A C-style code snippet is available [here](../test/pp_api_test/c_only_test.c).

**Image processing:** `OrtxCreateProcessor` can create an image processor object from a pre-defined workflow in JSON format to process image files into a tensor-like data type. An example code snippet can be found [here](../test/pp_api_test/test_processor.cc#L75).

**Audio feature extraction:** `OrtxCreateSpeechFeatureExtractor` creates a speech feature extractor to obtain log mel spectrum data as input for the Whisper model. An example code snippet can be found [here](../test/pp_api_test/test_feature_extractor.cc#L16).
7 changes: 7 additions & 0 deletions include/custom_op/custom_op_lite.h
Original file line number Diff line number Diff line change
Expand Up @@ -886,6 +886,13 @@ struct OrtLiteCustomOp : public OrtCustomOp {
return INPUT_OUTPUT_OPTIONAL;
};
#endif

#if ORT_API_VERSION >= 18
OrtCustomOp::GetMayInplace = [](int**, int**) -> size_t {
return 0;
};
OrtCustomOp::ReleaseMayInplace = [](int*, int*) -> void {};
#endif
}

const std::string op_name_;
Expand Down
25 changes: 25 additions & 0 deletions include/op_def_struct.h
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,18 @@ struct CustomOp_defined_getInputMemoryType : std::false_type {};
template <typename T>
struct CustomOp_defined_getInputMemoryType<T, std::void_t<decltype(&T::GetInputMemoryType)>> : std::true_type {};

template <typename T, typename = void>
struct CustomOp_defined_getMayInplace : std::false_type {};

template <typename T>
struct CustomOp_defined_getMayInplace<T, std::void_t<decltype(&T::GetMayInplace)>> : std::true_type {};

template <typename T, typename = void>
struct CustomOp_defined_releaseMayInplace : std::false_type {};

template <typename T>
struct CustomOp_defined_releaseMayInplace<T, std::void_t<decltype(&T::ReleaseMayInplace)>> : std::true_type {};

template <typename CustomOpKernel>
struct OrtLiteCustomStructV2 : public OrtLiteCustomOp {
using ComputeFunction = decltype(&CustomOpKernel::Compute);
Expand Down Expand Up @@ -192,6 +204,19 @@ struct OrtLiteCustomStructV2 : public OrtLiteCustomOp {
};
}

#if ORT_API_VERSION >= 18
if constexpr (CustomOp_defined_getMayInplace<CustomOpKernel>::value) {
OrtCustomOp::GetMayInplace = [](int** input_index, int** output_index) -> size_t {
return CustomOpKernel::GetMayInplace(input_index, output_index);
};
}
if constexpr (CustomOp_defined_releaseMayInplace<CustomOpKernel>::value) {
OrtCustomOp::ReleaseMayInplace = [](int* input_index, int* output_index) -> void {
CustomOpKernel::ReleaseMayInplace(input_index, output_index);
};
}
#endif

OrtCustomOp::CreateKernelV2 = [](const OrtCustomOp* this_,
const OrtApi* api, const OrtKernelInfo* info, void** op_kernel) -> OrtStatusPtr {
if (api == nullptr) {
Expand Down
File renamed without changes.
75 changes: 75 additions & 0 deletions include/ortx_extractor.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

// C ABI header file for the onnxruntime-extensions tokenization module

#pragma once

#include "ortx_utils.h"

typedef OrtxObject OrtxFeatureExtractor;
typedef OrtxObject OrtxRawAudios;
typedef OrtxObject OrtxTensorResult;

#ifdef __cplusplus
extern "C" {
#endif

/**
* @brief Creates a feature extractor object.
*
* This function creates a feature extractor object based on the provided feature definition.
*
* @param[out] extractor Pointer to a pointer to the created feature extractor object.
* @param[in] fe_def The feature definition used to create the feature extractor.
*
* @return An error code indicating the result of the operation.
*/
extError_t ORTX_API_CALL OrtxCreateSpeechFeatureExtractor(OrtxFeatureExtractor** extractor, const char* fe_def);

/**
* Loads a collection of audio files into memory.
*
* This function loads a collection of audio files specified by the `audio_paths` array
* into memory and returns a pointer to the loaded audio data in the `audios` parameter.
*
* @param audios A pointer to a pointer that will be updated with the loaded audio data.
* The caller is responsible for freeing the memory allocated for the audio data.
* @param audio_paths An array of strings representing the paths to the audio files to be loaded.
* @param num_audios The number of audio files to be loaded.
*
* @return An `extError_t` value indicating the success or failure of the operation.
*/
extError_t ORTX_API_CALL OrtxLoadAudios(OrtxRawAudios** audios, const char* const* audio_paths, size_t num_audios);

/**
* @brief Creates an array of raw audio objects.
*
* This function creates an array of raw audio objects based on the provided data and sizes.
*
* @param audios Pointer to the variable that will hold the created raw audio objects.
* @param data Array of pointers to the audio data.
* @param sizes Array of pointers to the sizes of the audio data.
* @param num_audios Number of audio objects to create.
*
* @return extError_t Error code indicating the success or failure of the operation.
*/
extError_t ORTX_API_CALL OrtxCreateRawAudios(OrtxRawAudios** audios, const void* data[], const int64_t* sizes[], size_t num_audios);

/**
* @brief Calculates the log mel spectrogram for a given audio using the specified feature extractor.
*
* This function takes an instance of the OrtxFeatureExtractor struct, an instance of the OrtxRawAudios struct,
* and a pointer to a OrtxTensorResult pointer. It calculates the log mel spectrogram for the given audio using
* the specified feature extractor and stores the result in the provided log_mel pointer.
*
* @param extractor The feature extractor to use for calculating the log mel spectrogram.
* @param audio The raw audio data to process.
* @param log_mel A pointer to a OrtxTensorResult pointer where the result will be stored.
* @return An extError_t value indicating the success or failure of the operation.
*/
extError_t ORTX_API_CALL OrtxSpeechLogMel(OrtxFeatureExtractor* extractor, OrtxRawAudios* audio, OrtxTensorResult** log_mel);

#ifdef __cplusplus
}
#endif
36 changes: 16 additions & 20 deletions include/ortx_processor.h
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
// typedefs to create/dispose function flood, and to make the API more C++ friendly with less casting
typedef OrtxObject OrtxProcessor;
typedef OrtxObject OrtxRawImages;
typedef OrtxObject OrtxImageProcessorResult;

#ifdef __cplusplus
extern "C" {
Expand Down Expand Up @@ -40,8 +39,22 @@ extError_t ORTX_API_CALL OrtxCreateProcessor(OrtxProcessor** processor, const ch
extError_t ORTX_API_CALL OrtxLoadImages(OrtxRawImages** images, const char** image_paths, size_t num_images,
size_t* num_images_loaded);


/**
* @brief Preprocesses the given raw images using the specified processor.
* @brief Creates raw images from the provided data.
*
* This function creates raw images from the provided data. The raw images are stored in the `images` parameter.
*
* @param images Pointer to a pointer to the `OrtxRawImages` structure that will hold the created raw images.
* @param data Array of pointers to the data for each image.
* @param sizes Array of pointers to the sizes of each image.
* @param num_images Number of images to create.
* @return An `extError_t` value indicating the success or failure of the operation.
*/
extError_t ORTX_API_CALL OrtxCreateRawImages(OrtxRawImages** images, const void* data[], const int64_t* sizes[], size_t num_images);

/**
* @brief Pre-processes the given raw images using the specified processor.
*
* This function applies preprocessing operations on the raw images using the provided processor.
* The result of the preprocessing is stored in the `OrtxImageProcessorResult` object.
Expand All @@ -52,24 +65,7 @@ extError_t ORTX_API_CALL OrtxLoadImages(OrtxRawImages** images, const char** ima
* @return An `extError_t` value indicating the success or failure of the preprocessing operation.
*/
extError_t ORTX_API_CALL OrtxImagePreProcess(OrtxProcessor* processor, OrtxRawImages* images,
OrtxImageProcessorResult** result);

/**
* @brief Retrieves the image processor result at the specified index.
*
* @param result Pointer to the OrtxImageProcessorResult structure to store the result.
* @param index The index of the result to retrieve.
* @return extError_t The error code indicating the success or failure of the operation.
*/
extError_t ORTX_API_CALL OrtxImageGetTensorResult(OrtxImageProcessorResult* result, size_t index, OrtxTensor** tensor);

/** \brief Clear the outputs of the processor
*
* \param processor The processor object
* \param result The result object to clear
* \return Error code indicating the success or failure of the operation
*/
extError_t ORTX_API_CALL OrtxClearOutputs(OrtxProcessor* processor, OrtxImageProcessorResult* result);
OrtxTensorResult** result);

#ifdef __cplusplus
}
Expand Down
19 changes: 17 additions & 2 deletions include/ortx_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,19 +17,22 @@ typedef enum {
kOrtxKindDetokenizerCache = 0x778B,
kOrtxKindProcessor = 0x778C,
kOrtxKindRawImages = 0x778D,
kOrtxKindImageProcessorResult = 0x778E,
kOrtxKindTensorResult = 0x778E,
kOrtxKindProcessorResult = 0x778F,
kOrtxKindTensor = 0x7790,
kOrtxKindFeatureExtractor = 0x7791,
kOrtxKindRawAudios = 0x7792,
kOrtxKindEnd = 0x7999
} extObjectKind_t;

// all object managed by the library should be 'derived' from this struct
// which eventually will be released by TfmDispose if C++, or TFM_DISPOSE if C
typedef struct {
int ext_kind_;
extObjectKind_t ext_kind_;
} OrtxObject;

typedef OrtxObject OrtxTensor;
typedef OrtxObject OrtxTensorResult;

// C, instead of C++ doesn't cast automatically,
// so we need to use a macro to cast the object to the correct type
Expand Down Expand Up @@ -77,6 +80,18 @@ extError_t ORTX_API_CALL OrtxDispose(OrtxObject** object);
*/
extError_t ORTX_API_CALL OrtxDisposeOnly(OrtxObject* object);

/**
* @brief Retrieves the tensor at the specified index from the given tensor result.
*
* This function allows you to access a specific tensor from a tensor result object.
*
* @param result The tensor result object from which to retrieve the tensor.
* @param index The index of the tensor to retrieve.
* @param tensor A pointer to a variable that will hold the retrieved tensor.
* @return An error code indicating the success or failure of the operation.
*/
extError_t ORTX_API_CALL OrtxTensorResultGetAt(OrtxTensorResult* result, size_t index, OrtxTensor** tensor);

/** \brief Get the data from the tensor
*
* \param tensor The tensor object
Expand Down
5 changes: 4 additions & 1 deletion onnxruntime_extensions/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@

__author__ = "Microsoft"


from ._version import __version__
from ._ocos import get_library_path
from ._ocos import Opdef, PyCustomOpDef
Expand Down Expand Up @@ -66,6 +65,10 @@ def _unimplemented(*args, **kwargs):
gen_processing_models = _unimplemented
OrtPyFunction = _unimplemented
ort_inference = _unimplemented
PyOrtFunction = _unimplemented
optimize_model = _unimplemented
make_onnx_model = _unimplemented
ONNXRuntimeError = _unimplemented

else:
__all__ += _offline_api
Expand Down
Loading

0 comments on commit 8359f41

Please sign in to comment.