diff --git a/CMakeLists.txt b/CMakeLists.txt index 4f0e8440a..1a6a7e25f 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -171,6 +171,10 @@ if (MSVC) endif() message(STATUS "_STATIC_MSVC_RUNTIME_LIBRARY: ${_STATIC_MSVC_RUNTIME_LIBRARY}") + # DLL initialization errors due to old conda msvcp140.dll dll are a result of the new MSVC compiler + # See https://developercommunity.visualstudio.com/t/Access-violation-with-std::mutex::lock-a/10664660#T-N10668856 + # Remove this definition once the conda msvcp140.dll dll is updated. + add_compile_definitions(_DISABLE_CONSTEXPR_MUTEX_CONSTRUCTOR) endif() if(NOT OCOS_BUILD_PYTHON AND OCOS_ENABLE_PYTHON) @@ -442,7 +446,9 @@ endif() if(OCOS_ENABLE_BERT_TOKENIZER) # Bert set(_HAS_TOKENIZER ON) - file(GLOB bert_TARGET_SRC "operators/tokenizer/basic_tokenizer.*" "operators/tokenizer/bert_tokenizer.*" "operators/tokenizer/bert_tokenizer_decoder.*") + file(GLOB bert_TARGET_SRC "operators/tokenizer/basic_tokenizer.*" + "operators/tokenizer/bert_tokenizer.*" + "operators/tokenizer/bert_tokenizer_decoder.*") list(APPEND TARGET_SRC ${bert_TARGET_SRC}) endif() @@ -820,7 +826,9 @@ if(OCOS_ENABLE_AZURE) endif() target_compile_definitions(ortcustomops PUBLIC ${OCOS_COMPILE_DEFINITIONS}) +target_include_directories(ortcustomops PUBLIC "$") target_include_directories(ortcustomops PUBLIC "$") + target_link_libraries(ortcustomops PUBLIC ocos_operators) if(_BUILD_SHARED_LIBRARY) @@ -840,7 +848,8 @@ if(_BUILD_SHARED_LIBRARY) standardize_output_folder(extensions_shared) if(LINUX OR ANDROID) - set_property(TARGET extensions_shared APPEND_STRING PROPERTY LINK_FLAGS " -Wl,--version-script -Wl,${PROJECT_SOURCE_DIR}/shared/ortcustomops.ver") + set_property(TARGET extensions_shared APPEND_STRING PROPERTY LINK_FLAGS + " -Wl,--version-script -Wl,${PROJECT_SOURCE_DIR}/shared/ortcustomops.ver") # strip if not a debug build if(NOT CMAKE_BUILD_TYPE STREQUAL "Debug") set_property(TARGET extensions_shared APPEND_STRING PROPERTY LINK_FLAGS " -Wl,-s") diff --git a/cmake/ext_cuda.cmake b/cmake/ext_cuda.cmake index aa7d3282c..ead468c9f 100644 --- a/cmake/ext_cuda.cmake +++ b/cmake/ext_cuda.cmake @@ -30,8 +30,6 @@ set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} -Xcudafe \"--diag_suppress=expr_has_no add_compile_definitions(USE_CUDA) -set(OCOS_USE_MEMORY_EFFICIENT_ATTENTION OFF) # turn off for the build time. Turn them on when these 2 libs are really in use -set(OCOS_USE_FLASH_ATTENTION OFF) if (OCOS_USE_FLASH_ATTENTION) message(STATUS "Enable flash attention") add_compile_definitions(OCOS_USE_FLASH_ATTENTION) diff --git a/docs/How_to_write_custom_op.md b/docs/How_to_write_custom_op.md new file mode 100644 index 000000000..40832c834 --- /dev/null +++ b/docs/How_to_write_custom_op.md @@ -0,0 +1,60 @@ +# How to write custom ops + +Custom Ops are based on ONNXRuntime-extensions API, especially **OrtLiteCustomOp** and **Tensor** class. C++ template metaprogramming is heavily used under the hood to provide big flexibility to the Custom Op authors on the parameter's count, type and order. + +## Basic scenario + +You have 2 ways to write a custom op: by writing a function, or by writing a structure. + +### Custom op in the form of function + +If your kernel is simple, you can use this option by just providing a function to compute the customized kernel. That function can have arbitrary number of inputs and outputs. For the inputs that are mandatory, their type would be like: + +```C++ +const Ort::Custom::Tensor& +// or +const Ort::Custom::Tensor* +``` + +For the inputs that are optional, their type would be like: + +```C++ +std::optional*> +``` + +The function can also accept the pointer of **CUDAKernelContext**, where you can retrieve CUDA stream and other CUDA resources, if it requires to be run in CUDA GPU. + +The function will return the type **OrtStatusPtr** + +Please refer to [negpos_def.h](https://github.com/microsoft/onnxruntime-extensions/blob/main/operators/math/cuda/negpos_def.h) as an example and [tensor_tuple.inc](https://github.com/microsoft/onnxruntime-extensions/blob/main/include/custom_op/tensor_tuple.inc) for more possible parameter types. + +### Custom op in the form of structure + +If the kernel is complicated and there are extra properties of the custom op, you can use this option by providing a C++ structure where you can put these properties as the structure's member variables. Besides that, you also need to provide the following member functions: + +```C++ +OrtStatusPtr OnModelAttach(const OrtApi& api, const OrtKernelInfo& info) // This function initialize the properties of the custom op + +OrtStatusPtr Compute(...) const // This function computes the customized kernel. +``` + +The specification of the parameters of the Compute function is the same as the first way (custom op in the form of function) + +## Advanced scenario + +In some cases you need more control on the parameters, in this case you have to use the structure form, which you need to provide the implementations of the following member functions such as: + +```C++ +// By default the function will return OrtMemType::OrtMemTypeDefault for all the inputs, +// you can provide your own implementation to specify the ith input is in CPU or GPU. +static OrtMemType GetInputMemoryType(size_t input_index) + +// You can specify input i shares the same memory with output j if possible, by allocating +// two array with same length for the pointer input_index and output_index seperately, and +// then let (*input_index)[k] = i and (*output_index)[k] = j. +// The return value is the length of the allocated array. +static size_t GetMayInplace(int** input_index, int** output_index) + +// Release the allocated array from the GetMayInplace() function. +static void ReleaseMayInplace(int* input_index, int* output_index) +``` \ No newline at end of file diff --git a/docs/c_api.md b/docs/c_api.md new file mode 100644 index 000000000..1a3d4613b --- /dev/null +++ b/docs/c_api.md @@ -0,0 +1,19 @@ +# ONNXRuntime Extensions C ABI + +ONNXRuntime Extensions provides a C-style ABI for pre-processing. It offers support for tokenization, image processing, speech feature extraction, and more. You can compile the ONNXRuntime Extensions as either a static library or a dynamic library to access these APIs. + +The C ABI header files are named `ortx_*.h` and can be found in the include folder. There are three types of data processing APIs available: + +- [`ortx_tokenizer.h`](../include/ortx_tokenizer.h): Provides tokenization for LLM models. +- [`ortx_processor.h`](../include/ortx_processor.h): Offers image processing APIs for multimodels. +- [`ortx_extraction.h`](../include/ortx_extractor.h): Provides speech feature extraction for audio data processing to assist the Whisper model. + +## ABI QuickStart + +Most APIs accept raw data inputs such as audio, image compressed binary formats, or UTF-8 encoded text for tokenization. + +**Tokenization:** You can create a tokenizer object using `OrtxCreateTokenizer` and then use the object to tokenize a text or decode the token ID into the text. A C-style code snippet is available [here](../test/pp_api_test/c_only_test.c). + +**Image processing:** `OrtxCreateProcessor` can create an image processor object from a pre-defined workflow in JSON format to process image files into a tensor-like data type. An example code snippet can be found [here](../test/pp_api_test/test_processor.cc#L75). + +**Audio feature extraction:** `OrtxCreateSpeechFeatureExtractor` creates a speech feature extractor to obtain log mel spectrum data as input for the Whisper model. An example code snippet can be found [here](../test/pp_api_test/test_feature_extractor.cc#L16). diff --git a/include/custom_op/custom_op_lite.h b/include/custom_op/custom_op_lite.h index bcb746b91..cba6beae6 100644 --- a/include/custom_op/custom_op_lite.h +++ b/include/custom_op/custom_op_lite.h @@ -886,6 +886,13 @@ struct OrtLiteCustomOp : public OrtCustomOp { return INPUT_OUTPUT_OPTIONAL; }; #endif + +#if ORT_API_VERSION >= 18 + OrtCustomOp::GetMayInplace = [](int**, int**) -> size_t { + return 0; + }; + OrtCustomOp::ReleaseMayInplace = [](int*, int*) -> void {}; +#endif } const std::string op_name_; diff --git a/include/op_def_struct.h b/include/op_def_struct.h index 0fc7b233c..8076204a4 100644 --- a/include/op_def_struct.h +++ b/include/op_def_struct.h @@ -106,6 +106,18 @@ struct CustomOp_defined_getInputMemoryType : std::false_type {}; template struct CustomOp_defined_getInputMemoryType> : std::true_type {}; +template +struct CustomOp_defined_getMayInplace : std::false_type {}; + +template +struct CustomOp_defined_getMayInplace> : std::true_type {}; + +template +struct CustomOp_defined_releaseMayInplace : std::false_type {}; + +template +struct CustomOp_defined_releaseMayInplace> : std::true_type {}; + template struct OrtLiteCustomStructV2 : public OrtLiteCustomOp { using ComputeFunction = decltype(&CustomOpKernel::Compute); @@ -192,6 +204,19 @@ struct OrtLiteCustomStructV2 : public OrtLiteCustomOp { }; } +#if ORT_API_VERSION >= 18 + if constexpr (CustomOp_defined_getMayInplace::value) { + OrtCustomOp::GetMayInplace = [](int** input_index, int** output_index) -> size_t { + return CustomOpKernel::GetMayInplace(input_index, output_index); + }; + } + if constexpr (CustomOp_defined_releaseMayInplace::value) { + OrtCustomOp::ReleaseMayInplace = [](int* input_index, int* output_index) -> void { + CustomOpKernel::ReleaseMayInplace(input_index, output_index); + }; + } +#endif + OrtCustomOp::CreateKernelV2 = [](const OrtCustomOp* this_, const OrtApi* api, const OrtKernelInfo* info, void** op_kernel) -> OrtStatusPtr { if (api == nullptr) { diff --git a/include/ortx_c_helper.h b/include/ortx_cpp_helper.h similarity index 100% rename from include/ortx_c_helper.h rename to include/ortx_cpp_helper.h diff --git a/include/ortx_extractor.h b/include/ortx_extractor.h new file mode 100644 index 000000000..13901666b --- /dev/null +++ b/include/ortx_extractor.h @@ -0,0 +1,75 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +// C ABI header file for the onnxruntime-extensions tokenization module + +#pragma once + +#include "ortx_utils.h" + +typedef OrtxObject OrtxFeatureExtractor; +typedef OrtxObject OrtxRawAudios; +typedef OrtxObject OrtxTensorResult; + +#ifdef __cplusplus +extern "C" { +#endif + +/** + * @brief Creates a feature extractor object. + * + * This function creates a feature extractor object based on the provided feature definition. + * + * @param[out] extractor Pointer to a pointer to the created feature extractor object. + * @param[in] fe_def The feature definition used to create the feature extractor. + * + * @return An error code indicating the result of the operation. + */ +extError_t ORTX_API_CALL OrtxCreateSpeechFeatureExtractor(OrtxFeatureExtractor** extractor, const char* fe_def); + +/** + * Loads a collection of audio files into memory. + * + * This function loads a collection of audio files specified by the `audio_paths` array + * into memory and returns a pointer to the loaded audio data in the `audios` parameter. + * + * @param audios A pointer to a pointer that will be updated with the loaded audio data. + * The caller is responsible for freeing the memory allocated for the audio data. + * @param audio_paths An array of strings representing the paths to the audio files to be loaded. + * @param num_audios The number of audio files to be loaded. + * + * @return An `extError_t` value indicating the success or failure of the operation. + */ +extError_t ORTX_API_CALL OrtxLoadAudios(OrtxRawAudios** audios, const char* const* audio_paths, size_t num_audios); + +/** + * @brief Creates an array of raw audio objects. + * + * This function creates an array of raw audio objects based on the provided data and sizes. + * + * @param audios Pointer to the variable that will hold the created raw audio objects. + * @param data Array of pointers to the audio data. + * @param sizes Array of pointers to the sizes of the audio data. + * @param num_audios Number of audio objects to create. + * + * @return extError_t Error code indicating the success or failure of the operation. + */ +extError_t ORTX_API_CALL OrtxCreateRawAudios(OrtxRawAudios** audios, const void* data[], const int64_t* sizes[], size_t num_audios); + +/** + * @brief Calculates the log mel spectrogram for a given audio using the specified feature extractor. + * + * This function takes an instance of the OrtxFeatureExtractor struct, an instance of the OrtxRawAudios struct, + * and a pointer to a OrtxTensorResult pointer. It calculates the log mel spectrogram for the given audio using + * the specified feature extractor and stores the result in the provided log_mel pointer. + * + * @param extractor The feature extractor to use for calculating the log mel spectrogram. + * @param audio The raw audio data to process. + * @param log_mel A pointer to a OrtxTensorResult pointer where the result will be stored. + * @return An extError_t value indicating the success or failure of the operation. + */ +extError_t ORTX_API_CALL OrtxSpeechLogMel(OrtxFeatureExtractor* extractor, OrtxRawAudios* audio, OrtxTensorResult** log_mel); + +#ifdef __cplusplus +} +#endif diff --git a/include/ortx_processor.h b/include/ortx_processor.h index 6dcc5a84e..b42a6c4f2 100644 --- a/include/ortx_processor.h +++ b/include/ortx_processor.h @@ -10,7 +10,6 @@ // typedefs to create/dispose function flood, and to make the API more C++ friendly with less casting typedef OrtxObject OrtxProcessor; typedef OrtxObject OrtxRawImages; -typedef OrtxObject OrtxImageProcessorResult; #ifdef __cplusplus extern "C" { @@ -40,8 +39,22 @@ extError_t ORTX_API_CALL OrtxCreateProcessor(OrtxProcessor** processor, const ch extError_t ORTX_API_CALL OrtxLoadImages(OrtxRawImages** images, const char** image_paths, size_t num_images, size_t* num_images_loaded); + /** - * @brief Preprocesses the given raw images using the specified processor. + * @brief Creates raw images from the provided data. + * + * This function creates raw images from the provided data. The raw images are stored in the `images` parameter. + * + * @param images Pointer to a pointer to the `OrtxRawImages` structure that will hold the created raw images. + * @param data Array of pointers to the data for each image. + * @param sizes Array of pointers to the sizes of each image. + * @param num_images Number of images to create. + * @return An `extError_t` value indicating the success or failure of the operation. + */ +extError_t ORTX_API_CALL OrtxCreateRawImages(OrtxRawImages** images, const void* data[], const int64_t* sizes[], size_t num_images); + +/** + * @brief Pre-processes the given raw images using the specified processor. * * This function applies preprocessing operations on the raw images using the provided processor. * The result of the preprocessing is stored in the `OrtxImageProcessorResult` object. @@ -52,24 +65,7 @@ extError_t ORTX_API_CALL OrtxLoadImages(OrtxRawImages** images, const char** ima * @return An `extError_t` value indicating the success or failure of the preprocessing operation. */ extError_t ORTX_API_CALL OrtxImagePreProcess(OrtxProcessor* processor, OrtxRawImages* images, - OrtxImageProcessorResult** result); - -/** - * @brief Retrieves the image processor result at the specified index. - * - * @param result Pointer to the OrtxImageProcessorResult structure to store the result. - * @param index The index of the result to retrieve. - * @return extError_t The error code indicating the success or failure of the operation. - */ -extError_t ORTX_API_CALL OrtxImageGetTensorResult(OrtxImageProcessorResult* result, size_t index, OrtxTensor** tensor); - -/** \brief Clear the outputs of the processor - * - * \param processor The processor object - * \param result The result object to clear - * \return Error code indicating the success or failure of the operation - */ -extError_t ORTX_API_CALL OrtxClearOutputs(OrtxProcessor* processor, OrtxImageProcessorResult* result); + OrtxTensorResult** result); #ifdef __cplusplus } diff --git a/include/ortx_utils.h b/include/ortx_utils.h index e6c0af9aa..ed1ca2ec5 100644 --- a/include/ortx_utils.h +++ b/include/ortx_utils.h @@ -17,19 +17,22 @@ typedef enum { kOrtxKindDetokenizerCache = 0x778B, kOrtxKindProcessor = 0x778C, kOrtxKindRawImages = 0x778D, - kOrtxKindImageProcessorResult = 0x778E, + kOrtxKindTensorResult = 0x778E, kOrtxKindProcessorResult = 0x778F, kOrtxKindTensor = 0x7790, + kOrtxKindFeatureExtractor = 0x7791, + kOrtxKindRawAudios = 0x7792, kOrtxKindEnd = 0x7999 } extObjectKind_t; // all object managed by the library should be 'derived' from this struct // which eventually will be released by TfmDispose if C++, or TFM_DISPOSE if C typedef struct { - int ext_kind_; + extObjectKind_t ext_kind_; } OrtxObject; typedef OrtxObject OrtxTensor; +typedef OrtxObject OrtxTensorResult; // C, instead of C++ doesn't cast automatically, // so we need to use a macro to cast the object to the correct type @@ -77,6 +80,18 @@ extError_t ORTX_API_CALL OrtxDispose(OrtxObject** object); */ extError_t ORTX_API_CALL OrtxDisposeOnly(OrtxObject* object); +/** + * @brief Retrieves the tensor at the specified index from the given tensor result. + * + * This function allows you to access a specific tensor from a tensor result object. + * + * @param result The tensor result object from which to retrieve the tensor. + * @param index The index of the tensor to retrieve. + * @param tensor A pointer to a variable that will hold the retrieved tensor. + * @return An error code indicating the success or failure of the operation. + */ +extError_t ORTX_API_CALL OrtxTensorResultGetAt(OrtxTensorResult* result, size_t index, OrtxTensor** tensor); + /** \brief Get the data from the tensor * * \param tensor The tensor object diff --git a/onnxruntime_extensions/__init__.py b/onnxruntime_extensions/__init__.py index 872c5a2b5..de6e1b68a 100644 --- a/onnxruntime_extensions/__init__.py +++ b/onnxruntime_extensions/__init__.py @@ -10,7 +10,6 @@ __author__ = "Microsoft" - from ._version import __version__ from ._ocos import get_library_path from ._ocos import Opdef, PyCustomOpDef @@ -66,6 +65,10 @@ def _unimplemented(*args, **kwargs): gen_processing_models = _unimplemented OrtPyFunction = _unimplemented ort_inference = _unimplemented + PyOrtFunction = _unimplemented + optimize_model = _unimplemented + make_onnx_model = _unimplemented + ONNXRuntimeError = _unimplemented else: __all__ += _offline_api diff --git a/onnxruntime_extensions/_torch_cvt.py b/onnxruntime_extensions/_torch_cvt.py index a17b1bb2d..10b85c1a6 100644 --- a/onnxruntime_extensions/_torch_cvt.py +++ b/onnxruntime_extensions/_torch_cvt.py @@ -17,7 +17,7 @@ from ._ortapi2 import make_onnx_model from ._cuops import SingleOpGraph from ._hf_cvt import HFTokenizerConverter -from .util import remove_unused_initializers +from .util import remove_unused_initializers, mel_filterbank class _WhisperHParams: @@ -30,53 +30,15 @@ class _WhisperHParams: N_FRAMES = N_SAMPLES // HOP_LENGTH -def _mel_filterbank( - n_fft: int, n_mels: int = 80, sr=16000, min_mel=0, max_mel=45.245640471924965, dtype=np.float32): - """ - Compute a Mel-filterbank. The filters are stored in the rows, the columns, - and it is Slaney normalized mel-scale filterbank. - """ - fbank = np.zeros((n_mels, n_fft // 2 + 1), dtype=dtype) - - # the centers of the frequency bins for the DFT - freq_bins = np.fft.rfftfreq(n=n_fft, d=1.0 / sr) - - mel = np.linspace(min_mel, max_mel, n_mels + 2) - # Fill in the linear scale - f_min = 0.0 - f_sp = 200.0 / 3 - freqs = f_min + f_sp * mel - - # And now the nonlinear scale - min_log_hz = 1000.0 # beginning of log region (Hz) - min_log_mel = (min_log_hz - f_min) / f_sp # same (Mels) - logstep = np.log(6.4) / 27.0 # step size for log region - - log_t = mel >= min_log_mel - freqs[log_t] = min_log_hz * np.exp(logstep * (mel[log_t] - min_log_mel)) - mel_bins = freqs - - mel_spacing = np.diff(mel_bins) - - ramps = mel_bins.reshape(-1, 1) - freq_bins.reshape(1, -1) - for i in range(n_mels): - left = -ramps[i] / mel_spacing[i] - right = ramps[i + 2] / mel_spacing[i + 1] - - # intersect them with each other and zero - fbank[i] = np.maximum(0, np.minimum(left, right)) - - energy_norm = 2.0 / (mel_bins[2: n_mels + 2] - mel_bins[:n_mels]) - fbank *= energy_norm[:, np.newaxis] - return fbank - - class CustomOpStftNorm(torch.autograd.Function): @staticmethod def symbolic(g, self, n_fft, hop_length, window): - t_n_fft = g.op('Constant', value_t=torch.tensor(n_fft, dtype=torch.int64)) - t_hop_length = g.op('Constant', value_t=torch.tensor(hop_length, dtype=torch.int64)) - t_frame_size = g.op('Constant', value_t=torch.tensor(n_fft, dtype=torch.int64)) + t_n_fft = g.op('Constant', value_t=torch.tensor( + n_fft, dtype=torch.int64)) + t_hop_length = g.op('Constant', value_t=torch.tensor( + hop_length, dtype=torch.int64)) + t_frame_size = g.op( + 'Constant', value_t=torch.tensor(n_fft, dtype=torch.int64)) return g.op("ai.onnx.contrib::StftNorm", self, t_n_fft, t_hop_length, window, t_frame_size) @staticmethod @@ -97,7 +59,7 @@ def __init__(self, sr=_WhisperHParams.SAMPLE_RATE, n_fft=_WhisperHParams.N_FFT, self.n_fft = n_fft self.window = torch.hann_window(n_fft) self.mel_filters = torch.from_numpy( - _mel_filterbank(sr=sr, n_fft=n_fft, n_mels=n_mels)) + mel_filterbank(sr=sr, n_fft=n_fft, n_mels=n_mels)) def forward(self, audio_pcm: torch.Tensor): stft_norm = CustomOpStftNorm.apply(audio_pcm, @@ -112,7 +74,8 @@ def forward(self, audio_pcm: torch.Tensor): spec_shape = log_spec.shape padding_spec = torch.ones(spec_shape[0], spec_shape[1], - self.n_samples // self.hop_length - spec_shape[2], + self.n_samples // self.hop_length - + spec_shape[2], dtype=torch.float) padding_spec *= spec_min log_spec = torch.cat((log_spec, padding_spec), dim=2) @@ -165,15 +128,20 @@ def _to_onnx_stft(onnx_model, n_fft): make_node('Slice', inputs=['transpose_1_output_0', 'const_18_output_0', 'const_minus_1_output_0', 'const_17_output_0', 'const_20_output_0'], outputs=['slice_1_output_0'], name='slice_1'), - make_node('Constant', inputs=[], outputs=['const0_output_0'], name='const0', value_int=0), - make_node('Constant', inputs=[], outputs=['const1_output_0'], name='const1', value_int=1), + make_node('Constant', inputs=[], outputs=[ + 'const0_output_0'], name='const0', value_int=0), + make_node('Constant', inputs=[], outputs=[ + 'const1_output_0'], name='const1', value_int=1), make_node('Gather', inputs=['slice_1_output_0', 'const0_output_0'], outputs=['gather_4_output_0'], name='gather_4', axis=3), make_node('Gather', inputs=['slice_1_output_0', 'const1_output_0'], outputs=['gather_5_output_0'], name='gather_5', axis=3), - make_node('Mul', inputs=['gather_4_output_0', 'gather_4_output_0'], outputs=['mul_output_0'], name='mul0'), - make_node('Mul', inputs=['gather_5_output_0', 'gather_5_output_0'], outputs=['mul_1_output_0'], name='mul1'), - make_node('Add', inputs=['mul_output_0', 'mul_1_output_0'], outputs=[stft_norm_node.output[0]], name='add0'), + make_node('Mul', inputs=['gather_4_output_0', 'gather_4_output_0'], outputs=[ + 'mul_output_0'], name='mul0'), + make_node('Mul', inputs=['gather_5_output_0', 'gather_5_output_0'], outputs=[ + 'mul_1_output_0'], name='mul1'), + make_node('Add', inputs=['mul_output_0', 'mul_1_output_0'], outputs=[ + stft_norm_node.output[0]], name='add0'), ] new_stft_nodes.extend(onnx_model.graph.node[:node_idx]) new_stft_nodes.extend(replaced_nodes) @@ -253,9 +221,11 @@ def post_processing(self, **kwargs): del g.node[:] g.node.extend(nodes) - inputs = [onnx.helper.make_tensor_value_info("sequences", onnx.TensorProto.INT32, ['N', 'seq_len', 'ids'])] + inputs = [onnx.helper.make_tensor_value_info( + "sequences", onnx.TensorProto.INT32, ['N', 'seq_len', 'ids'])] del g.input[:] g.input.extend(inputs) - g.output[0].type.CopyFrom(onnx.helper.make_tensor_type_proto(onnx.TensorProto.STRING, ['N', 'text'])) + g.output[0].type.CopyFrom(onnx.helper.make_tensor_type_proto( + onnx.TensorProto.STRING, ['N', 'text'])) return make_onnx_model(g, opset_version=self.opset_version) diff --git a/operators/audio/audio.cc b/operators/audio/audio.cc index ecd7d4cf3..2a0d3eb0f 100644 --- a/operators/audio/audio.cc +++ b/operators/audio/audio.cc @@ -3,17 +3,15 @@ #include "ocos.h" #ifdef ENABLE_DR_LIBS -#include "audio_decoder.hpp" +#include "audio_decoder.h" #endif // ENABLE_DR_LIBS -FxLoadCustomOpFactory LoadCustomOpClasses_Audio = []()-> CustomOpArray& { +FxLoadCustomOpFactory LoadCustomOpClasses_Audio = []() -> CustomOpArray& { static OrtOpLoader op_loader( - []() { return nullptr; } #ifdef ENABLE_DR_LIBS - , - CustomCpuStructV2("AudioDecoder", AudioDecoder) + CustomCpuStructV2("AudioDecoder", AudioDecoder), #endif - ); + []() { return nullptr; }); return op_loader.GetCustomOps(); }; diff --git a/operators/audio/audio_decoder.cc b/operators/audio/audio_decoder.cc new file mode 100644 index 000000000..b9e92dcd9 --- /dev/null +++ b/operators/audio/audio_decoder.cc @@ -0,0 +1,181 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +#include +#include +#include + +#include "audio_decoder.h" + +#define DR_FLAC_IMPLEMENTATION +#include "dr_flac.h" +#define DR_MP3_IMPLEMENTATION 1 +#define DR_MP3_FLOAT_OUTPUT 1 +#include "dr_mp3.h" +#define DR_WAV_IMPLEMENTATION +#include "dr_wav.h" + +#include "narrow.h" +#include "string_utils.h" +#include "string_tensor.h" +#include "sampling.h" + +OrtStatusPtr AudioDecoder::OnModelAttach(const OrtApi& api, const OrtKernelInfo& info) { + auto status = OrtW::GetOpAttribute(info, "downsampling_rate", downsample_rate_); + if (!status) { + status = OrtW::GetOpAttribute(info, "stereo_to_mono", stereo_mixer_); + } + + return status; +} + +AudioDecoder::AudioStreamType AudioDecoder::ReadStreamFormat(const uint8_t* p_data, const std::string& str_format, + OrtxStatus& status) const { + const std::map format_mapping = {{"default", AudioStreamType::kDefault}, + {"wav", AudioStreamType::kWAV}, + {"mp3", AudioStreamType::kMP3}, + {"flac", AudioStreamType::kFLAC}}; + + AudioStreamType stream_format = AudioStreamType::kDefault; + if (str_format.length() > 0) { + auto pos = format_mapping.find(str_format); + if (pos == format_mapping.end()) { + status = {kOrtxErrorInvalidArgument, + MakeString("[AudioDecoder]: Unknown audio stream format: ", str_format).c_str()}; + return stream_format; + } + stream_format = pos->second; + } + + if (stream_format == AudioStreamType::kDefault) { + auto p_stream = reinterpret_cast(p_data); + std::string_view marker(p_stream, 4); + if (marker == "fLaC") { + stream_format = AudioStreamType::kFLAC; + } else if (marker == "RIFF") { + stream_format = AudioStreamType::kWAV; + } else if (marker[0] == char(0xFF) && (marker[1] | 0x1F) == char(0xFF)) { + // http://www.mp3-tech.org/programmer/frame_header.html + // only detect the 8 + 3 bits sync word + stream_format = AudioStreamType::kMP3; + } else { + status = {kOrtxErrorInvalidArgument, "[AudioDecoder]: Cannot detect audio stream format"}; + } + } + + return stream_format; +} + +template +static size_t DrReadFrames(std::list>& frames, FX_DECODER fx, TY_AUDIO& obj) { + const size_t default_chunk_size = 1024 * 256; + int64_t total_buf_size = 0; + + for (;;) { + std::vector buf; + buf.resize(default_chunk_size * obj.channels); + auto n_frames = fx(&obj, default_chunk_size, buf.data()); + if (n_frames <= 0) { + break; + } + auto data_size = n_frames * obj.channels; + total_buf_size += data_size; + buf.resize(data_size); + frames.emplace_back(std::move(buf)); + } + + return total_buf_size; +} + +OrtxStatus AudioDecoder::Compute(const ortc::Tensor& input, const std::optional format, + ortc::Tensor& output0) const { + const uint8_t* p_data = input.Data(); + auto input_dim = input.Shape(); + OrtxStatus status; + if (!((input_dim.size() == 1) || (input_dim.size() == 2 && input_dim[0] == 1))) { + return {kOrtxErrorInvalidArgument, "[AudioDecoder]: Expect input dimension [n] or [1,n]."}; + } + + std::string str_format; + if (format) { + str_format = *format; + } + auto stream_format = ReadStreamFormat(p_data, str_format, status); + if (status) { + return status; + } + + int64_t total_buf_size = 0; + std::list> lst_frames; + int64_t orig_sample_rate = 0; + int64_t orig_channels = 0; + + if (stream_format == AudioStreamType::kMP3) { + auto mp3_obj_ptr = std::make_unique(); + if (!drmp3_init_memory(mp3_obj_ptr.get(), p_data, input.NumberOfElement(), nullptr)) { + status = {kOrtxErrorCorruptData, "[AudioDecoder]: unexpected error on MP3 stream."}; + return status; + } + orig_sample_rate = mp3_obj_ptr->sampleRate; + orig_channels = mp3_obj_ptr->channels; + total_buf_size = DrReadFrames(lst_frames, drmp3_read_pcm_frames_f32, *mp3_obj_ptr); + + } else if (stream_format == AudioStreamType::kFLAC) { + drflac* flac_obj = drflac_open_memory(p_data, input.NumberOfElement(), nullptr); + auto flac_obj_closer = gsl::finally([flac_obj]() { drflac_close(flac_obj); }); + if (flac_obj == nullptr) { + status = {kOrtxErrorCorruptData, "[AudioDecoder]: unexpected error on FLAC stream."}; + return status; + } + orig_sample_rate = flac_obj->sampleRate; + orig_channels = flac_obj->channels; + total_buf_size = DrReadFrames(lst_frames, drflac_read_pcm_frames_f32, *flac_obj); + + } else { + drwav wav_obj; + if (!drwav_init_memory(&wav_obj, p_data, input.NumberOfElement(), nullptr)) { + status = {kOrtxErrorCorruptData, "[AudioDecoder]: unexpected error on WAV stream."}; + return status; + } + orig_sample_rate = wav_obj.sampleRate; + orig_channels = wav_obj.channels; + total_buf_size = DrReadFrames(lst_frames, drwav_read_pcm_frames_f32, wav_obj); + } + + if (downsample_rate_ != 0 && orig_sample_rate < downsample_rate_) { + status = {kOrtxErrorCorruptData, "[AudioDecoder]: only down-sampling supported."}; + return status; + } + + // join all frames + std::vector buf; + buf.resize(total_buf_size); + int64_t offset = 0; + for (auto& _b : lst_frames) { + std::copy(_b.begin(), _b.end(), buf.begin() + offset); + offset += _b.size(); + } + + // mix the stereo channels into mono channel + if (stereo_mixer_ && orig_channels > 1) { + if (buf.size() > 1) { + for (size_t i = 0; i < buf.size() / 2; ++i) { + buf[i] = (buf[i * 2] + buf[i * 2 + 1]) / 2; + } + buf.resize(buf.size() / 2); + } + } + + if (downsample_rate_ != 0 && downsample_rate_ != orig_sample_rate) { + // A lowpass filter on buf audio data to remove high frequency noise + ButterworthLowpass filter(0.5 * downsample_rate_, 1.0 * orig_sample_rate); + std::vector filtered_buf = filter.Process(buf); + // downsample the audio data + KaiserWindowInterpolation::Process(filtered_buf, buf, 1.0f * orig_sample_rate, 1.0f * downsample_rate_); + } + + std::vector dim_out = {1, ort_extensions::narrow(buf.size())}; + float* p_output = output0.Allocate(dim_out); + std::copy(buf.begin(), buf.end(), p_output); + return status; +} diff --git a/operators/audio/audio_decoder.h b/operators/audio/audio_decoder.h new file mode 100644 index 000000000..cecfcffdb --- /dev/null +++ b/operators/audio/audio_decoder.h @@ -0,0 +1,45 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +#pragma once + +#include "ocos.h" + +#include +#include + +struct AudioDecoder { + public: + OrtStatusPtr OnModelAttach(const OrtApi& api, const OrtKernelInfo& info); + + template + OrtxStatus Init(const DictT& attrs) { + // in API mode, the default value is 1 + downsample_rate_ = 16000; + stereo_mixer_ = 1; + for (const auto& [key, value] : attrs) { + if (key == "target_sample_rate") { + downsample_rate_ = std::get(value); + } else if (key == "stereo_to_mono") { + stereo_mixer_ = std::get(value); + } else { + return {kOrtxErrorInvalidArgument, "[AudioDecoder]: Invalid argument"}; + } + } + + return {}; + } + + enum class AudioStreamType { kDefault = 0, kWAV, kMP3, kFLAC }; + + AudioStreamType ReadStreamFormat(const uint8_t* p_data, const std::string& str_format, OrtxStatus& status) const; + OrtxStatus Compute(const ortc::Tensor& input, const std::optional format, + ortc::Tensor& output0) const; + OrtxStatus ComputeNoOpt(const ortc::Tensor& input, ortc::Tensor& output0) { + return Compute(input, std::nullopt, output0); + } + + private: + int64_t downsample_rate_{}; + int64_t stereo_mixer_{}; +}; diff --git a/operators/audio/audio_decoder.hpp b/operators/audio/audio_decoder.hpp deleted file mode 100644 index 06e61172e..000000000 --- a/operators/audio/audio_decoder.hpp +++ /dev/null @@ -1,206 +0,0 @@ -// Copyright (c) Microsoft Corporation. -// Licensed under the MIT License. - -#pragma once - -#include "ocos.h" - -#include -#include -#include -#include -#define DR_FLAC_IMPLEMENTATION -#include "dr_flac.h" -#define DR_MP3_IMPLEMENTATION 1 -#define DR_MP3_FLOAT_OUTPUT 1 -#include "dr_mp3.h" -#define DR_WAV_IMPLEMENTATION -#include "dr_wav.h" - -#include -#include "narrow.h" -#include "string_utils.h" -#include "string_tensor.h" -#include "sampling.h" - -struct AudioDecoder{ - public: - - OrtStatusPtr OnModelAttach(const OrtApi& api, const OrtKernelInfo& info) { - auto status = OrtW::GetOpAttribute(info, "downsampling_rate", downsample_rate_); - if (!status) { - status = OrtW::GetOpAttribute(info, "stereo_to_mono", stereo_mixer_); - } - - return status; - } - - enum class AudioStreamType { - kDefault = 0, - kWAV, - kMP3, - kFLAC - }; - - AudioStreamType ReadStreamFormat(const uint8_t* p_data, const std::string& str_format, OrtStatusPtr& status) const { - static const std::map format_mapping = { - {"default", AudioStreamType::kDefault}, - {"wav", AudioStreamType::kWAV}, - {"mp3", AudioStreamType::kMP3}, - {"flac", AudioStreamType::kFLAC}}; - - AudioStreamType stream_format = AudioStreamType::kDefault; - if (str_format.length() > 0) { - auto pos = format_mapping.find(str_format); - if (pos == format_mapping.end()) { - status = OrtW::CreateStatus(MakeString( - "[AudioDecoder]: Unknown audio stream format: ", str_format) - .c_str(), - ORT_INVALID_ARGUMENT); - return stream_format; - } - stream_format = pos->second; - } - - if (stream_format == AudioStreamType::kDefault) { - auto p_stream = reinterpret_cast(p_data); - std::string_view marker(p_stream, 4); - if (marker == "fLaC") { - stream_format = AudioStreamType::kFLAC; - } else if (marker == "RIFF") { - stream_format = AudioStreamType::kWAV; - } else if (marker[0] == char(0xFF) && (marker[1] | 0x1F) == char(0xFF)) { - // http://www.mp3-tech.org/programmer/frame_header.html - // only detect the 8 + 3 bits sync word - stream_format = AudioStreamType::kMP3; - } else { - status = OrtW::CreateStatus("[AudioDecoder]: Cannot detect audio stream format", ORT_INVALID_ARGUMENT); - } - } - - return stream_format; - } - - template - static size_t DrReadFrames(std::list>& frames, FX_DECODER fx, TY_AUDIO& obj) { - const size_t default_chunk_size = 1024 * 256; - int64_t total_buf_size = 0; - - for (;;) { - std::vector buf; - buf.resize(default_chunk_size * obj.channels); - auto n_frames = fx(&obj, default_chunk_size, buf.data()); - if (n_frames <= 0) { - break; - } - auto data_size = n_frames * obj.channels; - total_buf_size += data_size; - buf.resize(data_size); - frames.emplace_back(std::move(buf)); - } - - return total_buf_size; - } - - OrtStatusPtr Compute(const ortc::Tensor& input, - const std::optional format, - ortc::Tensor& output0) const { - const uint8_t* p_data = input.Data(); - auto input_dim = input.Shape(); - OrtStatusPtr status = nullptr; - if (!((input_dim.size() == 1) || (input_dim.size() == 2 && input_dim[0] == 1))) { - status = OrtW::CreateStatus("[AudioDecoder]: Expect input dimension [n] or [1,n].", ORT_INVALID_ARGUMENT); - return status; - } - - std::string str_format; - if (format) { - str_format = *format; - } - auto stream_format = ReadStreamFormat(p_data, str_format, status); - if (status) { - return status; - } - - int64_t total_buf_size = 0; - std::list> lst_frames; - int64_t orig_sample_rate = 0; - int64_t orig_channels = 0; - - if (stream_format == AudioStreamType::kMP3) { - auto mp3_obj_ptr = std::make_unique(); - if (!drmp3_init_memory(mp3_obj_ptr.get(), p_data, input.NumberOfElement(), nullptr)) { - status = OrtW::CreateStatus("[AudioDecoder]: unexpected error on MP3 stream.", ORT_RUNTIME_EXCEPTION); - return status; - } - orig_sample_rate = mp3_obj_ptr->sampleRate; - orig_channels = mp3_obj_ptr->channels; - total_buf_size = DrReadFrames(lst_frames, drmp3_read_pcm_frames_f32, *mp3_obj_ptr); - - } else if (stream_format == AudioStreamType::kFLAC) { - drflac* flac_obj = drflac_open_memory(p_data, input.NumberOfElement(), nullptr); - auto flac_obj_closer = gsl::finally([flac_obj]() { drflac_close(flac_obj); }); - if (flac_obj == nullptr) { - status = OrtW::CreateStatus("[AudioDecoder]: unexpected error on FLAC stream.", ORT_RUNTIME_EXCEPTION); - return status; - } - orig_sample_rate = flac_obj->sampleRate; - orig_channels = flac_obj->channels; - total_buf_size = DrReadFrames(lst_frames, drflac_read_pcm_frames_f32, *flac_obj); - - } else { - drwav wav_obj; - if (!drwav_init_memory(&wav_obj, p_data, input.NumberOfElement(), nullptr)) { - status = OrtW::CreateStatus("[AudioDecoder]: unexpected error on WAV stream.", ORT_RUNTIME_EXCEPTION); - return status; - } - orig_sample_rate = wav_obj.sampleRate; - orig_channels = wav_obj.channels; - total_buf_size = DrReadFrames(lst_frames, drwav_read_pcm_frames_f32, wav_obj); - } - - if (downsample_rate_ != 0 && - orig_sample_rate < downsample_rate_) { - status = OrtW::CreateStatus("[AudioDecoder]: only down-sampling supported.", ORT_INVALID_ARGUMENT); - return status; - } - - // join all frames - std::vector buf; - buf.resize(total_buf_size); - int64_t offset = 0; - for (auto& _b : lst_frames) { - std::copy(_b.begin(), _b.end(), buf.begin() + offset); - offset += _b.size(); - } - - // mix the stereo channels into mono channel - if (stereo_mixer_ && orig_channels > 1) { - if (buf.size() > 1) { - for (size_t i = 0; i < buf.size() / 2; ++i) { - buf[i] = (buf[i * 2] + buf[i * 2 + 1]) / 2; - } - buf.resize(buf.size() / 2); - } - } - - if (downsample_rate_ != 0 && - downsample_rate_ != orig_sample_rate) { - // A lowpass filter on buf audio data to remove high frequency noise - ButterworthLowpass filter(0.5 * downsample_rate_, 1.0 * orig_sample_rate); - std::vector filtered_buf = filter.Process(buf); - // downsample the audio data - KaiserWindowInterpolation::Process(filtered_buf, buf, - 1.0f * orig_sample_rate, 1.0f * downsample_rate_); - } - - std::vector dim_out = {1, ort_extensions::narrow(buf.size())}; - float* p_output = output0.Allocate(dim_out); - std::copy(buf.begin(), buf.end(), p_output); - return status; - } - - private: - int64_t downsample_rate_{}; - int64_t stereo_mixer_{}; -}; diff --git a/operators/cuda/attention_lib/flash_attention/flash.h b/operators/cuda/attention_lib/flash_attention/flash.h index 603a6e068..5f5be4078 100644 --- a/operators/cuda/attention_lib/flash_attention/flash.h +++ b/operators/cuda/attention_lib/flash_attention/flash.h @@ -87,6 +87,13 @@ struct Flash_fwd_params : public Qkv_params { // The indices to index into the KV cache. int* __restrict__ cache_batch_idx = nullptr; + // Paged KV cache + int * __restrict__ block_table; + index_t block_table_batch_stride; + int page_block_size; + + float rp_dropout; + // Local window size int window_size_left = -1; int window_size_right = -1; @@ -102,6 +109,9 @@ struct Flash_fwd_params : public Qkv_params { int num_splits = 0; // For split-KV version + void * __restrict__ alibi_slopes_ptr; + index_t alibi_slopes_batch_stride; + const cudaDeviceProp* dprops = nullptr; }; diff --git a/operators/cuda/attention_lib/flash_attention/flash_api.cc b/operators/cuda/attention_lib/flash_attention/flash_api.cc index 46812b560..586a7a471 100644 --- a/operators/cuda/attention_lib/flash_attention/flash_api.cc +++ b/operators/cuda/attention_lib/flash_attention/flash_api.cc @@ -32,7 +32,9 @@ void set_params_fprop(Flash_fwd_params& params, bool is_bf16, bool kv_bsnh = true, int window_size_left = -1, - int window_size_right = -1) { + int window_size_right = -1, + bool paged_KV = false, + int page_block_size = -1) { // Set the pointers and strides. params.q_ptr = q; params.k_ptr = k; @@ -64,8 +66,8 @@ void set_params_fprop(Flash_fwd_params& params, if (cu_seqlens_q_d == nullptr) { params.q_batch_stride = seqlen_q * num_heads * head_size; // stride(0) - params.k_batch_stride = seqlen_k * num_heads_k * head_size; // stride(0) - params.v_batch_stride = seqlen_k * num_heads_k * head_size; // stride(0) + params.k_batch_stride = (paged_KV ? page_block_size : seqlen_k) * num_heads_k * head_size; // stride(0) + params.v_batch_stride = (paged_KV ? page_block_size : seqlen_k) * num_heads_k * head_size; // stride(0) params.o_batch_stride = seqlen_q * num_heads * head_size; // stride(0) } else { params.q_batch_stride = 0; @@ -99,6 +101,10 @@ void set_params_fprop(Flash_fwd_params& params, params.scale_softmax = softmax_scale; params.scale_softmax_log2 = softmax_scale * M_LOG2E; + params.rp_dropout = 1.f; + params.alibi_slopes_ptr = nullptr; + params.alibi_slopes_batch_stride = 0; + // In our API, causal/unidirectional determines if we only look at prior tokens. However, the flash API seperates // local and causal, meaning when we have local window size params.is_causal = is_causal; @@ -349,8 +355,8 @@ bool is_supported(const cudaDeviceProp& dprops, int head_size, int num_heads, in OrtStatusPtr mha_fwd_kvcache(const cudaDeviceProp& dprops, cudaStream_t stream, void* q, // batch_size x seqlen_q x num_heads x head_size - void* kcache, // batch_size x seqlen_k_max x num_heads_k x head_size or batch_size x num_heads_k seqlen_k_max x head_size - void* vcache, // batch_size x seqlen_k_max x num_heads_k x head_size or batch_size x num_heads_k seqlen_k_max x head_size + void* kcache, // batch_size x seqlen_k_max x num_heads_k x head_size or batch_size x num_heads_k seqlen_k_max x head_size or num_blocks x page_block_size x num_heads_k x head_size if there's a block_table + void* vcache, // batch_size x seqlen_k_max x num_heads_k x head_size or batch_size x num_heads_k seqlen_k_max x head_size or num_blocks x page_block_size x num_heads_k x head_size if there's a block_table void* k_new, // (optional) batch_size x seqlen_k_new x num_heads_k x head_size void* v_new, // (optional) batch_size x seqlen_k_new x num_heads_k x head_size void* out, // batch_size x seqlen_q x num_heads x head_size @@ -374,7 +380,10 @@ OrtStatusPtr mha_fwd_kvcache(const cudaDeviceProp& dprops, void* out_accum, // num_splits x batch_size x seqlen_q x num_heads x head_size_rounded int local_window_size, bool is_rotary_interleaved, - bool is_packed_qkv) { + bool is_packed_qkv, + int32_t* block_table, // batch_size x max_num_blocks_per_seq + int32_t max_num_blocks_per_seq, + int32_t page_block_size) { auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; }; const int head_size_rounded = round_multiple(head_size, 32); const int seqlen_q_rounded = round_multiple(seqlen_q, 128); @@ -398,7 +407,9 @@ OrtStatusPtr mha_fwd_kvcache(const cudaDeviceProp& dprops, is_bf16, past_bsnh, local_window_size, - is_causal ? 0 : -1); + is_causal ? 0 : -1, + block_table != nullptr, + page_block_size); params.dprops = &dprops; if (k_new != nullptr && v_new != nullptr) { @@ -454,6 +465,10 @@ OrtStatusPtr mha_fwd_kvcache(const cudaDeviceProp& dprops, params.oaccum_ptr = nullptr; } + params.block_table = block_table; + params.block_table_batch_stride = max_num_blocks_per_seq; + params.page_block_size = page_block_size; + // Only split kernel supports appending to KV cache run_mha_fwd(params, stream, /*force_split_kernel=*/k_new != nullptr); diff --git a/operators/cuda/attention_lib/flash_attention/flash_api.h b/operators/cuda/attention_lib/flash_attention/flash_api.h index 4ad1b76e1..07640d4c8 100644 --- a/operators/cuda/attention_lib/flash_attention/flash_api.h +++ b/operators/cuda/attention_lib/flash_attention/flash_api.h @@ -53,8 +53,8 @@ OrtStatusPtr mha_varlen_fwd(const cudaDeviceProp& dprops, OrtStatusPtr mha_fwd_kvcache(const cudaDeviceProp& dprops, cudaStream_t stream, void* q, // batch_size x seqlen_q x num_heads x head_size - void* kcache, // batch_size x seqlen_k x num_heads_k x head_size or batch_size x num_heads_k seqlen_k x x head_size - void* vcache, // batch_size x seqlen_k x num_heads_k x head_size or batch_size x num_heads_k seqlen_k x x head_size + void* kcache, // batch_size x seqlen_k x num_heads_k x head_size or batch_size x num_heads_k seqlen_k x x head_size or num_blocks x page_block_size x num_heads_k x head_size if there's a block_table + void* vcache, // batch_size x seqlen_k x num_heads_k x head_size or batch_size x num_heads_k seqlen_k x x head_size or num_blocks x page_block_size x num_heads_k x head_size if there's a block_table void* k, // batch_size x seqlen_k_new x num_heads_k x head_size void* v, // batch_size x seqlen_k_new x num_heads_k x head_size void* out, // batch_size x seqlen_q x num_heads x head_size @@ -78,7 +78,10 @@ OrtStatusPtr mha_fwd_kvcache(const cudaDeviceProp& dprops, void* out_accum = nullptr, // num_splits x batch_size x seqlen_q x num_heads x head_size_rounded int local_window_size = -1, bool is_rotary_interleaved = false, - bool is_packed_qkv = false); + bool is_packed_qkv = false, + int32_t* block_table = nullptr, // batch_size x max_num_blocks_per_seq + int32_t max_num_blocks_per_seq = -1, + int32_t page_block_size = 1); size_t get_softmax_lse_size(int max_seqlen_q, int batch_size, int num_heads); diff --git a/operators/cuda/attention_lib/flash_attention/flash_fwd_kernel.h b/operators/cuda/attention_lib/flash_attention/flash_fwd_kernel.h index c44a470f6..47263d411 100644 --- a/operators/cuda/attention_lib/flash_attention/flash_fwd_kernel.h +++ b/operators/cuda/attention_lib/flash_attention/flash_fwd_kernel.h @@ -28,1027 +28,1006 @@ using namespace cute; //////////////////////////////////////////////////////////////////////////////////////////////////// -template -inline __device__ void softmax_rescale_o(Tensor0& scores, Tensor1& scores_max, Tensor1& scores_sum, - Tensor2& acc_o, float softmax_scale_log2) { - if (Is_first) { - flash::template reduce_max(scores, scores_max); - flash::scale_apply_exp2(scores, scores_max, softmax_scale_log2); - flash::reduce_sum(scores, scores_sum); - } else { - cute::Tensor scores_max_prev = make_fragment_like(scores_max); - cute::copy(scores_max, scores_max_prev); - flash::template reduce_max(scores, scores_max); - // Reshape acc_o from (MMA=4, MMA_M, MMA_K) to (nrow=(2, MMA_M), ncol=(2, MMA_K)) - cute::Tensor acc_o_rowcol = make_tensor(acc_o.data(), flash::convert_layout_acc_rowcol(acc_o.layout())); -#pragma unroll - for (int mi = 0; mi < cute::size(scores_max); ++mi) { - float scores_max_cur = !Check_inf - ? scores_max(mi) - : (scores_max(mi) == -INFINITY ? 0.0f : scores_max(mi)); - float scores_scale = exp2f((scores_max_prev(mi) - scores_max_cur) * softmax_scale_log2); - scores_sum(mi) *= scores_scale; -#pragma unroll - for (int ni = 0; ni < cute::size<1>(acc_o_rowcol); ++ni) { - acc_o_rowcol(mi, ni) *= scores_scale; - } - } - flash::scale_apply_exp2(scores, scores_max, softmax_scale_log2); - cute::Tensor scores_sum_cur = make_fragment_like(scores_sum); - flash::reduce_sum(scores, scores_sum_cur); -#pragma unroll - for (int mi = 0; mi < cute::size(scores_sum); ++mi) { - scores_sum(mi) += scores_sum_cur(mi); - } - } -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template -inline __device__ void write_softmax_to_gmem( - cute::Tensor const& tOrP, cute::Tensor& tPgP, TiledCopy gmem_tiled_copy_P) { - // Reshape tOrP from (8, MMA_M, MMA_N) to (8, MMA_M * MMA_N) - cute::Layout l = tOrP.layout(); - cute::Tensor tPrP = make_tensor(tOrP.data(), make_layout(get<0>(l), make_layout(get<1>(l), get<2>(l)))); - CUTE_STATIC_ASSERT_V(cute::size<2>(tPgP) == _1{}); - CUTE_STATIC_ASSERT_V(cute::size<1>(tPrP) == cute::size<1>(tPgP)); -#pragma unroll - for (int mi = 0; mi < cute::size<1>(tPrP); ++mi) { - cute::copy(gmem_tiled_copy_P, tPrP(_, mi), tPgP(_, mi, 0)); - } -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template +template inline __device__ void compute_attn_1rowblock(const Params& params, const int bidb, const int bidh, const int m_block) { - using Element = typename Kernel_traits::Element; - using ElementAccum = typename Kernel_traits::ElementAccum; - using index_t = typename Kernel_traits::index_t; - - // Shared memory. - extern __shared__ char smem_[]; - - // The thread index. - const int tidx = threadIdx.x; - - constexpr int kBlockM = Kernel_traits::kBlockM; - constexpr int kBlockN = Kernel_traits::kBlockN; - constexpr int kHeadDim = Kernel_traits::kHeadDim; - constexpr int kNWarps = Kernel_traits::kNWarps; - constexpr int MMA_M = kBlockM / decltype(cute::size<0>(typename Kernel_traits::TiledMma::TiledShape_MNK{}))::value; - - const BlockInfo binfo(params, bidb); - if (m_block * kBlockM >= binfo.actual_seqlen_q || binfo.actual_seqlen_k == 0) return; - - const int n_block_min = !Is_local ? 0 : std::max(0, (m_block * kBlockM + binfo.actual_seqlen_k - binfo.actual_seqlen_q - params.window_size_left) / kBlockN); - int n_block_max = cute::ceil_div(binfo.actual_seqlen_k, kBlockN); - if (Is_causal || Is_local) { - n_block_max = std::min(n_block_max, - cute::ceil_div((m_block + 1) * kBlockM + binfo.actual_seqlen_k - binfo.actual_seqlen_q + params.window_size_right, kBlockN)); - // We exit early and write 0 to gO and gLSE. + using Element = typename Kernel_traits::Element; + using ElementAccum = typename Kernel_traits::ElementAccum; + using index_t = typename Kernel_traits::index_t; + + // Shared memory. + extern __shared__ char smem_[]; + + // The thread index. + const int tidx = threadIdx.x; + + constexpr int kBlockM = Kernel_traits::kBlockM; + constexpr int kBlockN = Kernel_traits::kBlockN; + constexpr int kHeadDim = Kernel_traits::kHeadDim; + constexpr int kNWarps = Kernel_traits::kNWarps; + + const BlockInfo binfo(params, bidb); + if (m_block * kBlockM >= binfo.actual_seqlen_q) return; + + const int n_block_min = !Is_local ? 0 : std::max(0, (m_block * kBlockM + binfo.actual_seqlen_k - binfo.actual_seqlen_q - params.window_size_left) / kBlockN); + int n_block_max = cute::ceil_div(binfo.actual_seqlen_k, kBlockN); + if (Is_causal || Is_local) { + n_block_max = std::min(n_block_max, + cute::ceil_div((m_block + 1) * kBlockM + binfo.actual_seqlen_k - binfo.actual_seqlen_q + params.window_size_right, kBlockN)); + // if (threadIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0) { + // printf("m_block = %d, n_block_max = %d\n", m_block, n_block_max); + // } + } + // We exit early and write 0 to gO and gLSE. This also covers the case where actual_seqlen_k == 0. // Otherwise we might read OOB elements from gK and gV. - if (n_block_max <= n_block_min) { - const index_t row_offset_o = binfo.q_offset(params.o_batch_stride, params.o_row_stride, bidb) + m_block * kBlockM * params.o_row_stride + bidh * params.o_head_stride; - const index_t row_offset_lse = (bidb * params.h + bidh) * params.seqlen_q + m_block * kBlockM; - Tensor gO = make_tensor(make_gmem_ptr(reinterpret_cast(params.o_ptr) + row_offset_o), - Shape, Int>{}, - make_stride(params.o_row_stride, _1{})); - Tensor gLSE = make_tensor(make_gmem_ptr(reinterpret_cast(params.softmax_lse_ptr) + row_offset_lse), - Shape>{}, Stride<_1>{}); - - typename Kernel_traits::GmemTiledCopyO gmem_tiled_copy_O; - auto gmem_thr_copy_O = gmem_tiled_copy_O.get_thread_slice(tidx); - Tensor tOgO = gmem_thr_copy_O.partition_D(gO); - Tensor tOrO = make_tensor(shape(tOgO)); - clear(tOrO); - // Construct identity layout for sO - Tensor cO = make_identity_tensor(make_shape(size<0>(gO), size<1>(gO))); // (BLK_M,BLK_K) -> (blk_m,blk_k) - // Repeat the partitioning with identity layouts - Tensor tOcO = gmem_thr_copy_O.partition_D(cO); - Tensor tOpO = make_tensor(make_shape(size<2>(tOgO))); - if (!Is_even_K) { -#pragma unroll - for (int k = 0; k < size(tOpO); ++k) { - tOpO(k) = get<1>(tOcO(0, 0, k)) < params.d; + if ((Is_causal || Is_local || !Is_even_MN) && n_block_max <= n_block_min) { + Tensor mO = make_tensor(make_gmem_ptr(reinterpret_cast(params.o_ptr) + + binfo.q_offset(params.o_batch_stride, params.o_row_stride, bidb)), + make_shape(binfo.actual_seqlen_q, params.h, params.d), + make_stride(params.o_row_stride, params.o_head_stride, _1{})); + Tensor gO = local_tile(mO(_, bidh, _), Shape, Int>{}, + make_coord(m_block, 0)); // (kBlockM, kHeadDim) + Tensor mLSE = make_tensor(make_gmem_ptr(reinterpret_cast(params.softmax_lse_ptr)), + make_shape(params.b, params.h, params.seqlen_q), + make_stride(params.h * params.seqlen_q, params.seqlen_q, _1{})); + Tensor gLSE = local_tile(mLSE(bidb, bidh, _), Shape>{}, make_coord(m_block)); + + typename Kernel_traits::GmemTiledCopyO gmem_tiled_copy_O; + auto gmem_thr_copy_O = gmem_tiled_copy_O.get_thread_slice(tidx); + Tensor tOgO = gmem_thr_copy_O.partition_D(gO); + Tensor tOrO = make_tensor(shape(tOgO)); + clear(tOrO); + // Construct identity layout for sO + Tensor cO = make_identity_tensor(make_shape(size<0>(gO), size<1>(gO))); // (BLK_M,BLK_K) -> (blk_m,blk_k) + // Repeat the partitioning with identity layouts + Tensor tOcO = gmem_thr_copy_O.partition_D(cO); + Tensor tOpO = make_tensor(make_shape(size<2>(tOgO))); + if (!Is_even_K) { + #pragma unroll + for (int k = 0; k < size(tOpO); ++k) { tOpO(k) = get<1>(tOcO(0, 0, k)) < params.d; } } - } - // Clear_OOB_K must be false since we don't want to write zeros to gmem - flash::copy( - gmem_tiled_copy_O, tOrO, tOgO, tOcO, tOpO, binfo.actual_seqlen_q - m_block * kBlockM); -#pragma unroll - for (int m = 0; m < size<1>(tOgO); ++m) { - const int row = get<0>(tOcO(0, m, 0)); - if (row < binfo.actual_seqlen_q - m_block * kBlockM && get<1>(tOcO(0, m, 0)) == 0) { - gLSE(row) = INFINITY; + // Clear_OOB_K must be false since we don't want to write zeros to gmem + flash::copy( + gmem_tiled_copy_O, tOrO, tOgO, tOcO, tOpO, binfo.actual_seqlen_q - m_block * kBlockM + ); + #pragma unroll + for (int m = 0; m < size<1>(tOgO); ++m) { + const int row = get<0>(tOcO(0, m, 0)); + if (row < binfo.actual_seqlen_q - m_block * kBlockM && get<1>(tOcO(0, m, 0)) == 0) { gLSE(row) = INFINITY; } } - } - return; + return; } - } - - // We iterate over the blocks in reverse order. This is because the last block is the only one - // that needs masking when we read K and V from global memory. Moreover, iterating in reverse - // might save us 1 register (we just need n_block instead of both n_block and n_block_max). - - const index_t row_offset_q = binfo.q_offset(params.q_batch_stride, params.q_row_stride, bidb) + m_block * kBlockM * params.q_row_stride + bidh * params.q_head_stride; - // We move K and V to the last block. - const index_t row_offset_k = binfo.k_offset(params.k_batch_stride, params.k_row_stride, bidb) + (n_block_max - 1) * kBlockN * params.k_row_stride + (bidh / params.h_h_k_ratio) * params.k_head_stride; - const index_t row_offset_v = binfo.k_offset(params.v_batch_stride, params.v_row_stride, bidb) + (n_block_max - 1) * kBlockN * params.v_row_stride + (bidh / params.h_h_k_ratio) * params.v_head_stride; - const index_t row_offset_p = ((bidb * params.h + bidh) * params.seqlen_q_rounded + m_block * kBlockM) * params.seqlen_k_rounded + (n_block_max - 1) * kBlockN; - cute::Tensor gQ = make_tensor(make_gmem_ptr(reinterpret_cast(params.q_ptr) + row_offset_q), - cute::Shape, cute::Int>{}, - make_stride(params.q_row_stride, _1{})); - cute::Tensor gK = make_tensor(make_gmem_ptr(reinterpret_cast(params.k_ptr) + row_offset_k), - cute::Shape, cute::Int>{}, - make_stride(params.k_row_stride, _1{})); - cute::Tensor gV = make_tensor(make_gmem_ptr(reinterpret_cast(params.v_ptr) + row_offset_v), - cute::Shape, cute::Int>{}, - make_stride(params.v_row_stride, _1{})); - cute::Tensor gP = make_tensor(make_gmem_ptr(reinterpret_cast(params.p_ptr) + row_offset_p), - cute::Shape, cute::Int>{}, - make_stride(params.seqlen_k_rounded, _1{})); - - cute::Tensor sQ = make_tensor(make_smem_ptr(reinterpret_cast(smem_)), - typename Kernel_traits::SmemLayoutQ{}); - // Careful we're using the same smem for sQ and sK | sV if Share_Q_K_smem; - cute::Tensor sK = make_tensor(sQ.data() + (Kernel_traits::Share_Q_K_smem ? 0 : cute::size(sQ)), - typename Kernel_traits::SmemLayoutKV{}); - cute::Tensor sV = make_tensor(sK.data() + cute::size(sK), typename Kernel_traits::SmemLayoutKV{}); - cute::Tensor sVt = make_tensor(sV.data(), typename Kernel_traits::SmemLayoutVtransposed{}); - cute::Tensor sVtNoSwizzle = make_tensor(sV.data(), typename Kernel_traits::SmemLayoutVtransposedNoSwizzle{}); - - typename Kernel_traits::GmemTiledCopyQKV gmem_tiled_copy_QKV; - auto gmem_thr_copy_QKV = gmem_tiled_copy_QKV.get_thread_slice(tidx); - typename Kernel_traits::GmemTiledCopyP gmem_tiled_copy_P; - auto gmem_thr_copy_P = gmem_tiled_copy_P.get_thread_slice(tidx); - - cute::Tensor tQgQ = gmem_thr_copy_QKV.partition_S(gQ); - cute::Tensor tQsQ = gmem_thr_copy_QKV.partition_D(sQ); - cute::Tensor tKgK = gmem_thr_copy_QKV.partition_S(gK); // (KCPY, KCPY_N, KCPY_K) - cute::Tensor tKsK = gmem_thr_copy_QKV.partition_D(sK); - cute::Tensor tVgV = gmem_thr_copy_QKV.partition_S(gV); // (VCPY, VCPY_N, VCPY_K) - cute::Tensor tVsV = gmem_thr_copy_QKV.partition_D(sV); - cute::Tensor tPgP = gmem_thr_copy_P.partition_D(gP); - - typename Kernel_traits::TiledMma tiled_mma; - auto thr_mma = tiled_mma.get_thread_slice(tidx); - cute::Tensor tSrQ = thr_mma.partition_fragment_A(sQ); // (MMA,MMA_M,MMA_K) - cute::Tensor tSrK = thr_mma.partition_fragment_B(sK); // (MMA,MMA_N,MMA_K) - cute::Tensor tOrVt = thr_mma.partition_fragment_B(sVtNoSwizzle); // (MMA, MMA_K,MMA_N) - - cute::Tensor acc_o = partition_fragment_C(tiled_mma, cute::Shape, cute::Int>{}); // MMA, MMA_M, MMA_K - - // - // Copy Atom retiling - // - - auto smem_tiled_copy_Q = make_tiled_copy_A(typename Kernel_traits::SmemCopyAtom{}, tiled_mma); - auto smem_thr_copy_Q = smem_tiled_copy_Q.get_thread_slice(tidx); - cute::Tensor tSsQ = smem_thr_copy_Q.partition_S(sQ); - - auto smem_tiled_copy_K = make_tiled_copy_B(typename Kernel_traits::SmemCopyAtom{}, tiled_mma); - auto smem_thr_copy_K = smem_tiled_copy_K.get_thread_slice(tidx); - cute::Tensor tSsK = smem_thr_copy_K.partition_S(sK); - - auto smem_tiled_copy_V = make_tiled_copy_B(typename Kernel_traits::SmemCopyAtomTransposed{}, tiled_mma); - auto smem_thr_copy_V = smem_tiled_copy_V.get_thread_slice(tidx); - cute::Tensor tOsVt = smem_thr_copy_V.partition_S(sVt); - - // TODO: this might need to change if we change the mma instruction in SM70 - cute::Tensor scores_max = make_tensor(cute::Shape(acc_o)>>{}); - cute::Tensor scores_sum = make_fragment_like(scores_max); - - // - // PREDICATES - // - - // Construct identity layout for sQ and sK - cute::Tensor cQ = make_identity_tensor(make_shape(cute::size<0>(sQ), cute::size<1>(sQ))); // (BLK_M,BLK_K) -> (blk_m,blk_k) - cute::Tensor cKV = make_identity_tensor(make_shape(cute::size<0>(sK), cute::size<1>(sK))); // (BLK_N,BLK_K) -> (blk_n,blk_k) + // if (tidx == 0) { printf("m_block = %d, n_block_min = %d, n_block_max = %d\n", m_block, n_block_min, n_block_max); } + + // We iterate over the blocks in reverse order. This is because the last block is the only one + // that needs masking when we read K and V from global memory. Moreover, iterating in reverse + // might save us 1 register (we just need n_block instead of both n_block and n_block_max). + + const index_t row_offset_p = ((bidb * params.h + bidh) * params.seqlen_q_rounded + + m_block * kBlockM) * params.seqlen_k_rounded + (n_block_max - 1) * kBlockN; + + Tensor mQ = make_tensor(make_gmem_ptr(reinterpret_cast(params.q_ptr) + + binfo.q_offset(params.q_batch_stride, params.q_row_stride, bidb)), + make_shape(binfo.actual_seqlen_q, params.h, params.d), + make_stride(params.q_row_stride, params.q_head_stride, _1{})); + Tensor gQ = local_tile(mQ(_, bidh, _), Shape, Int>{}, + make_coord(m_block, 0)); // (kBlockM, kHeadDim) + Tensor mK = make_tensor(make_gmem_ptr(reinterpret_cast(params.k_ptr) + + binfo.k_offset(params.k_batch_stride, params.k_row_stride, bidb)), + make_shape(binfo.actual_seqlen_k, params.h_k, params.d), + make_stride(params.k_row_stride, params.k_head_stride, _1{})); + Tensor gK = local_tile(mK(_, bidh / params.h_h_k_ratio, _), Shape, Int>{}, + make_coord(_, 0)); // (kBlockN, kHeadDim, nblocksN) + Tensor mV = make_tensor(make_gmem_ptr(reinterpret_cast(params.v_ptr) + + binfo.k_offset(params.v_batch_stride, params.v_row_stride, bidb)), + make_shape(binfo.actual_seqlen_k, params.h_k, params.d), + make_stride(params.v_row_stride, params.v_head_stride, _1{})); + Tensor gV = local_tile(mV(_, bidh / params.h_h_k_ratio, _), Shape, Int>{}, + make_coord(_, 0)); // (kBlockN, kHeadDim, nblocksN) + Tensor gP = make_tensor(make_gmem_ptr(reinterpret_cast(params.p_ptr) + row_offset_p), + Shape, Int>{}, + make_stride(params.seqlen_k_rounded, _1{})); + + Tensor sQ = make_tensor(make_smem_ptr(reinterpret_cast(smem_)), + typename Kernel_traits::SmemLayoutQ{}); + // Careful we're using the same smem for sQ and sK | sV if Share_Q_K_smem; + Tensor sK = make_tensor(sQ.data() + (Kernel_traits::Share_Q_K_smem ? 0 : size(sQ)), + typename Kernel_traits::SmemLayoutKV{}); + Tensor sV = make_tensor(sK.data() + size(sK), typename Kernel_traits::SmemLayoutKV{}); + Tensor sVt = make_tensor(sV.data(), typename Kernel_traits::SmemLayoutVtransposed{}); + Tensor sVtNoSwizzle = make_tensor(sV.data(), typename Kernel_traits::SmemLayoutVtransposedNoSwizzle{}); + + typename Kernel_traits::GmemTiledCopyQKV gmem_tiled_copy_QKV; + auto gmem_thr_copy_QKV = gmem_tiled_copy_QKV.get_thread_slice(tidx); + + Tensor tQgQ = gmem_thr_copy_QKV.partition_S(gQ); + Tensor tQsQ = gmem_thr_copy_QKV.partition_D(sQ); + Tensor tKgK = gmem_thr_copy_QKV.partition_S(gK); // (KCPY, KCPY_N, KCPY_K, nblocksN) + Tensor tKsK = gmem_thr_copy_QKV.partition_D(sK); + Tensor tVgV = gmem_thr_copy_QKV.partition_S(gV); // (VCPY, VCPY_N, VCPY_K, nblocksN) + Tensor tVsV = gmem_thr_copy_QKV.partition_D(sV); + + typename Kernel_traits::TiledMma tiled_mma; + auto thr_mma = tiled_mma.get_thread_slice(tidx); + Tensor tSrQ = thr_mma.partition_fragment_A(sQ); // (MMA,MMA_M,MMA_K) + Tensor tSrK = thr_mma.partition_fragment_B(sK); // (MMA,MMA_N,MMA_K) + Tensor tOrVt = thr_mma.partition_fragment_B(sVtNoSwizzle); // (MMA, MMA_K,MMA_N) + + Tensor tSgS = thr_mma.partition_C(gP); + + Tensor acc_o = partition_fragment_C(tiled_mma, Shape, Int>{}); // MMA, MMA_M, MMA_K + + // + // Copy Atom retiling + // + + auto smem_tiled_copy_Q = make_tiled_copy_A(typename Kernel_traits::SmemCopyAtom{}, tiled_mma); + auto smem_thr_copy_Q = smem_tiled_copy_Q.get_thread_slice(tidx); + // if (cute::thread0()) {smem_thr_copy_Q.print_all();} + Tensor tSsQ = smem_thr_copy_Q.partition_S(sQ); + // if (cute::thread0()) {print(tSsQ.layout()); printf("\n");} + + auto smem_tiled_copy_K = make_tiled_copy_B(typename Kernel_traits::SmemCopyAtom{}, tiled_mma); + auto smem_thr_copy_K = smem_tiled_copy_K.get_thread_slice(tidx); + Tensor tSsK = smem_thr_copy_K.partition_S(sK); + + auto smem_tiled_copy_V = make_tiled_copy_B(typename Kernel_traits::SmemCopyAtomTransposed{}, tiled_mma); + auto smem_thr_copy_V = smem_tiled_copy_V.get_thread_slice(tidx); + Tensor tOsVt = smem_thr_copy_V.partition_S(sVt); + + // + // PREDICATES + // + + // // Allocate predicate tensors for m and n + // Tensor tQpQ = make_tensor(make_shape(size<1>(tQsQ), size<2>(tQsQ)), Stride<_1,_0>{}); + // Tensor tKVpKV = make_tensor(make_shape(size<1>(tKsK), size<2>(tKsK)), Stride<_1,_0>{}); + + // Construct identity layout for sQ and sK + Tensor cQ = make_identity_tensor(make_shape(size<0>(sQ), size<1>(sQ))); // (BLK_M,BLK_K) -> (blk_m,blk_k) + Tensor cKV = make_identity_tensor(make_shape(size<0>(sK), size<1>(sK))); // (BLK_N,BLK_K) -> (blk_n,blk_k) + // Tensor tScQ = thr_mma.partition_A(cQ); // (MMA,MMA_M,MMA_K) + // if (cute::thread0()) { + // print(tScQ.layout()); printf("\n"); + // for (int i = 0; i < size(tScQ); ++i) { + // printf("%d ", get<0>(tScQ(i))); + // } + // printf("\n"); + // for (int i = 0; i < size(tScQ); ++i) { + // printf("%d ", get<1>(tScQ(i))); + // } + // printf("\n"); + // } - // Repeat the partitioning with identity layouts - cute::Tensor tQcQ = gmem_thr_copy_QKV.partition_S(cQ); // (ACPY,ACPY_M,ACPY_K) -> (blk_m,blk_k) - cute::Tensor tKVcKV = gmem_thr_copy_QKV.partition_S(cKV); // (BCPY,BCPY_N,BCPY_K) -> (blk_n,blk_k) + // Repeat the partitioning with identity layouts + Tensor tQcQ = gmem_thr_copy_QKV.partition_S(cQ); // (ACPY,ACPY_M,ACPY_K) -> (blk_m,blk_k) + Tensor tKVcKV = gmem_thr_copy_QKV.partition_S(cKV); // (BCPY,BCPY_N,BCPY_K) -> (blk_n,blk_k) - // Allocate predicate tensors for k - cute::Tensor tQpQ = make_tensor(make_shape(cute::size<2>(tQsQ))); - cute::Tensor tKVpKV = make_tensor(make_shape(cute::size<2>(tKsK))); + // Allocate predicate tensors for k + Tensor tQpQ = make_tensor(make_shape(size<2>(tQsQ))); + Tensor tKVpKV = make_tensor(make_shape(size<2>(tKsK))); - // Set predicates for k bounds - if (!Is_even_K) { -#pragma unroll - for (int k = 0; k < cute::size(tQpQ); ++k) { - tQpQ(k) = get<1>(tQcQ(0, 0, k)) < params.d; - } -#pragma unroll - for (int k = 0; k < cute::size(tKVpKV); ++k) { - tKVpKV(k) = get<1>(tKVcKV(0, 0, k)) < params.d; + // Set predicates for k bounds + if (!Is_even_K) { + #pragma unroll + for (int k = 0; k < size(tQpQ); ++k) { tQpQ(k) = get<1>(tQcQ(0, 0, k)) < params.d; } + #pragma unroll + for (int k = 0; k < size(tKVpKV); ++k) { tKVpKV(k) = get<1>(tKVcKV(0, 0, k)) < params.d; } } - } - - // Prologue - - cute::Tensor tQrQ = make_fragment_like(tQgQ); - // We don't need to clear the sQ smem tiles since we'll only write out the valid outputs - flash::copy(gmem_tiled_copy_QKV, tQgQ, tQsQ, tQcQ, tQpQ, - binfo.actual_seqlen_q - m_block * kBlockM); - if (Kernel_traits::Is_Q_in_regs) { - cute::cp_async_fence(); - } - - if (Kernel_traits::Share_Q_K_smem) { - flash::cp_async_wait<0>(); - __syncthreads(); - cute::Tensor tSrQ_copy_view = smem_thr_copy_Q.retile_D(tSrQ); - CUTE_STATIC_ASSERT_V(cute::size<1>(tSsQ) == cute::size<1>(tSrQ_copy_view)); // M - cute::copy(smem_tiled_copy_Q, tSsQ, tSrQ_copy_view); - __syncthreads(); - } - - int n_block = n_block_max - 1; - // We don't need to clear the sK smem tiles since we'll mask out the scores anyway. - flash::copy(gmem_tiled_copy_QKV, tKgK, tKsK, tKVcKV, tKVpKV, - binfo.actual_seqlen_k - n_block * kBlockN); - cute::cp_async_fence(); - - if (Kernel_traits::Is_Q_in_regs && !Kernel_traits::Share_Q_K_smem) { - flash::cp_async_wait<1>(); - __syncthreads(); - cute::Tensor tSrQ_copy_view = smem_thr_copy_Q.retile_D(tSrQ); - CUTE_STATIC_ASSERT_V(cute::size<1>(tSsQ) == cute::size<1>(tSrQ_copy_view)); // M - cute::copy(smem_tiled_copy_Q, tSsQ, tSrQ_copy_view); - } - clear(acc_o); + // Prologue - // For performance reason, we separate out two kinds of iterations: - // those that need masking on S, and those that don't. - // We need masking on S for the very last block when K and V has length not multiple of kBlockN. - // We also need masking on S if it's causal, for the last ceil_div(kBlockM, kBlockN) blocks. - // We will have at least 1 "masking" iteration. - - // If not even_N, then seqlen_k might end in the middle of a block. In that case we need to - // mask 2 blocks (e.g. when kBlockM == kBlockN), not just 1. - constexpr int n_masking_steps = (!Is_causal && !Is_local) - ? 1 - : ((Is_even_MN && Is_causal) ? cute::ceil_div(kBlockM, kBlockN) : cute::ceil_div(kBlockM, kBlockN) + 1); -#pragma unroll - for (int masking_step = 0; masking_step < n_masking_steps; ++masking_step, --n_block) { - cute::Tensor acc_s = partition_fragment_C(tiled_mma, cute::Shape, cute::Int>{}); // (MMA=4, MMA_M, MMA_N) - clear(acc_s); - flash::cp_async_wait<0>(); - __syncthreads(); - - // Advance gV - if (masking_step > 0) { - tVgV.data() = tVgV.data() + (-int(kBlockN * params.v_row_stride)); - flash::copy(gmem_tiled_copy_QKV, tVgV, tVsV, tKVcKV, tKVpKV); - } else { - // Clear the smem tiles to account for predicated off loads - flash::copy( - gmem_tiled_copy_QKV, tVgV, tVsV, tKVcKV, tKVpKV, binfo.actual_seqlen_k - n_block * kBlockN); + // We don't need to clear the sQ smem tiles since we'll only write out the valid outputs + flash::copy(gmem_tiled_copy_QKV, tQgQ, tQsQ, tQcQ, tQpQ, + binfo.actual_seqlen_q - m_block * kBlockM); + if (Kernel_traits::Is_Q_in_regs) { cute::cp_async_fence(); } + + // // if (cute::thread(1, 0)) { print(tQsQ); } + // // Tensor sQNoSwizzle = make_tensor(make_smem_ptr(reinterpret_cast(smem_)), typename Kernel_traits::SmemLayoutQNoSwizzle{}); + // // if (cute::thread0()) { print(sQNoSwizzle); } + + if (Kernel_traits::Share_Q_K_smem) { + flash::cp_async_wait<0>(); + __syncthreads(); + Tensor tSrQ_copy_view = smem_thr_copy_Q.retile_D(tSrQ); + CUTE_STATIC_ASSERT_V(size<1>(tSsQ) == size<1>(tSrQ_copy_view)); // M + cute::copy(smem_tiled_copy_Q, tSsQ, tSrQ_copy_view); + __syncthreads(); } - cute::cp_async_fence(); - flash::gemm( - acc_s, tSrQ, tSrK, tSsQ, tSsK, tiled_mma, smem_tiled_copy_Q, smem_tiled_copy_K, - smem_thr_copy_Q, smem_thr_copy_K); - // if (cute::thread0()) { print(acc_s); } - - // Reshape acc_s from (MMA=4, MMA_M, MMA_N) to (nrow=(2, MMA_M), ncol=(2, MMA_N)) - cute::Tensor scores = make_tensor(acc_s.data(), flash::convert_layout_acc_rowcol(acc_s.layout())); - - // We don't put the masking before the matmul S = Q K^T because we don't clear sK - // for rows outside actual_seqlen_k. So those rows could have Inf / NaN, and the matmul - // can produce Inf / NaN. - if (!Is_causal && !Is_local) { - if (!Is_even_MN) { - flash::apply_mask(scores, binfo.actual_seqlen_k - n_block * kBlockN); - } - } else { - // I can't get the stride from idx_row - flash::apply_mask_local(scores, n_block * kBlockN, binfo.actual_seqlen_k, - // m_block * kBlockM + get<0>(idx_row(0)), - m_block * kBlockM + (tidx / 32) * 16 + (tidx % 32) / 4, - binfo.actual_seqlen_q, kNWarps * 16, - params.window_size_left, params.window_size_right); - } + int n_block = n_block_max - 1; + // We don't need to clear the sK smem tiles since we'll mask out the scores anyway. + flash::copy(gmem_tiled_copy_QKV, tKgK(_, _, _, n_block), tKsK, tKVcKV, tKVpKV, + binfo.actual_seqlen_k - n_block * kBlockN); + cute::cp_async_fence(); + // if (threadIdx.x == 0 && blockIdx.y == 0 && blockIdx.z < 2) { print(tKgK); } + // __syncthreads(); - flash::cp_async_wait<0>(); - __syncthreads(); - if (n_block > n_block_min) { - // Advance gK - tKgK.data() = tKgK.data() + (-int(kBlockN * params.k_row_stride)); - flash::copy(gmem_tiled_copy_QKV, tKgK, tKsK, tKVcKV, tKVpKV); - // This cp_async_fence needs to be in the if block, otherwise the synchronization - // isn't right and we get race conditions. - cute::cp_async_fence(); + if (Kernel_traits::Is_Q_in_regs && !Kernel_traits::Share_Q_K_smem) { + flash::cp_async_wait<1>(); + __syncthreads(); + Tensor tSrQ_copy_view = smem_thr_copy_Q.retile_D(tSrQ); + CUTE_STATIC_ASSERT_V(size<1>(tSsQ) == size<1>(tSrQ_copy_view)); // M + cute::copy(smem_tiled_copy_Q, tSsQ, tSrQ_copy_view); } - // TODO: when we have key_padding_mask we'll need to Check_inf - masking_step == 0 - ? softmax_rescale_o(scores, scores_max, scores_sum, acc_o, params.scale_softmax_log2) - : softmax_rescale_o(scores, scores_max, scores_sum, acc_o, params.scale_softmax_log2); - - // Convert scores from fp32 to fp16/bf16 - cute::Tensor rP = flash::convert_type(scores); - // Reshape rP from (nrow=(2, MMA_M), ncol=(2, MMA_N)) to ((2, 2, 2), MMA_M, MMA_N / 2) - // if using m16n8k16 or ((2, 2, 1), MMA_M, MMA_N) if using m16n8k8. - cute::Tensor tOrP = make_tensor(rP.data(), flash::convert_layout_rowcol_Aregs(rP.layout())); - // if (Return_softmax) { - // cute::Tensor tOrP_copy = make_fragment_like(tOrP); - // copy(tOrP, tOrP_copy); - // flash::write_softmax_to_gmem(tOrP_copy, tPgP, gmem_thr_copy_P); - // tPgP.data() = tPgP.data() + (-kBlockN); - // } - - flash::gemm_A_in_regs(acc_o, tOrP, tOrVt, tOsVt, tiled_mma, smem_tiled_copy_V, smem_thr_copy_V); + clear(acc_o); + + flash::Softmax<2 * size<1>(acc_o)> softmax; + + const float alibi_slope = !Has_alibi || params.alibi_slopes_ptr == nullptr ? 0.0f : reinterpret_cast(params.alibi_slopes_ptr)[bidb * params.alibi_slopes_batch_stride + bidh] / params.scale_softmax; + flash::Mask mask(binfo.actual_seqlen_k, binfo.actual_seqlen_q, params.window_size_left, params.window_size_right, alibi_slope); + + // For performance reason, we separate out two kinds of iterations: + // those that need masking on S, and those that don't. + // We need masking on S for the very last block when K and V has length not multiple of kBlockN. + // We also need masking on S if it's causal, for the last ceil_div(kBlockM, kBlockN) blocks. + // We will have at least 1 "masking" iteration. + + // If not even_N, then seqlen_k might end in the middle of a block. In that case we need to + // mask 2 blocks (e.g. when kBlockM == kBlockN), not just 1. + constexpr int n_masking_steps = (!Is_causal && !Is_local) + ? 1 + : ((Is_even_MN && Is_causal) ? cute::ceil_div(kBlockM, kBlockN) : cute::ceil_div(kBlockM, kBlockN) + 1); + #pragma unroll + for (int masking_step = 0; masking_step < n_masking_steps; ++masking_step, --n_block) { + Tensor acc_s = partition_fragment_C(tiled_mma, Shape, Int>{}); // (MMA=4, MMA_M, MMA_N) + clear(acc_s); + flash::cp_async_wait<0>(); + __syncthreads(); + + // Advance gV + if (masking_step > 0) { + flash::copy(gmem_tiled_copy_QKV, tVgV(_, _, _, n_block), tVsV, tKVcKV, tKVpKV); + } else { + // Clear the smem tiles to account for predicated off loads + flash::copy( + gmem_tiled_copy_QKV, tVgV(_, _, _, n_block), tVsV, tKVcKV, tKVpKV, binfo.actual_seqlen_k - n_block * kBlockN + ); + } + cute::cp_async_fence(); + + flash::gemm( + acc_s, tSrQ, tSrK, tSsQ, tSsK, tiled_mma, smem_tiled_copy_Q, smem_tiled_copy_K, + smem_thr_copy_Q, smem_thr_copy_K + ); + // if (cute::thread0()) { print(acc_s); } + + mask.template apply_mask( + acc_s, n_block * kBlockN, m_block * kBlockM + (tidx / 32) * 16 + (tidx % 32) / 4, kNWarps * 16 + ); + + flash::cp_async_wait<0>(); + __syncthreads(); + if (n_block > n_block_min) { + flash::copy(gmem_tiled_copy_QKV, tKgK(_, _, _, n_block - 1), tKsK, tKVcKV, tKVpKV); + // This cp_async_fence needs to be in the if block, otherwise the synchronization + // isn't right and we get race conditions. + cute::cp_async_fence(); + } - // This check is at the end of the loop since we always have at least 1 iteration - if (n_masking_steps > 1 && n_block <= n_block_min) { - --n_block; - break; + // TODO: when we have key_padding_mask we'll need to Check_inf + masking_step == 0 + ? softmax.template softmax_rescale_o(acc_s, acc_o, params.scale_softmax_log2) + : softmax.template softmax_rescale_o(acc_s, acc_o, params.scale_softmax_log2); + + // Convert acc_s from fp32 to fp16/bf16 + Tensor rP = flash::convert_type(acc_s); + int block_row_idx = m_block * (kBlockM / 16) + tidx / 32; + int block_col_idx = n_block * (kBlockN / 32); +// if (Return_softmax) { +// Tensor rP_drop = make_fragment_like(rP); +// cute::copy(rP, rP_drop); +// dropout.template apply_dropout( +// rP_drop, block_row_idx, block_col_idx, kNWarps +// ); +// cute::copy(rP_drop, tSgS); +// tSgS.data() = tSgS.data() + (-kBlockN); +// } +// if (Is_dropout) { +// dropout.apply_dropout(rP, block_row_idx, block_col_idx, kNWarps); +// } + + // Reshape rP from (MMA=4, MMA_M, MMA_N) to ((4, 2), MMA_M, MMA_N / 2) + // if using m16n8k16 or (4, MMA_M, MMA_N) if using m16n8k8. + Tensor tOrP = make_tensor(rP.data(), flash::convert_layout_acc_Aregs(rP.layout())); + // if (cute::thread0()) { print(tOrP); } + flash::gemm_rs(acc_o, tOrP, tOrVt, tOsVt, tiled_mma, smem_tiled_copy_V, smem_thr_copy_V); + // if (cute::thread0()) { print(scores); } + + // This check is at the end of the loop since we always have at least 1 iteration + if (n_masking_steps > 1 && n_block <= n_block_min) { + --n_block; + break; + } } - } - - // These are the iterations where we don't need masking on S - for (; n_block >= n_block_min; --n_block) { - cute::Tensor acc_s = partition_fragment_C(tiled_mma, cute::Shape, cute::Int>{}); // (MMA=4, MMA_M, MMA_N) - clear(acc_s); - flash::cp_async_wait<0>(); - __syncthreads(); - // Advance gV - tVgV.data() = tVgV.data() + (-int(kBlockN * params.v_row_stride)); - flash::copy(gmem_tiled_copy_QKV, tVgV, tVsV, tKVcKV, tKVpKV); - cute::cp_async_fence(); - flash::gemm( - acc_s, tSrQ, tSrK, tSsQ, tSsK, tiled_mma, smem_tiled_copy_Q, smem_tiled_copy_K, - smem_thr_copy_Q, smem_thr_copy_K); + // These are the iterations where we don't need masking on S + for (; n_block >= n_block_min; --n_block) { + Tensor acc_s = partition_fragment_C(tiled_mma, Shape, Int>{}); // (MMA=4, MMA_M, MMA_N) + clear(acc_s); + flash::cp_async_wait<0>(); + __syncthreads(); + flash::copy(gmem_tiled_copy_QKV, tVgV(_, _, _, n_block), tVsV, tKVcKV, tKVpKV); + cute::cp_async_fence(); + + flash::gemm( + acc_s, tSrQ, tSrK, tSsQ, tSsK, tiled_mma, smem_tiled_copy_Q, smem_tiled_copy_K, + smem_thr_copy_Q, smem_thr_copy_K + ); + + flash::cp_async_wait<0>(); + __syncthreads(); + if (n_block > n_block_min) { + flash::copy(gmem_tiled_copy_QKV, tKgK(_, _, _, n_block - 1), tKsK, tKVcKV, tKVpKV); + // This cp_async_fence needs to be in the if block, otherwise the synchronization + // isn't right and we get race conditions. + cute::cp_async_fence(); + } - flash::cp_async_wait<0>(); - __syncthreads(); - if (n_block > n_block_min) { - // Advance gK - tKgK.data() = tKgK.data() + (-int(kBlockN * params.k_row_stride)); - flash::copy(gmem_tiled_copy_QKV, tKgK, tKsK, tKVcKV, tKVpKV); - // This cp_async_fence needs to be in the if block, otherwise the synchronization - // isn't right and we get race conditions. - cute::cp_async_fence(); + mask.template apply_mask( + acc_s, n_block * kBlockN, m_block * kBlockM + (tidx / 32) * 16 + (tidx % 32) / 4, kNWarps * 16 + ); + + softmax.template softmax_rescale_o(acc_s, acc_o, params.scale_softmax_log2); + + Tensor rP = flash::convert_type(acc_s); + int block_row_idx = m_block * (kBlockM / 16) + tidx / 32; + int block_col_idx = n_block * (kBlockN / 32); +// if (Return_softmax) { +// Tensor rP_drop = make_fragment_like(rP); +// cute::copy(rP, rP_drop); +// dropout.template apply_dropout( +// rP_drop, block_row_idx, block_col_idx, kNWarps +// ); +// cute::copy(rP_drop, tSgS); +// tSgS.data() = tSgS.data() + (-kBlockN); +// } +// if (Is_dropout) { +// dropout.apply_dropout(rP, block_row_idx, block_col_idx, kNWarps); +// } + + // Reshape rP from (MMA=4, MMA_M, MMA_N) to ((4, 2), MMA_M, MMA_N / 2) + // if using m16n8k16 or (4, MMA_M, MMA_N) if using m16n8k8. + Tensor tOrP = make_tensor(rP.data(), flash::convert_layout_acc_Aregs(rP.layout())); + flash::gemm_rs(acc_o, tOrP, tOrVt, tOsVt, tiled_mma, smem_tiled_copy_V, smem_thr_copy_V); } - // Reshape acc_s from (MMA=4, MMA_M, MMA_N) to (nrow=(2, MMA_M), ncol=(2, MMA_N)) - Tensor scores = make_tensor(acc_s.data(), flash::convert_layout_acc_rowcol(acc_s.layout())); - if (Is_local && n_block * kBlockN < (m_block + 1) * kBlockM + binfo.actual_seqlen_k - binfo.actual_seqlen_q + params.window_size_right) { - flash::apply_mask_local( - scores, n_block * kBlockN, binfo.actual_seqlen_k, - m_block * kBlockM + (tidx / 32) * 16 + (tidx % 32) / 4, - binfo.actual_seqlen_q, kNWarps * 16, - params.window_size_left, params.window_size_right); - } - softmax_rescale_o(scores, scores_max, scores_sum, acc_o, params.scale_softmax_log2); - - cute::Tensor rP = flash::convert_type(scores); - // Reshape rP from (nrow=(2, MMA_M), ncol=(2, MMA_N)) to ((2, 2, 2), MMA_M, MMA_N / 2) - // if using m16n8k16 or ((2, 2, 1), MMA_M, MMA_N) if using m16n8k8. - cute::Tensor tOrP = make_tensor(rP.data(), flash::convert_layout_rowcol_Aregs(rP.layout())); - // if (Return_softmax) { - // cute::Tensor tOrP_copy = make_fragment_like(tOrP); - // copy(tOrP, tOrP_copy); - // flash::write_softmax_to_gmem(tOrP_copy, tPgP, gmem_thr_copy_P); - // tPgP.data() = tPgP.data() + (-kBlockN); - // } + // Epilogue - flash::gemm_A_in_regs(acc_o, tOrP, tOrVt, tOsVt, tiled_mma, smem_tiled_copy_V, smem_thr_copy_V); - } + Tensor lse = softmax.template normalize_softmax_lse(acc_o, params.scale_softmax, params.rp_dropout); - // Epilogue + // Convert acc_o from fp32 to fp16/bf16 + Tensor rO = flash::convert_type(acc_o); + Tensor sO = make_tensor(sQ.data(), typename Kernel_traits::SmemLayoutO{}); // (SMEM_M,SMEM_N) + // Partition sO to match the accumulator partitioning + auto smem_tiled_copy_O = make_tiled_copy_C(typename Kernel_traits::SmemCopyAtomO{}, tiled_mma); + auto smem_thr_copy_O = smem_tiled_copy_O.get_thread_slice(tidx); + Tensor taccOrO = smem_thr_copy_O.retile_S(rO); // ((Atom,AtomNum), MMA_M, MMA_N) + Tensor taccOsO = smem_thr_copy_O.partition_D(sO); // ((Atom,AtomNum),PIPE_M,PIPE_N) - // Reshape acc_o from (MMA=4, MMA_M, MMA_K) to (nrow=(2, MMA_M), ncol=(2, MMA_K)) - cute::Tensor acc_o_rowcol = make_tensor(acc_o.data(), flash::convert_layout_acc_rowcol(acc_o.layout())); - cute::Tensor lse = make_fragment_like(scores_sum); -#pragma unroll - for (int mi = 0; mi < cute::size<0>(acc_o_rowcol); ++mi) { - float sum = scores_sum(mi); - float inv_sum = (sum == 0.f || sum != sum) ? 1.f : 1.f / sum; - lse(mi) = (sum == 0.f || sum != sum) ? INFINITY : scores_max(mi) * params.scale_softmax + __logf(sum); - float scale = inv_sum; -#pragma unroll - for (int ni = 0; ni < cute::size<1>(acc_o_rowcol); ++ni) { - acc_o_rowcol(mi, ni) *= scale; - } - } + // sO has the same size as sQ, so we don't need to sync here. + if (Kernel_traits::Share_Q_K_smem) { __syncthreads(); } - // Convert acc_o from fp32 to fp16/bf16 - cute::Tensor rO = flash::convert_type(acc_o); - cute::Tensor sO = make_tensor(sQ.data(), typename Kernel_traits::SmemLayoutO{}); // (SMEM_M,SMEM_N) - // Partition sO to match the accumulator partitioning - auto smem_tiled_copy_O = make_tiled_copy_C(typename Kernel_traits::SmemCopyAtomO{}, tiled_mma); - auto smem_thr_copy_O = smem_tiled_copy_O.get_thread_slice(tidx); // auto smem_thr_copy_O = make_tiled_copy_C_warpcontiguousM(typename Kernel_traits::SmemCopyAtomO{}, tiled_mma).get_thread_slice(tidx); - cute::Tensor taccOrO = smem_thr_copy_O.retile_S(rO); // ((Atom,AtomNum), MMA_M, MMA_N) - cute::Tensor taccOsO = smem_thr_copy_O.partition_D(sO); // ((Atom,AtomNum),PIPE_M,PIPE_N) - - // sO has the same size as sQ, so we don't need to sync here. - if (Kernel_traits::Share_Q_K_smem) { - __syncthreads(); - } + cute::copy(smem_tiled_copy_O, taccOrO, taccOsO); - cute::copy(smem_tiled_copy_O, taccOrO, taccOsO); + Tensor mO = make_tensor(make_gmem_ptr(reinterpret_cast(params.o_ptr) + + binfo.q_offset(params.o_batch_stride, params.o_row_stride, bidb)), + make_shape(binfo.actual_seqlen_q, params.h, params.d), + make_stride(params.o_row_stride, params.o_head_stride, _1{})); + Tensor gO = local_tile(mO(_, bidh, _), Shape, Int>{}, + make_coord(m_block, 0)); // (kBlockM, kHeadDim) + Tensor mLSE = make_tensor(make_gmem_ptr(reinterpret_cast(params.softmax_lse_ptr)), + make_shape(params.b, params.h, params.seqlen_q), + make_stride(params.h * params.seqlen_q, params.seqlen_q, _1{})); + Tensor gLSE = local_tile(mLSE(bidb, bidh, _), Shape>{}, make_coord(m_block)); - const index_t row_offset_o = binfo.q_offset(params.o_batch_stride, params.o_row_stride, bidb) + m_block * kBlockM * params.o_row_stride + bidh * params.o_head_stride; - const index_t row_offset_lse = (bidb * params.h + bidh) * params.seqlen_q + m_block * kBlockM; - cute::Tensor gO = make_tensor(make_gmem_ptr(reinterpret_cast(params.o_ptr) + row_offset_o), - cute::Shape, cute::Int>{}, - make_stride(params.o_row_stride, _1{})); - cute::Tensor gLSE = make_tensor(make_gmem_ptr(reinterpret_cast(params.softmax_lse_ptr) + row_offset_lse), - cute::Shape>{}, cute::Stride<_1>{}); - - typename Kernel_traits::GmemTiledCopyO gmem_tiled_copy_O; - auto gmem_thr_copy_O = gmem_tiled_copy_O.get_thread_slice(tidx); - cute::Tensor tOsO = gmem_thr_copy_O.partition_S(sO); // ((Atom,AtomNum),ATOM_M,ATOM_N) - cute::Tensor tOgO = gmem_thr_copy_O.partition_D(gO); - - __syncthreads(); + typename Kernel_traits::GmemTiledCopyO gmem_tiled_copy_O; + auto gmem_thr_copy_O = gmem_tiled_copy_O.get_thread_slice(tidx); + Tensor tOsO = gmem_thr_copy_O.partition_S(sO); // ((Atom,AtomNum),ATOM_M,ATOM_N) + Tensor tOgO = gmem_thr_copy_O.partition_D(gO); - cute::Tensor tOrO = make_tensor(cute::shape(tOgO)); - cute::copy(gmem_tiled_copy_O, tOsO, tOrO); + __syncthreads(); - cute::Tensor caccO = make_identity_tensor(cute::Shape, cute::Int>{}); // (BLK_M,BLK_K) -> (blk_m,blk_k) - cute::Tensor taccOcO = thr_mma.partition_C(caccO); // (MMA,MMA_M,MMA_K) - static_assert(decltype(cute::size<0>(taccOcO))::value == 4); - // Convert to ((2, 2), MMA_M, MMA_K) then take only the row indices. - cute::Tensor taccOcO_row = logical_divide(taccOcO, cute::Shape<_2>{})(make_coord(0, _), _, 0); - CUTE_STATIC_ASSERT_V(cute::size(lse) == cute::size(taccOcO_row)); // MMA_M - if (get<1>(taccOcO_row(0)) == 0) { -#pragma unroll - for (int mi = 0; mi < cute::size(lse); ++mi) { - const int row = get<0>(taccOcO_row(mi)); - if (row < binfo.actual_seqlen_q - m_block * kBlockM) { - gLSE(row) = lse(mi); - } + Tensor tOrO = make_tensor(shape(tOgO)); + cute::copy(gmem_tiled_copy_O, tOsO, tOrO); + + Tensor caccO = make_identity_tensor(Shape, Int>{}); // (BLK_M,BLK_K) -> (blk_m,blk_k) + Tensor taccOcO = thr_mma.partition_C(caccO); // (MMA,MMA_M,MMA_K) + static_assert(decltype(size<0>(taccOcO))::value == 4); + // Convert to ((2, 2), MMA_M, MMA_K) then take only the row indices. + Tensor taccOcO_row = logical_divide(taccOcO, Shape<_2>{})(make_coord(0, _), _, 0); + CUTE_STATIC_ASSERT_V(size(lse) == size(taccOcO_row)); // MMA_M + if (get<1>(taccOcO_row(0)) == 0) { + #pragma unroll + for (int mi = 0; mi < size(lse); ++mi) { + const int row = get<0>(taccOcO_row(mi)); + if (row < binfo.actual_seqlen_q - m_block * kBlockM) { gLSE(row) = lse(mi); } + } } - } - // Construct identity layout for sO - cute::Tensor cO = make_identity_tensor(make_shape(cute::size<0>(sO), cute::size<1>(sO))); // (BLK_M,BLK_K) -> (blk_m,blk_k) - // Repeat the partitioning with identity layouts - cute::Tensor tOcO = gmem_thr_copy_O.partition_D(cO); // (ACPY,ACPY_M,ACPY_K) -> (blk_m,blk_k) - cute::Tensor tOpO = make_tensor(make_shape(cute::size<2>(tOgO))); - if (!Is_even_K) { -#pragma unroll - for (int k = 0; k < cute::size(tOpO); ++k) { - tOpO(k) = get<1>(tOcO(0, 0, k)) < params.d; + // Construct identity layout for sO + Tensor cO = make_identity_tensor(make_shape(size<0>(sO), size<1>(sO))); // (BLK_M,BLK_K) -> (blk_m,blk_k) + // Repeat the partitioning with identity layouts + Tensor tOcO = gmem_thr_copy_O.partition_D(cO); // (ACPY,ACPY_M,ACPY_K) -> (blk_m,blk_k) + Tensor tOpO = make_tensor(make_shape(size<2>(tOgO))); + if (!Is_even_K) { + #pragma unroll + for (int k = 0; k < size(tOpO); ++k) { tOpO(k) = get<1>(tOcO(0, 0, k)) < params.d; } } - } - // Clear_OOB_K must be false since we don't want to write zeros to gmem - flash::copy( - gmem_tiled_copy_O, tOrO, tOgO, tOcO, tOpO, binfo.actual_seqlen_q - m_block * kBlockM); + // Clear_OOB_K must be false since we don't want to write zeros to gmem + flash::copy( + gmem_tiled_copy_O, tOrO, tOgO, tOcO, tOpO, binfo.actual_seqlen_q - m_block * kBlockM + ); } //////////////////////////////////////////////////////////////////////////////////////////////////// -template +template inline __device__ void compute_attn_1rowblock_splitkv(const Params& params, const int bidb, const int bidh, const int m_block, const int n_split_idx, const int num_n_splits) { - using Element = typename Kernel_traits::Element; - using ElementAccum = typename Kernel_traits::ElementAccum; - using index_t = typename Kernel_traits::index_t; - - // Shared memory. - extern __shared__ char smem_[]; - - // The thread index. - const int tidx = threadIdx.x; - - constexpr int kBlockM = Kernel_traits::kBlockM; - constexpr int kBlockN = Kernel_traits::kBlockN; - constexpr int kHeadDim = Kernel_traits::kHeadDim; - constexpr int kNWarps = Kernel_traits::kNWarps; - - using GmemTiledCopyO = std::conditional_t< - !Split, - typename Kernel_traits::GmemTiledCopyOaccum, - typename Kernel_traits::GmemTiledCopyO>; - using ElementO = std::conditional_t; - - const BlockInfo binfo(params, bidb); - // if (threadIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0) { printf("Is_even_MN = %d, is_cumulativ = %d, seqlen_k_cache = %d, actual_seqlen_k = %d\n", Is_even_MN, params.is_seqlens_k_cumulative, binfo.seqlen_k_cache, binfo.actual_seqlen_k); } - // if (threadIdx.x == 0 && blockIdx.y == 1 && blockIdx.z == 0) { printf("params.knew_ptr = %p, seqlen_k_cache + seqlen_knew = %d\n", params.knew_ptr, binfo.seqlen_k_cache + (params.knew_ptr == nullptr ? 0 : params.seqlen_knew)); } - if (m_block * kBlockM >= binfo.actual_seqlen_q) return; - - const int n_blocks_per_split = ((params.seqlen_k + kBlockN - 1) / kBlockN + num_n_splits - 1) / num_n_splits; - const int n_block_min = !Is_local - ? n_split_idx * n_blocks_per_split - : std::max(n_split_idx * n_blocks_per_split, (m_block * kBlockM + binfo.actual_seqlen_k - binfo.actual_seqlen_q - params.window_size_left) / kBlockN); - int n_block_max = std::min(cute::ceil_div(binfo.actual_seqlen_k, kBlockN), (n_split_idx + 1) * n_blocks_per_split); - if (Is_causal || Is_local) { - n_block_max = std::min(n_block_max, - cute::ceil_div((m_block + 1) * kBlockM + binfo.actual_seqlen_k - binfo.actual_seqlen_q + params.window_size_right, kBlockN)); - } - if (n_block_min >= n_block_max) { // This also covers the case where n_block_max <= 0 - // We exit early and write 0 to gOaccum and -inf to gLSEaccum. - // Otherwise we might read OOB elements from gK and gV, - // or get wrong results when we combine gOaccum from different blocks. - const index_t row_offset_o = binfo.q_offset(params.o_batch_stride, params.o_row_stride, bidb) + m_block * kBlockM * params.o_row_stride + bidh * params.o_head_stride; - const index_t row_offset_oaccum = (((n_split_idx * params.b + bidb) * params.h + bidh) * params.seqlen_q + m_block * kBlockM) * params.d_rounded; - const index_t row_offset_lseaccum = ((n_split_idx * params.b + bidb) * params.h + bidh) * params.seqlen_q + m_block * kBlockM; - Tensor gOaccum = make_tensor(make_gmem_ptr(reinterpret_cast(Split ? params.oaccum_ptr : params.o_ptr) + (Split ? row_offset_oaccum : row_offset_o)), - Shape, Int>{}, - make_stride(Split ? kHeadDim : params.o_row_stride, _1{})); - Tensor gLSEaccum = make_tensor(make_gmem_ptr(reinterpret_cast(Split ? params.softmax_lseaccum_ptr : params.softmax_lse_ptr) + row_offset_lseaccum), - Shape>{}, Stride<_1>{}); - - GmemTiledCopyO gmem_tiled_copy_Oaccum; - auto gmem_thr_copy_Oaccum = gmem_tiled_copy_Oaccum.get_thread_slice(tidx); - Tensor tOgOaccum = gmem_thr_copy_Oaccum.partition_D(gOaccum); - Tensor tOrOaccum = make_tensor(shape(tOgOaccum)); - clear(tOrOaccum); - // Construct identity layout for sO - Tensor cO = make_identity_tensor(make_shape(size<0>(gOaccum), size<1>(gOaccum))); // (BLK_M,BLK_K) -> (blk_m,blk_k) - // Repeat the partitioning with identity layouts - Tensor tOcO = gmem_thr_copy_Oaccum.partition_D(cO); - Tensor tOpO = make_tensor(make_shape(size<2>(tOgOaccum))); - if (!Is_even_K) { -#pragma unroll - for (int k = 0; k < size(tOpO); ++k) { - tOpO(k) = get<1>(tOcO(0, 0, k)) < params.d; - } + using Element = typename Kernel_traits::Element; + using ElementAccum = typename Kernel_traits::ElementAccum; + using index_t = typename Kernel_traits::index_t; + + // Shared memory. + extern __shared__ char smem_[]; + + // The thread index. + const int tidx = threadIdx.x; + + constexpr int kBlockM = Kernel_traits::kBlockM; + constexpr int kBlockN = Kernel_traits::kBlockN; + constexpr int kHeadDim = Kernel_traits::kHeadDim; + constexpr int kNWarps = Kernel_traits::kNWarps; + + using GmemTiledCopyO = std::conditional_t< + !Split, + typename Kernel_traits::GmemTiledCopyO, + typename Kernel_traits::GmemTiledCopyOaccum + >; + using ElementO = std::conditional_t; + + const BlockInfo binfo(params, bidb); + // if (threadIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0) { printf("Is_even_MN = %d, is_cumulativ = %d, seqlen_k_cache = %d, actual_seqlen_k = %d\n", Is_even_MN, params.is_seqlens_k_cumulative, binfo.seqlen_k_cache, binfo.actual_seqlen_k); } + // if (threadIdx.x == 0 && blockIdx.y == 1 && blockIdx.z == 0) { printf("params.knew_ptr = %p, seqlen_k_cache + seqlen_knew = %d\n", params.knew_ptr, binfo.seqlen_k_cache + (params.knew_ptr == nullptr ? 0 : params.seqlen_knew)); } + if (m_block * kBlockM >= binfo.actual_seqlen_q) return; + + const int n_blocks_per_split = ((params.seqlen_k + kBlockN - 1) / kBlockN + num_n_splits - 1) / num_n_splits; + const int n_block_min = !Is_local + ? n_split_idx * n_blocks_per_split + : std::max(n_split_idx * n_blocks_per_split, (m_block * kBlockM + binfo.actual_seqlen_k - binfo.actual_seqlen_q - params.window_size_left) / kBlockN); + int n_block_max = std::min(cute::ceil_div(binfo.actual_seqlen_k, kBlockN), (n_split_idx + 1) * n_blocks_per_split); + if (Is_causal || Is_local) { + n_block_max = std::min(n_block_max, + cute::ceil_div((m_block + 1) * kBlockM + binfo.actual_seqlen_k - binfo.actual_seqlen_q + params.window_size_right, kBlockN)); } - // Clear_OOB_K must be false since we don't want to write zeros to gmem - flash::copy( - gmem_tiled_copy_Oaccum, tOrOaccum, tOgOaccum, tOcO, tOpO, binfo.actual_seqlen_q - m_block * kBlockM); -#pragma unroll - for (int m = 0; m < size<1>(tOgOaccum); ++m) { - const int row = get<0>(tOcO(0, m, 0)); - if (row < binfo.actual_seqlen_q - m_block * kBlockM && get<1>(tOcO(0, m, 0)) == 0) { - gLSEaccum(row) = Split ? -INFINITY : INFINITY; - } + if (n_block_min >= n_block_max) { // This also covers the case where n_block_max <= 0 + // We exit early and write 0 to gOaccum and -inf to gLSEaccum. + // Otherwise we might read OOB elements from gK and gV, + // or get wrong results when we combine gOaccum from different blocks. + const index_t row_offset_o = binfo.q_offset(params.o_batch_stride, params.o_row_stride, bidb) + + m_block * kBlockM * params.o_row_stride + bidh * params.o_head_stride; + const index_t row_offset_oaccum = (((n_split_idx * params.b + bidb) * params.h + bidh) * params.seqlen_q + + m_block * kBlockM) * params.d_rounded; + const index_t row_offset_lseaccum = ((n_split_idx * params.b + bidb) * params.h + bidh) * params.seqlen_q + m_block * kBlockM; + Tensor gOaccum = make_tensor(make_gmem_ptr(reinterpret_cast(Split ? params.oaccum_ptr : params.o_ptr) + (Split ? row_offset_oaccum : row_offset_o)), + Shape, Int>{}, + make_stride(Split ? kHeadDim : params.o_row_stride, _1{})); + Tensor gLSEaccum = make_tensor(make_gmem_ptr(reinterpret_cast(Split ? params.softmax_lseaccum_ptr : params.softmax_lse_ptr) + row_offset_lseaccum), + Shape>{}, Stride<_1>{}); + + GmemTiledCopyO gmem_tiled_copy_Oaccum; + auto gmem_thr_copy_Oaccum = gmem_tiled_copy_Oaccum.get_thread_slice(tidx); + Tensor tOgOaccum = gmem_thr_copy_Oaccum.partition_D(gOaccum); + Tensor tOrOaccum = make_tensor(shape(tOgOaccum)); + clear(tOrOaccum); + // Construct identity layout for sO + Tensor cO = make_identity_tensor(make_shape(size<0>(gOaccum), size<1>(gOaccum))); // (BLK_M,BLK_K) -> (blk_m,blk_k) + // Repeat the partitioning with identity layouts + Tensor tOcO = gmem_thr_copy_Oaccum.partition_D(cO); + Tensor tOpO = make_tensor(make_shape(size<2>(tOgOaccum))); + if (!Is_even_K) { + #pragma unroll + for (int k = 0; k < size(tOpO); ++k) { tOpO(k) = get<1>(tOcO(0, 0, k)) < params.d; } + } + // Clear_OOB_K must be false since we don't want to write zeros to gmem + flash::copy( + gmem_tiled_copy_Oaccum, tOrOaccum, tOgOaccum, tOcO, tOpO, binfo.actual_seqlen_q - m_block * kBlockM + ); + #pragma unroll + for (int m = 0; m < size<1>(tOgOaccum); ++m) { + const int row = get<0>(tOcO(0, m, 0)); + if (row < binfo.actual_seqlen_q - m_block * kBlockM && get<1>(tOcO(0, m, 0)) == 0) { gLSEaccum(row) = Split ? -INFINITY : INFINITY; } + } + return; } - return; - } - // We iterate over the blocks in reverse order. This is because the last block is the only one - // that needs masking when we read K and V from global memory. Moreover, iterating in reverse - // might save us 1 register (we just need n_block instead of both n_block and n_block_max). - - const index_t row_offset_q = binfo.q_offset(params.q_batch_stride, params.q_row_stride, bidb) + m_block * kBlockM * params.q_row_stride + bidh * params.q_head_stride; - // We move K and V to the last block. - const int bidb_cache = params.cache_batch_idx == nullptr ? bidb : params.cache_batch_idx[bidb]; - const index_t row_offset_k = binfo.k_offset(params.k_batch_stride, params.k_row_stride, bidb_cache) + (n_block_max - 1) * kBlockN * params.k_row_stride + (bidh / params.h_h_k_ratio) * params.k_head_stride; - const index_t row_offset_v = binfo.k_offset(params.v_batch_stride, params.v_row_stride, bidb_cache) + (n_block_max - 1) * kBlockN * params.v_row_stride + (bidh / params.h_h_k_ratio) * params.v_head_stride; - - Tensor gQ = make_tensor(make_gmem_ptr(reinterpret_cast(params.q_ptr) + row_offset_q), - Shape, Int>{}, - make_stride(params.q_row_stride, _1{})); - Tensor gK = make_tensor(make_gmem_ptr(reinterpret_cast(params.k_ptr) + row_offset_k), - Shape, Int>{}, - make_stride(params.k_row_stride, _1{})); - // if (threadIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0) { printf("k_ptr = %p, row_offset_k = %d, gK_ptr = %p\n", params.k_ptr, row_offset_k, gK.data()); } - Tensor gV = make_tensor(make_gmem_ptr(reinterpret_cast(params.v_ptr) + row_offset_v), - Shape, Int>{}, - make_stride(params.v_row_stride, _1{})); - - Tensor sQ = make_tensor(make_smem_ptr(reinterpret_cast(smem_)), - typename Kernel_traits::SmemLayoutQ{}); - Tensor sK = make_tensor(sQ.data() + size(sQ), typename Kernel_traits::SmemLayoutKV{}); - Tensor sV = make_tensor(sK.data() + size(sK), typename Kernel_traits::SmemLayoutKV{}); - Tensor sVt = make_tensor(sV.data(), typename Kernel_traits::SmemLayoutVtransposed{}); - Tensor sVtNoSwizzle = make_tensor(sV.data(), typename Kernel_traits::SmemLayoutVtransposedNoSwizzle{}); - - typename Kernel_traits::GmemTiledCopyQKV gmem_tiled_copy_QKV; - auto gmem_thr_copy_QKV = gmem_tiled_copy_QKV.get_thread_slice(tidx); - - Tensor tQgQ = gmem_thr_copy_QKV.partition_S(gQ); - Tensor tQsQ = gmem_thr_copy_QKV.partition_D(sQ); - Tensor tKgK = gmem_thr_copy_QKV.partition_S(gK); // (KCPY, KCPY_N, KCPY_K) - Tensor tKsK = gmem_thr_copy_QKV.partition_D(sK); - Tensor tVgV = gmem_thr_copy_QKV.partition_S(gV); // (VCPY, VCPY_N, VCPY_K) - Tensor tVsV = gmem_thr_copy_QKV.partition_D(sV); - - typename Kernel_traits::TiledMma tiled_mma; - auto thr_mma = tiled_mma.get_thread_slice(tidx); - Tensor tSrQ = thr_mma.partition_fragment_A(sQ); // (MMA,MMA_M,MMA_K) - Tensor tSrK = thr_mma.partition_fragment_B(sK); // (MMA,MMA_N,MMA_K) - Tensor tOrVt = thr_mma.partition_fragment_B(sVtNoSwizzle); // (MMA, MMA_K,MMA_N) - - Tensor acc_o = partition_fragment_C(tiled_mma, Shape, Int>{}); // MMA, MMA_M, MMA_K - - // - // Copy Atom retiling - // - - auto smem_tiled_copy_Q = make_tiled_copy_A(typename Kernel_traits::SmemCopyAtom{}, tiled_mma); - auto smem_thr_copy_Q = smem_tiled_copy_Q.get_thread_slice(tidx); - Tensor tSsQ = smem_thr_copy_Q.partition_S(sQ); - - auto smem_tiled_copy_K = make_tiled_copy_B(typename Kernel_traits::SmemCopyAtom{}, tiled_mma); - auto smem_thr_copy_K = smem_tiled_copy_K.get_thread_slice(tidx); - Tensor tSsK = smem_thr_copy_K.partition_S(sK); - - auto smem_tiled_copy_V = make_tiled_copy_B(typename Kernel_traits::SmemCopyAtomTransposed{}, tiled_mma); - auto smem_thr_copy_V = smem_tiled_copy_V.get_thread_slice(tidx); - Tensor tOsVt = smem_thr_copy_V.partition_S(sVt); - - // TODO: this might need to change if we change the mma instruction in SM70 - Tensor scores_max = make_tensor(Shape(acc_o)>>{}); - Tensor scores_sum = make_fragment_like(scores_max); - - // - // PREDICATES - // - - // // Allocate predicate tensors for m and n - // Tensor tQpQ = make_tensor(make_shape(size<1>(tQsQ), size<2>(tQsQ)), Stride<_1,_0>{}); - // Tensor tKVpKV = make_tensor(make_shape(size<1>(tKsK), size<2>(tKsK)), Stride<_1,_0>{}); - - // Construct identity layout for sQ and sK - Tensor cQ = make_identity_tensor(make_shape(size<0>(sQ), size<1>(sQ))); // (BLK_M,BLK_K) -> (blk_m,blk_k) - Tensor cKV = make_identity_tensor(make_shape(size<0>(sK), size<1>(sK))); // (BLK_N,BLK_K) -> (blk_n,blk_k) + // We iterate over the blocks in reverse order. This is because the last block is the only one + // that needs masking when we read K and V from global memory. Moreover, iterating in reverse + // might save us 1 register (we just need n_block instead of both n_block and n_block_max). + + // We move K and V to the last block. + const int bidb_cache = params.cache_batch_idx == nullptr ? bidb : params.cache_batch_idx[bidb]; + const int *block_table = params.block_table == nullptr ? nullptr : params.block_table + bidb * params.block_table_batch_stride; + const int block_table_idx = block_table == nullptr ? 0 : (n_block_max - 1) * kBlockN / params.page_block_size; + const int block_table_offset = block_table == nullptr ? 0 : (n_block_max - 1) * kBlockN - block_table_idx * params.page_block_size; + const index_t row_offset_k = block_table == nullptr + ? binfo.k_offset(params.k_batch_stride, params.k_row_stride, bidb_cache) + + (n_block_max - 1) * kBlockN * params.k_row_stride + (bidh / params.h_h_k_ratio) * params.k_head_stride + : block_table[block_table_idx] * params.k_batch_stride + block_table_offset * params.k_row_stride + (bidh / params.h_h_k_ratio) * params.k_head_stride; + const index_t row_offset_v = block_table == nullptr + ? binfo.k_offset(params.v_batch_stride, params.v_row_stride, bidb_cache) + + (n_block_max - 1) * kBlockN * params.v_row_stride + (bidh / params.h_h_k_ratio) * params.v_head_stride + : block_table[block_table_idx] * params.v_batch_stride + block_table_offset * params.v_row_stride + (bidh / params.h_h_k_ratio) * params.v_head_stride; + + Tensor mQ = make_tensor(make_gmem_ptr(reinterpret_cast(params.q_ptr) + binfo.q_offset(params.q_batch_stride, params.q_row_stride, bidb)), + make_shape(binfo.actual_seqlen_q, params.h, params.d), + make_stride(params.q_row_stride, params.q_head_stride, _1{})); + Tensor gQ = local_tile(mQ(_, bidh, _), Shape, Int>{}, + make_coord(m_block, 0)); // (kBlockM, kHeadDim) + Tensor gK = make_tensor(make_gmem_ptr(reinterpret_cast(params.k_ptr) + row_offset_k), + Shape, Int>{}, + make_stride(params.k_row_stride, _1{})); + // if (threadIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0) { printf("k_ptr = %p, row_offset_k = %d, gK_ptr = %p\n", params.k_ptr, row_offset_k, gK.data()); } + Tensor gV = make_tensor(make_gmem_ptr(reinterpret_cast(params.v_ptr) + row_offset_v), + Shape, Int>{}, + make_stride(params.v_row_stride, _1{})); + + Tensor sQ = make_tensor(make_smem_ptr(reinterpret_cast(smem_)), + typename Kernel_traits::SmemLayoutQ{}); + Tensor sK = make_tensor(sQ.data() + size(sQ), typename Kernel_traits::SmemLayoutKV{}); + Tensor sV = make_tensor(sK.data() + size(sK), typename Kernel_traits::SmemLayoutKV{}); + Tensor sVt = make_tensor(sV.data(), typename Kernel_traits::SmemLayoutVtransposed{}); + Tensor sVtNoSwizzle = make_tensor(sV.data(), typename Kernel_traits::SmemLayoutVtransposedNoSwizzle{}); + + typename Kernel_traits::GmemTiledCopyQKV gmem_tiled_copy_QKV; + auto gmem_thr_copy_QKV = gmem_tiled_copy_QKV.get_thread_slice(tidx); + + Tensor tQgQ = gmem_thr_copy_QKV.partition_S(gQ); + Tensor tQsQ = gmem_thr_copy_QKV.partition_D(sQ); + Tensor tKgK = gmem_thr_copy_QKV.partition_S(gK); // (KCPY, KCPY_N, KCPY_K) + Tensor tKsK = gmem_thr_copy_QKV.partition_D(sK); + Tensor tVgV = gmem_thr_copy_QKV.partition_S(gV); // (VCPY, VCPY_N, VCPY_K) + Tensor tVsV = gmem_thr_copy_QKV.partition_D(sV); + + typename Kernel_traits::TiledMma tiled_mma; + auto thr_mma = tiled_mma.get_thread_slice(tidx); + Tensor tSrQ = thr_mma.partition_fragment_A(sQ); // (MMA,MMA_M,MMA_K) + Tensor tSrK = thr_mma.partition_fragment_B(sK); // (MMA,MMA_N,MMA_K) + Tensor tOrVt = thr_mma.partition_fragment_B(sVtNoSwizzle); // (MMA, MMA_K,MMA_N) + + Tensor acc_o = partition_fragment_C(tiled_mma, Shape, Int>{}); // MMA, MMA_M, MMA_K + + // + // Copy Atom retiling + // + + auto smem_tiled_copy_Q = make_tiled_copy_A(typename Kernel_traits::SmemCopyAtom{}, tiled_mma); + auto smem_thr_copy_Q = smem_tiled_copy_Q.get_thread_slice(tidx); + Tensor tSsQ = smem_thr_copy_Q.partition_S(sQ); + + auto smem_tiled_copy_K = make_tiled_copy_B(typename Kernel_traits::SmemCopyAtom{}, tiled_mma); + auto smem_thr_copy_K = smem_tiled_copy_K.get_thread_slice(tidx); + Tensor tSsK = smem_thr_copy_K.partition_S(sK); + + auto smem_tiled_copy_V = make_tiled_copy_B(typename Kernel_traits::SmemCopyAtomTransposed{}, tiled_mma); + auto smem_thr_copy_V = smem_tiled_copy_V.get_thread_slice(tidx); + Tensor tOsVt = smem_thr_copy_V.partition_S(sVt); + + // PREDICATES + // + + // // Allocate predicate tensors for m and n + // Tensor tQpQ = make_tensor(make_shape(size<1>(tQsQ), size<2>(tQsQ)), Stride<_1,_0>{}); + // Tensor tKVpKV = make_tensor(make_shape(size<1>(tKsK), size<2>(tKsK)), Stride<_1,_0>{}); + + // Construct identity layout for sQ and sK + Tensor cQ = make_identity_tensor(make_shape(size<0>(sQ), size<1>(sQ))); // (BLK_M,BLK_K) -> (blk_m,blk_k) + Tensor cKV = make_identity_tensor(make_shape(size<0>(sK), size<1>(sK))); // (BLK_N,BLK_K) -> (blk_n,blk_k) - // Repeat the partitioning with identity layouts - Tensor tQcQ = gmem_thr_copy_QKV.partition_S(cQ); // (ACPY,ACPY_M,ACPY_K) -> (blk_m,blk_k) - Tensor tKVcKV = gmem_thr_copy_QKV.partition_S(cKV); // (BCPY,BCPY_N,BCPY_K) -> (blk_n,blk_k) + // Repeat the partitioning with identity layouts + Tensor tQcQ = gmem_thr_copy_QKV.partition_S(cQ); // (ACPY,ACPY_M,ACPY_K) -> (blk_m,blk_k) + Tensor tKVcKV = gmem_thr_copy_QKV.partition_S(cKV); // (BCPY,BCPY_N,BCPY_K) -> (blk_n,blk_k) - // Allocate predicate tensors for k - Tensor tQpQ = make_tensor(make_shape(size<2>(tQsQ))); - Tensor tKVpKV = make_tensor(make_shape(size<2>(tKsK))); + // Allocate predicate tensors for k + Tensor tQpQ = make_tensor(make_shape(size<2>(tQsQ))); + Tensor tKVpKV = make_tensor(make_shape(size<2>(tKsK))); - // Set predicates for k bounds - if (!Is_even_K) { -#pragma unroll - for (int k = 0; k < size(tQpQ); ++k) { - tQpQ(k) = get<1>(tQcQ(0, 0, k)) < params.d; - } -#pragma unroll - for (int k = 0; k < size(tKVpKV); ++k) { - tKVpKV(k) = get<1>(tKVcKV(0, 0, k)) < params.d; + // Set predicates for k bounds + if (!Is_even_K) { + #pragma unroll + for (int k = 0; k < size(tQpQ); ++k) { tQpQ(k) = get<1>(tQcQ(0, 0, k)) < params.d; } + #pragma unroll + for (int k = 0; k < size(tKVpKV); ++k) { tKVpKV(k) = get<1>(tKVcKV(0, 0, k)) < params.d; } } - } - // Prologue - // Copy from Knew to K, optionally apply rotary embedding. - typename Kernel_traits::GmemTiledCopyRotcossin gmem_tiled_copy_rotary; - auto gmem_thr_copy_rotary = gmem_tiled_copy_rotary.get_thread_slice(tidx); - typename Kernel_traits::GmemTiledCopyRotcossinCont gmem_tiled_copy_rotary_cont; - auto gmem_thr_copy_rotary_cont = gmem_tiled_copy_rotary_cont.get_thread_slice(tidx); - if constexpr (Append_KV) { - // Even if we have MQA / GQA, all threadblocks responsible for the same KV head are writing to - // gmem. Technically it's a race condition, but they all write the same content anyway, and it's safe. - // We want to do this so that all threadblocks can proceed right after they finish writing the KV cache. - const index_t row_offset_cossin = ((n_block_max - 1) * kBlockN) * (params.rotary_dim / 2); - Tensor gCos = make_tensor(make_gmem_ptr(reinterpret_cast(params.rotary_cos_ptr) + row_offset_cossin), - Shape, Int>{}, - make_stride(params.rotary_dim / 2, _1{})); - Tensor gSin = make_tensor(make_gmem_ptr(reinterpret_cast(params.rotary_sin_ptr) + row_offset_cossin), - Shape, Int>{}, - make_stride(params.rotary_dim / 2, _1{})); - Tensor gCosCont = make_tensor(make_gmem_ptr(reinterpret_cast(params.rotary_cos_ptr) + row_offset_cossin), - Shape, Int>{}, + // Prologue + + // Copy from Knew to K, optionally apply rotary embedding. + typename Kernel_traits::GmemTiledCopyRotcossin gmem_tiled_copy_rotary; + auto gmem_thr_copy_rotary = gmem_tiled_copy_rotary.get_thread_slice(tidx); + typename Kernel_traits::GmemTiledCopyRotcossinCont gmem_tiled_copy_rotary_cont; + auto gmem_thr_copy_rotary_cont = gmem_tiled_copy_rotary_cont.get_thread_slice(tidx); + if constexpr (Append_KV) { + // Even if we have MQA / GQA, all threadblocks responsible for the same KV head are writing to + // gmem. Technically it's a race condition, but they all write the same content anyway, and it's safe. + // We want to do this so that all threadblocks can proceed right after they finish writing the KV cache. + const index_t row_offset_cossin = ((n_block_max - 1) * kBlockN) * (params.rotary_dim / 2); + Tensor gCos = make_tensor(make_gmem_ptr(reinterpret_cast(params.rotary_cos_ptr) + row_offset_cossin), + Shape, Int>{}, make_stride(params.rotary_dim / 2, _1{})); - Tensor gSinCont = make_tensor(make_gmem_ptr(reinterpret_cast(params.rotary_sin_ptr) + row_offset_cossin), - Shape, Int>{}, + Tensor gSin = make_tensor(make_gmem_ptr(reinterpret_cast(params.rotary_sin_ptr) + row_offset_cossin), + Shape, Int>{}, make_stride(params.rotary_dim / 2, _1{})); - Tensor tRgCos = gmem_thr_copy_rotary.partition_S(gCos); - Tensor tRgSin = gmem_thr_copy_rotary.partition_S(gSin); - Tensor tRgCosCont = gmem_thr_copy_rotary_cont.partition_S(gCosCont); - Tensor tRgSinCont = gmem_thr_copy_rotary_cont.partition_S(gSinCont); - // if (cute::thread(0, 0)) { printf("rotary_cos_ptr = %p, gCos.data() = %p, tRgCos.data() = %p, rotary_dim = %d\n", params.rotary_cos_ptr, gCos.data(), tRgCos.data(), params.rotary_dim); } - // if (cute::thread(8, 0)) { print_tensor(gCos); } - // if (cute::thread(0, 0)) { print_tensor(tRgCos); } - - const index_t row_offset_knew = binfo.k_offset(params.knew_batch_stride, params.knew_row_stride, bidb) + ((n_block_max - 1) * kBlockN) * params.knew_row_stride + (bidh / params.h_h_k_ratio) * params.knew_head_stride; - const index_t row_offset_vnew = binfo.k_offset(params.vnew_batch_stride, params.vnew_row_stride, bidb) + ((n_block_max - 1) * kBlockN) * params.vnew_row_stride + (bidh / params.h_h_k_ratio) * params.vnew_head_stride; - // Subtract seqlen_k_cache * row stride so that conceptually gK and gKnew "line up". When we access them, - // e.g. if gK has 128 rows and gKnew has 64 rows, we access gK[:128] and gKNew[128:128 + 64]. - // This maps to accessing the first 64 rows of knew_ptr. - Tensor gKnew = make_tensor(make_gmem_ptr(reinterpret_cast(params.knew_ptr) + row_offset_knew - binfo.seqlen_k_cache * params.knew_row_stride), - Shape, Int>{}, - make_stride(params.knew_row_stride, _1{})); - // if (threadIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0) { printf("knew_ptr = %p, row_offset_knew = %d, gKnew_ptr = %p\n", params.knew_ptr, row_offset_knew, gKnew.data()); } - Tensor gVnew = make_tensor(make_gmem_ptr(reinterpret_cast(params.vnew_ptr) + row_offset_vnew - binfo.seqlen_k_cache * params.vnew_row_stride), - Shape, Int>{}, - make_stride(params.vnew_row_stride, _1{})); - Tensor tKgKnew = gmem_thr_copy_QKV.partition_S(gKnew); // (KCPY, KCPY_N, KCPY_K) - Tensor tVgVnew = gmem_thr_copy_QKV.partition_S(gVnew); // (VCPY, VCPY_N, VCPY_K) - - const int n_block_copy_min = std::max(n_block_min, binfo.seqlen_k_cache / kBlockN); - for (int n_block = n_block_max - 1; n_block >= n_block_copy_min; n_block--) { - flash::copy_w_min_idx( - tVgVnew, tVgV, tKVcKV, tKVpKV, binfo.actual_seqlen_k - n_block * kBlockN, binfo.seqlen_k_cache - n_block * kBlockN); - tVgV.data() = tVgV.data() + (-int(kBlockN * params.v_row_stride)); - tVgVnew.data() = tVgVnew.data() + (-int(kBlockN * params.vnew_row_stride)); - if (params.rotary_dim == 0) { - flash::copy_w_min_idx( - tKgKnew, tKgK, tKVcKV, tKVpKV, binfo.actual_seqlen_k - n_block * kBlockN, binfo.seqlen_k_cache - n_block * kBlockN); - } else { - if (params.is_rotary_interleaved) { - // Don't clear OOB_K because we're writing to global memory - flash::copy_rotary_interleaved( - tKgKnew, tKgK, tRgCos, tRgSin, tKVcKV, binfo.actual_seqlen_k - n_block * kBlockN, - binfo.seqlen_k_cache - n_block * kBlockN, params.d, params.rotary_dim); - tRgCos.data() = tRgCos.data() + (-int(kBlockN * params.rotary_dim / 2)); - tRgSin.data() = tRgSin.data() + (-int(kBlockN * params.rotary_dim / 2)); - } else { - // Don't clear OOB_K because we're writing to global memory - flash::copy_rotary_contiguous( - tKgKnew, tKgK, tRgCosCont, tRgSinCont, tKVcKV, binfo.actual_seqlen_k - n_block * kBlockN, - binfo.seqlen_k_cache - n_block * kBlockN, params.d, params.rotary_dim); - tRgCosCont.data() = tRgCosCont.data() + (-int(kBlockN * params.rotary_dim / 2)); - tRgSinCont.data() = tRgSinCont.data() + (-int(kBlockN * params.rotary_dim / 2)); + Tensor gCosCont = make_tensor(make_gmem_ptr(reinterpret_cast(params.rotary_cos_ptr) + row_offset_cossin), + Shape, Int>{}, + make_stride(params.rotary_dim / 2, _1{})); + Tensor gSinCont = make_tensor(make_gmem_ptr(reinterpret_cast(params.rotary_sin_ptr) + row_offset_cossin), + Shape, Int>{}, + make_stride(params.rotary_dim / 2, _1{})); + Tensor tRgCos = gmem_thr_copy_rotary.partition_S(gCos); + Tensor tRgSin = gmem_thr_copy_rotary.partition_S(gSin); + Tensor tRgCosCont = gmem_thr_copy_rotary_cont.partition_S(gCosCont); + Tensor tRgSinCont = gmem_thr_copy_rotary_cont.partition_S(gSinCont); + // if (cute::thread(0, 0)) { printf("rotary_cos_ptr = %p, gCos.data() = %p, tRgCos.data() = %p, rotary_dim = %d\n", params.rotary_cos_ptr, gCos.data(), tRgCos.data(), params.rotary_dim); } + // if (cute::thread(8, 0)) { print_tensor(gCos); } + // if (cute::thread(0, 0)) { print_tensor(tRgCos); } + + const index_t row_offset_knew = binfo.k_offset(params.knew_batch_stride, params.knew_row_stride, bidb) + + ((n_block_max - 1) * kBlockN) * params.knew_row_stride + (bidh / params.h_h_k_ratio) * params.knew_head_stride; + const index_t row_offset_vnew = binfo.k_offset(params.vnew_batch_stride, params.vnew_row_stride, bidb) + + ((n_block_max - 1) * kBlockN) * params.vnew_row_stride + (bidh / params.h_h_k_ratio) * params.vnew_head_stride; + // Subtract seqlen_k_cache * row stride so that conceptually gK and gKnew "line up". When we access them, + // e.g. if gK has 128 rows and gKnew has 64 rows, we access gK[:128] and gKNew[128:128 + 64]. + // This maps to accessing the first 64 rows of knew_ptr. + Tensor gKnew = make_tensor(make_gmem_ptr(reinterpret_cast(params.knew_ptr) + + row_offset_knew - binfo.seqlen_k_cache * params.knew_row_stride), + Shape, Int>{}, + make_stride(params.knew_row_stride, _1{})); + // if (threadIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0) { printf("knew_ptr = %p, row_offset_knew = %d, gKnew_ptr = %p\n", params.knew_ptr, row_offset_knew, gKnew.data()); } + Tensor gVnew = make_tensor(make_gmem_ptr(reinterpret_cast(params.vnew_ptr) + + row_offset_vnew - binfo.seqlen_k_cache * params.vnew_row_stride), + Shape, Int>{}, + make_stride(params.vnew_row_stride, _1{})); + Tensor tKgKnew = gmem_thr_copy_QKV.partition_S(gKnew); // (KCPY, KCPY_N, KCPY_K) + Tensor tVgVnew = gmem_thr_copy_QKV.partition_S(gVnew); // (VCPY, VCPY_N, VCPY_K) + + const int n_block_copy_min = std::max(n_block_min, binfo.seqlen_k_cache / kBlockN); + auto tKgK_data = tKgK.data(); + auto tVgV_data = tVgV.data(); + for (int n_block = n_block_max - 1; n_block >= n_block_copy_min; n_block--) { + flash::copy_w_min_idx( + tVgVnew, tVgV, tKVcKV, tKVpKV, binfo.actual_seqlen_k - n_block * kBlockN, binfo.seqlen_k_cache - n_block * kBlockN + ); + tVgVnew.data() = tVgVnew.data() + (-int(kBlockN * params.vnew_row_stride)); + if (params.rotary_dim == 0) { + flash::copy_w_min_idx( + tKgKnew, tKgK, tKVcKV, tKVpKV, binfo.actual_seqlen_k - n_block * kBlockN, binfo.seqlen_k_cache - n_block * kBlockN + ); + } else { + if (params.is_rotary_interleaved) { + // Don't clear OOB_K because we're writing to global memory + flash::copy_rotary_interleaved( + tKgKnew, tKgK, tRgCos, tRgSin, tKVcKV, binfo.actual_seqlen_k - n_block * kBlockN, + binfo.seqlen_k_cache - n_block * kBlockN, params.d, params.rotary_dim + ); + tRgCos.data() = tRgCos.data() + (-int(kBlockN * params.rotary_dim / 2)); + tRgSin.data() = tRgSin.data() + (-int(kBlockN * params.rotary_dim / 2)); + } else { + // Don't clear OOB_K because we're writing to global memory + flash::copy_rotary_contiguous( + tKgKnew, tKgK, tRgCosCont, tRgSinCont, tKVcKV, binfo.actual_seqlen_k - n_block * kBlockN, + binfo.seqlen_k_cache - n_block * kBlockN, params.d, params.rotary_dim + ); + tRgCosCont.data() = tRgCosCont.data() + (-int(kBlockN * params.rotary_dim / 2)); + tRgSinCont.data() = tRgSinCont.data() + (-int(kBlockN * params.rotary_dim / 2)); + + } + } + tKgKnew.data() = tKgKnew.data() + (-int(kBlockN * params.knew_row_stride)); + if (block_table == nullptr) { + tVgV.data() = tVgV.data() + (-int(kBlockN * params.v_row_stride)); + tKgK.data() = tKgK.data() + (-int(kBlockN * params.k_row_stride)); + } else { + if (n_block > n_block_copy_min) { + const int block_table_idx_cur = n_block * kBlockN / params.page_block_size; + const int block_table_offset_cur = n_block * kBlockN - block_table_idx_cur * params.page_block_size; + const int block_table_idx_next = (n_block - 1) * kBlockN / params.page_block_size; + const int block_table_offset_next = (n_block - 1) * kBlockN - block_table_idx_next * params.page_block_size; + const int table_diff = block_table[block_table_idx_next] - block_table[block_table_idx_cur]; + const int offset_diff = block_table_offset_next - block_table_offset_cur; + tVgV.data() = tVgV.data() + table_diff * params.v_batch_stride + offset_diff * params.v_row_stride; + tKgK.data() = tKgK.data() + table_diff * params.k_batch_stride + offset_diff * params.k_row_stride; + } + } } - } - tKgK.data() = tKgK.data() + (-int(kBlockN * params.k_row_stride)); - tKgKnew.data() = tKgKnew.data() + (-int(kBlockN * params.knew_row_stride)); + // Need this before we can read in K again, so that we'll see the updated K values. + __syncthreads(); + tKgK.data() = tKgK_data; + tVgV.data() = tVgV_data; } - // Need this before we can read in K again, so that we'll see the updated K values. - __syncthreads(); - if (n_block_max > n_block_copy_min) { - tKgK.data() = tKgK.data() + (n_block_max - n_block_copy_min) * kBlockN * params.k_row_stride; - tVgV.data() = tVgV.data() + (n_block_max - n_block_copy_min) * kBlockN * params.v_row_stride; - } - } - // Read Q from gmem to smem, optionally apply rotary embedding. - Tensor tQrQ = make_fragment_like(tQgQ); - if (!Append_KV || params.rotary_dim == 0) { - // We don't need to clear the sQ smem tiles since we'll only write out the valid outputs - flash::copy(gmem_tiled_copy_QKV, tQgQ, tQsQ, tQcQ, tQpQ, - binfo.actual_seqlen_q - m_block * kBlockM); - } else { - const index_t row_offset_cossin = (binfo.seqlen_k_cache + (Is_causal || Is_local ? m_block * kBlockM : 0)) * (params.rotary_dim / 2); - // If not causal, all the queries get the same the cos/sin, taken at location seqlen_k_cache. - // We do this by setting the row stride of gCos / gSin to 0. - Tensor gCos = make_tensor(make_gmem_ptr(reinterpret_cast(params.rotary_cos_ptr) + row_offset_cossin), - Shape, Int>{}, - make_stride(Is_causal || Is_local ? params.rotary_dim / 2 : 0, _1{})); - Tensor gSin = make_tensor(make_gmem_ptr(reinterpret_cast(params.rotary_sin_ptr) + row_offset_cossin), - Shape, Int>{}, - make_stride(Is_causal || Is_local ? params.rotary_dim / 2 : 0, _1{})); - Tensor gCosCont = make_tensor(make_gmem_ptr(reinterpret_cast(params.rotary_cos_ptr) + row_offset_cossin), + // Read Q from gmem to smem, optionally apply rotary embedding. + if (!Append_KV || params.rotary_dim == 0) { + // We don't need to clear the sQ smem tiles since we'll only write out the valid outputs + flash::copy(gmem_tiled_copy_QKV, tQgQ, tQsQ, tQcQ, tQpQ, + binfo.actual_seqlen_q - m_block * kBlockM); + } else { + const index_t row_offset_cossin = (binfo.seqlen_k_cache + (Is_causal || Is_local ? m_block * kBlockM : 0)) * (params.rotary_dim / 2); + // If not causal, all the queries get the same the cos/sin, taken at location seqlen_k_cache. + // We do this by setting the row stride of gCos / gSin to 0. + Tensor gCos = make_tensor(make_gmem_ptr(reinterpret_cast(params.rotary_cos_ptr) + row_offset_cossin), + Shape, Int>{}, + make_stride(Is_causal || Is_local ? params.rotary_dim / 2 : 0, _1{})); + Tensor gSin = make_tensor(make_gmem_ptr(reinterpret_cast(params.rotary_sin_ptr) + row_offset_cossin), + Shape, Int>{}, + make_stride(Is_causal || Is_local ? params.rotary_dim / 2 : 0, _1{})); + Tensor gCosCont = make_tensor(make_gmem_ptr(reinterpret_cast(params.rotary_cos_ptr) + row_offset_cossin), Shape, Int>{}, make_stride(Is_causal || Is_local ? params.rotary_dim / 2 : 0, _1{})); - Tensor gSinCont = make_tensor(make_gmem_ptr(reinterpret_cast(params.rotary_sin_ptr) + row_offset_cossin), + Tensor gSinCont = make_tensor(make_gmem_ptr(reinterpret_cast(params.rotary_sin_ptr) + row_offset_cossin), Shape, Int>{}, make_stride(Is_causal || Is_local ? params.rotary_dim / 2 : 0, _1{})); - Tensor tRgCos = gmem_thr_copy_rotary.partition_S(gCos); - Tensor tRgSin = gmem_thr_copy_rotary.partition_S(gSin); - Tensor tRgCosCont = gmem_thr_copy_rotary_cont.partition_S(gCosCont); - Tensor tRgSinCont = gmem_thr_copy_rotary_cont.partition_S(gSinCont); - if (params.is_rotary_interleaved) { - flash::copy_rotary_interleaved( - tQgQ, tQsQ, tRgCos, tRgSin, tQcQ, binfo.actual_seqlen_q - m_block * kBlockM, - 0, params.d, params.rotary_dim); - } else { - flash::copy_rotary_contiguous( - tQgQ, tQsQ, tRgCosCont, tRgSinCont, tQcQ, binfo.actual_seqlen_q - m_block * kBlockM, - 0, params.d, params.rotary_dim); + Tensor tRgCos = gmem_thr_copy_rotary.partition_S(gCos); + Tensor tRgSin = gmem_thr_copy_rotary.partition_S(gSin); + Tensor tRgCosCont = gmem_thr_copy_rotary_cont.partition_S(gCosCont); + Tensor tRgSinCont = gmem_thr_copy_rotary_cont.partition_S(gSinCont); + if (params.is_rotary_interleaved) { + flash::copy_rotary_interleaved( + tQgQ, tQsQ, tRgCos, tRgSin, tQcQ, binfo.actual_seqlen_q - m_block * kBlockM, + 0, params.d, params.rotary_dim + ); + } else { + flash::copy_rotary_contiguous( + tQgQ, tQsQ, tRgCosCont, tRgSinCont, tQcQ, binfo.actual_seqlen_q - m_block * kBlockM, + 0, params.d, params.rotary_dim + ); + } } - } - - int n_block = n_block_max - 1; - // We don't need to clear the sK smem tiles since we'll mask out the scores anyway. - flash::copy(gmem_tiled_copy_QKV, tKgK, tKsK, tKVcKV, tKVpKV, - binfo.actual_seqlen_k - n_block * kBlockN); - cute::cp_async_fence(); - - // flash::cp_async_wait<0>(); - // __syncthreads(); - // if (tidx == 0 && blockIdx.y == 0 && blockIdx.z == 0) { print(tKsK); } - // __syncthreads(); - - clear(acc_o); - - // For performance reason, we separate out two kinds of iterations: - // those that need masking on S, and those that don't. - // We need masking on S for the very last block when K and V has length not multiple of kBlockN. - // We also need masking on S if it's causal, for the last ceil_div(kBlockM, kBlockN) blocks. - // We will have at least 1 "masking" iteration. - - // If not even_N, then seqlen_k might end in the middle of a block. In that case we need to - // mask 2 blocks (e.g. when kBlockM == kBlockN), not just 1. - constexpr int n_masking_steps = (!Is_causal && !Is_local) - ? 1 - : ((Is_even_MN && Is_causal) ? cute::ceil_div(kBlockM, kBlockN) : cute::ceil_div(kBlockM, kBlockN) + 1); -#pragma unroll - for (int masking_step = 0; masking_step < n_masking_steps; ++masking_step, --n_block) { - Tensor acc_s = partition_fragment_C(tiled_mma, Shape, Int>{}); // (MMA=4, MMA_M, MMA_N) - clear(acc_s); - flash::cp_async_wait<0>(); - __syncthreads(); - // Advance gV - if (masking_step > 0) { - tVgV.data() = tVgV.data() + (-int(kBlockN * params.v_row_stride)); - flash::copy(gmem_tiled_copy_QKV, tVgV, tVsV, tKVcKV, tKVpKV); - } else { - // Clear the smem tiles to account for predicated off loads - flash::copy( - gmem_tiled_copy_QKV, tVgV, tVsV, tKVcKV, tKVpKV, binfo.actual_seqlen_k - n_block * kBlockN); - } + int n_block = n_block_max - 1; + // We don't need to clear the sK smem tiles since we'll mask out the scores anyway. + flash::copy(gmem_tiled_copy_QKV, tKgK, tKsK, tKVcKV, tKVpKV, + binfo.actual_seqlen_k - n_block * kBlockN); cute::cp_async_fence(); - flash::gemm( - acc_s, tSrQ, tSrK, tSsQ, tSsK, tiled_mma, smem_tiled_copy_Q, smem_tiled_copy_K, - smem_thr_copy_Q, smem_thr_copy_K); - // if (cute::thread0()) { print(acc_s); } - - // Reshape acc_s from (MMA=4, MMA_M, MMA_N) to (nrow=(2, MMA_M), ncol=(2, MMA_N)) - Tensor scores = make_tensor(acc_s.data(), flash::convert_layout_acc_rowcol(acc_s.layout())); - // if (cute::thread0()) { print(scores); } - // We don't put the masking before the matmul S = Q K^T because we don't clear sK - // for rows outside actual_seqlen_k. So those rows could have Inf / NaN, and the matmul - // can produce Inf / NaN. - if (!Is_causal && !Is_local) { - if (!Is_even_MN) { - flash::apply_mask(scores, binfo.actual_seqlen_k - n_block * kBlockN); - } - } else { - flash::apply_mask_local(scores, n_block * kBlockN, binfo.actual_seqlen_k, - m_block * kBlockM + (tidx / 32) * 16 + (tidx % 32) / 4, - binfo.actual_seqlen_q, kNWarps * 16, - params.window_size_left, params.window_size_right); - } - - flash::cp_async_wait<0>(); - __syncthreads(); - // if (tidx == 0 && blockIdx.y == 0 && blockIdx.z == 0) { print(tVsV); } + // flash::cp_async_wait<0>(); + // __syncthreads(); + // if (tidx == 0 && blockIdx.y == 0 && blockIdx.z == 0) { print(tKsK); } // __syncthreads(); - if (n_block > n_block_min) { - // Advance gK - tKgK.data() = tKgK.data() + (-int(kBlockN * params.k_row_stride)); - flash::copy(gmem_tiled_copy_QKV, tKgK, tKsK, tKVcKV, tKVpKV); - // This cp_async_fence needs to be in the if block, otherwise the synchronization - // isn't right and we get race conditions. - cute::cp_async_fence(); - } + clear(acc_o); + + flash::Softmax<2 * size<1>(acc_o)> softmax; + + const float alibi_slope = !Has_alibi ? 0.0f : reinterpret_cast(params.alibi_slopes_ptr)[bidb * params.alibi_slopes_batch_stride + bidh] / params.scale_softmax; + flash::Mask mask(binfo.actual_seqlen_k, binfo.actual_seqlen_q, params.window_size_left, params.window_size_right, alibi_slope); + + // For performance reason, we separate out two kinds of iterations: + // those that need masking on S, and those that don't. + // We need masking on S for the very last block when K and V has length not multiple of kBlockN. + // We also need masking on S if it's causal, for the last ceil_div(kBlockM, kBlockN) blocks. + // We will have at least 1 "masking" iteration. + + // If not even_N, then seqlen_k might end in the middle of a block. In that case we need to + // mask 2 blocks (e.g. when kBlockM == kBlockN), not just 1. + constexpr int n_masking_steps = (!Is_causal && !Is_local) + ? 1 + : ((Is_even_MN && Is_causal) ? cute::ceil_div(kBlockM, kBlockN) : cute::ceil_div(kBlockM, kBlockN) + 1); + #pragma unroll + for (int masking_step = 0; masking_step < n_masking_steps; ++masking_step, --n_block) { + Tensor acc_s = partition_fragment_C(tiled_mma, Shape, Int>{}); // (MMA=4, MMA_M, MMA_N) + clear(acc_s); + flash::cp_async_wait<0>(); + __syncthreads(); + + // Advance gV + if (masking_step > 0) { + if (block_table == nullptr) { + tVgV.data() = tVgV.data() + (-int(kBlockN * params.v_row_stride)); + } else { + const int block_table_idx_cur = (n_block + 1) * kBlockN / params.page_block_size; + const int block_table_offset_cur = (n_block + 1) * kBlockN - block_table_idx_cur * params.page_block_size; + const int block_table_idx_next = n_block * kBlockN / params.page_block_size; + const int block_table_offset_next = n_block * kBlockN - block_table_idx_next * params.page_block_size; + tVgV.data() = tVgV.data() + (block_table[block_table_idx_next] - block_table[block_table_idx_cur]) * params.v_batch_stride + (block_table_offset_next - block_table_offset_cur) * params.v_row_stride; + } + flash::copy(gmem_tiled_copy_QKV, tVgV, tVsV, tKVcKV, tKVpKV); + } else { + // Clear the smem tiles to account for predicated off loads + flash::copy( + gmem_tiled_copy_QKV, tVgV, tVsV, tKVcKV, tKVpKV, binfo.actual_seqlen_k - n_block * kBlockN + ); + } + cute::cp_async_fence(); + + flash::gemm( + acc_s, tSrQ, tSrK, tSsQ, tSsK, tiled_mma, smem_tiled_copy_Q, smem_tiled_copy_K, + smem_thr_copy_Q, smem_thr_copy_K + ); + // if (cute::thread0()) { print(acc_s); } + + mask.template apply_mask( + acc_s, n_block * kBlockN, m_block * kBlockM + (tidx / 32) * 16 + (tidx % 32) / 4, kNWarps * 16 + ); + + flash::cp_async_wait<0>(); + __syncthreads(); + // if (tidx == 0 && blockIdx.y == 0 && blockIdx.z == 0) { print(tVsV); } + // __syncthreads(); + + if (n_block > n_block_min) { + // Advance gK + if (block_table == nullptr) { + tKgK.data() = tKgK.data() + (-int(kBlockN * params.k_row_stride)); + } else { + const int block_table_idx_cur = n_block * kBlockN / params.page_block_size; + const int block_table_offset_cur = n_block * kBlockN - block_table_idx_cur * params.page_block_size; + const int block_table_idx_next = (n_block - 1) * kBlockN / params.page_block_size; + const int block_table_offset_next =(n_block - 1) * kBlockN - block_table_idx_next * params.page_block_size; + tKgK.data() = tKgK.data() + (block_table[block_table_idx_next] - block_table[block_table_idx_cur]) * params.k_batch_stride + (block_table_offset_next - block_table_offset_cur) * params.k_row_stride; + } + flash::copy(gmem_tiled_copy_QKV, tKgK, tKsK, tKVcKV, tKVpKV); + // This cp_async_fence needs to be in the if block, otherwise the synchronization + // isn't right and we get race conditions. + cute::cp_async_fence(); + } - // We have key_padding_mask so we'll need to Check_inf - masking_step == 0 - ? softmax_rescale_o(scores, scores_max, scores_sum, acc_o, params.scale_softmax_log2) - : softmax_rescale_o(scores, scores_max, scores_sum, acc_o, params.scale_softmax_log2); - // if (cute::thread0()) { print(scores_max); print(scores_sum); print(scores); } - - // Convert scores from fp32 to fp16/bf16 - Tensor rP = flash::convert_type(scores); - // Reshape rP from (nrow=(2, MMA_M), ncol=(2, MMA_N)) to ((2, 2, 2), MMA_M, MMA_N / 2) - // if using m16n8k16 or ((2, 2, 1), MMA_M, MMA_N) if using m16n8k8. - Tensor tOrP = make_tensor(rP.data(), flash::convert_layout_rowcol_Aregs(rP.layout())); - - flash::gemm_A_in_regs(acc_o, tOrP, tOrVt, tOsVt, tiled_mma, smem_tiled_copy_V, smem_thr_copy_V); - // if (cute::thread0()) { print(scores); } - - // This check is at the end of the loop since we always have at least 1 iteration - if (n_masking_steps > 1 && n_block <= n_block_min) { - --n_block; - break; - } - } + // We have key_padding_mask so we'll need to Check_inf + masking_step == 0 + ? softmax.template softmax_rescale_o(acc_s, acc_o, params.scale_softmax_log2) + : softmax.template softmax_rescale_o(acc_s, acc_o, params.scale_softmax_log2); + // if (cute::thread0()) { print(scores_max); print(scores_sum); print(scores); } - // These are the iterations where we don't need masking on S - for (; n_block >= n_block_min; --n_block) { - Tensor acc_s = partition_fragment_C(tiled_mma, Shape, Int>{}); // (MMA=4, MMA_M, MMA_N) - clear(acc_s); - flash::cp_async_wait<0>(); - __syncthreads(); - // Advance gV - tVgV.data() = tVgV.data() + (-int(kBlockN * params.v_row_stride)); - flash::copy(gmem_tiled_copy_QKV, tVgV, tVsV, tKVcKV, tKVpKV); - cute::cp_async_fence(); - - flash::gemm( - acc_s, tSrQ, tSrK, tSsQ, tSsK, tiled_mma, smem_tiled_copy_Q, smem_tiled_copy_K, - smem_thr_copy_Q, smem_thr_copy_K); + // Convert acc_s from fp32 to fp16/bf16 + Tensor rP = flash::convert_type(acc_s); + // Reshape rP from (MMA=4, MMA_M, MMA_N) to ((4, 2), MMA_M, MMA_N / 2) + // if using m16n8k16 or (4, MMA_M, MMA_N) if using m16n8k8. + Tensor tOrP = make_tensor(rP.data(), flash::convert_layout_acc_Aregs(rP.layout())); - flash::cp_async_wait<0>(); - __syncthreads(); - if (n_block > n_block_min) { - // Advance gK - tKgK.data() = tKgK.data() + (-int(kBlockN * params.k_row_stride)); - flash::copy(gmem_tiled_copy_QKV, tKgK, tKsK, tKVcKV, tKVpKV); - // This cp_async_fence needs to be in the if block, otherwise the synchronization - // isn't right and we get race conditions. - cute::cp_async_fence(); - } + flash::gemm_rs(acc_o, tOrP, tOrVt, tOsVt, tiled_mma, smem_tiled_copy_V, smem_thr_copy_V); - // Reshape acc_s from (MMA=4, MMA_M, MMA_N) to (nrow=(2, MMA_M), ncol=(2, MMA_N)) - Tensor scores = make_tensor(acc_s.data(), flash::convert_layout_acc_rowcol(acc_s.layout())); - if (Is_local && n_block * kBlockN < (m_block + 1) * kBlockM + binfo.actual_seqlen_k - binfo.actual_seqlen_q + params.window_size_right) { - flash::apply_mask_local( - scores, n_block * kBlockN, binfo.actual_seqlen_k, - m_block * kBlockM + (tidx / 32) * 16 + (tidx % 32) / 4, - binfo.actual_seqlen_q, kNWarps * 16, - params.window_size_left, params.window_size_right); + // This check is at the end of the loop since we always have at least 1 iteration + if (n_masking_steps > 1 && n_block <= n_block_min) { + --n_block; + break; + } } - softmax_rescale_o(scores, scores_max, scores_sum, acc_o, params.scale_softmax_log2); - Tensor rP = flash::convert_type(scores); - // Reshape rP from (nrow=(2, MMA_M), ncol=(2, MMA_N)) to ((2, 2, 2), MMA_M, MMA_N / 2) - // if using m16n8k16 or ((2, 2, 1), MMA_M, MMA_N) if using m16n8k8. - Tensor tOrP = make_tensor(rP.data(), flash::convert_layout_rowcol_Aregs(rP.layout())); + // These are the iterations where we don't need masking on S + for (; n_block >= n_block_min; --n_block) { + Tensor acc_s = partition_fragment_C(tiled_mma, Shape, Int>{}); // (MMA=4, MMA_M, MMA_N) + clear(acc_s); + flash::cp_async_wait<0>(); + __syncthreads(); + // Advance gV + if (block_table == nullptr) { + tVgV.data() = tVgV.data() + (-int(kBlockN * params.v_row_stride)); + } else { + const int block_table_idx_cur = (n_block + 1) * kBlockN / params.page_block_size; + const int block_table_offset_cur = (n_block + 1) * kBlockN - block_table_idx_cur * params.page_block_size; + const int block_table_idx_next = n_block * kBlockN / params.page_block_size; + const int block_table_offset_next = n_block * kBlockN - block_table_idx_next * params.page_block_size; + tVgV.data() = tVgV.data() + (block_table[block_table_idx_next] - block_table[block_table_idx_cur]) * params.v_batch_stride + (block_table_offset_next - block_table_offset_cur) * params.v_row_stride; + } + flash::copy(gmem_tiled_copy_QKV, tVgV, tVsV, tKVcKV, tKVpKV); + cute::cp_async_fence(); + + flash::gemm( + acc_s, tSrQ, tSrK, tSsQ, tSsK, tiled_mma, smem_tiled_copy_Q, smem_tiled_copy_K, + smem_thr_copy_Q, smem_thr_copy_K + ); + + flash::cp_async_wait<0>(); + __syncthreads(); + if (n_block > n_block_min) { + // Advance gK + if (block_table == nullptr) { + tKgK.data() = tKgK.data() + (-int(kBlockN * params.k_row_stride)); + } else { + const int block_table_idx_cur = n_block * kBlockN / params.page_block_size; + const int block_table_offset_cur = n_block * kBlockN - block_table_idx_cur * params.page_block_size; + const int block_table_idx_next = (n_block - 1) * kBlockN / params.page_block_size; + const int block_table_offset_next = (n_block - 1) * kBlockN - block_table_idx_next * params.page_block_size; + tKgK.data() = tKgK.data() + (block_table[block_table_idx_next] - block_table[block_table_idx_cur]) * params.k_batch_stride + (block_table_offset_next - block_table_offset_cur) * params.k_row_stride; + } + flash::copy(gmem_tiled_copy_QKV, tKgK, tKsK, tKVcKV, tKVpKV); + // This cp_async_fence needs to be in the if block, otherwise the synchronization + // isn't right and we get race conditions. + cute::cp_async_fence(); + } - flash::gemm_A_in_regs(acc_o, tOrP, tOrVt, tOsVt, tiled_mma, smem_tiled_copy_V, smem_thr_copy_V); - } + mask.template apply_mask( + acc_s, n_block * kBlockN, m_block * kBlockM + (tidx / 32) * 16 + (tidx % 32) / 4, kNWarps * 16 + ); + softmax.template softmax_rescale_o(acc_s, acc_o, params.scale_softmax_log2); - // Epilogue + Tensor rP = flash::convert_type(acc_s); + // Reshape rP from (MMA=4, MMA_M, MMA_N) to ((4, 2), MMA_M, MMA_N / 2) + // if using m16n8k16 or (4, MMA_M, MMA_N) if using m16n8k8. + Tensor tOrP = make_tensor(rP.data(), flash::convert_layout_acc_Aregs(rP.layout())); - // Reshape acc_o from (MMA=4, MMA_M, MMA_K) to (nrow=(2, MMA_M), ncol=(2, MMA_K)) - Tensor acc_o_rowcol = make_tensor(acc_o.data(), flash::convert_layout_acc_rowcol(acc_o.layout())); - // if (cute::thread0()) { print(acc_o_rowcol); } - Tensor lse = make_fragment_like(scores_sum); -#pragma unroll - for (int mi = 0; mi < size<0>(acc_o_rowcol); ++mi) { - float sum = scores_sum(mi); - float inv_sum = (sum == 0.f || sum != sum) ? 1.f : 1.f / sum; - lse(mi) = (sum == 0.f || sum != sum) ? (Split ? -INFINITY : INFINITY) : scores_max(mi) * params.scale_softmax + __logf(sum); - float scale = inv_sum; -#pragma unroll - for (int ni = 0; ni < size<1>(acc_o_rowcol); ++ni) { - acc_o_rowcol(mi, ni) *= scale; + flash::gemm_rs(acc_o, tOrP, tOrVt, tOsVt, tiled_mma, smem_tiled_copy_V, smem_thr_copy_V); } - } - // if (cute::thread0()) { print(lse); } - // if (cute::thread0()) { print(acc_o_rowcol); } - - Tensor sOaccum = make_tensor(make_smem_ptr(reinterpret_cast(smem_)), typename Kernel_traits::SmemLayoutO{}); // (SMEM_M,SMEM_N) - // Partition sO to match the accumulator partitioning - using SmemTiledCopyO = std::conditional_t< - !Split, - typename Kernel_traits::SmemCopyAtomO, - typename Kernel_traits::SmemCopyAtomOaccum>; - auto smem_tiled_copy_Oaccum = make_tiled_copy_C(SmemTiledCopyO{}, tiled_mma); - auto smem_thr_copy_Oaccum = smem_tiled_copy_Oaccum.get_thread_slice(tidx); - Tensor rO = flash::convert_type(acc_o); - Tensor taccOrOaccum = smem_thr_copy_Oaccum.retile_S(rO); // ((Atom,AtomNum), MMA_M, MMA_N) - Tensor taccOsOaccum = smem_thr_copy_Oaccum.partition_D(sOaccum); // ((Atom,AtomNum),PIPE_M,PIPE_N) - - // sOaccum is larger than sQ, so we need to syncthreads here - // TODO: allocate enough smem for sOaccum - if constexpr (Split) { - __syncthreads(); - } - cute::copy(smem_tiled_copy_Oaccum, taccOrOaccum, taccOsOaccum); - - const index_t row_offset_o = binfo.q_offset(params.o_batch_stride, params.o_row_stride, bidb) + m_block * kBlockM * params.o_row_stride + bidh * params.o_head_stride; - const index_t row_offset_oaccum = (((n_split_idx * params.b + bidb) * params.h + bidh) * params.seqlen_q + m_block * kBlockM) * params.d_rounded; - const index_t row_offset_lseaccum = ((n_split_idx * params.b + bidb) * params.h + bidh) * params.seqlen_q + m_block * kBlockM; - - Tensor gOaccum = make_tensor(make_gmem_ptr(reinterpret_cast(Split ? params.oaccum_ptr : params.o_ptr) + (Split ? row_offset_oaccum : row_offset_o)), - Shape, Int>{}, - make_stride(Split ? kHeadDim : params.o_row_stride, _1{})); - Tensor gLSEaccum = make_tensor(make_gmem_ptr(reinterpret_cast(Split ? params.softmax_lseaccum_ptr : params.softmax_lse_ptr) + row_offset_lseaccum), - Shape>{}, Stride<_1>{}); - // if (tidx == 0) { printf("row_offset_o = %d, bidh = %d, gOaccum = %p\n", row_offset_o, bidh, gOaccum.data()); } + // Epilogue + + Tensor lse = softmax.template normalize_softmax_lse(acc_o, params.scale_softmax); + // if (cute::thread0()) { print(lse); } + + Tensor sOaccum = make_tensor(make_smem_ptr(reinterpret_cast(smem_)), typename Kernel_traits::SmemLayoutO{}); // (SMEM_M,SMEM_N) + // Partition sO to match the accumulator partitioning + using SmemTiledCopyO = std::conditional_t< + !Split, + typename Kernel_traits::SmemCopyAtomO, + typename Kernel_traits::SmemCopyAtomOaccum + >; + auto smem_tiled_copy_Oaccum = make_tiled_copy_C(SmemTiledCopyO{}, tiled_mma); + auto smem_thr_copy_Oaccum = smem_tiled_copy_Oaccum.get_thread_slice(tidx); + Tensor rO = flash::convert_type(acc_o); + Tensor taccOrOaccum = smem_thr_copy_Oaccum.retile_S(rO); // ((Atom,AtomNum), MMA_M, MMA_N) + Tensor taccOsOaccum = smem_thr_copy_Oaccum.partition_D(sOaccum); // ((Atom,AtomNum),PIPE_M,PIPE_N) + + // sOaccum is larger than sQ, so we need to syncthreads here + // TODO: allocate enough smem for sOaccum + if constexpr (Split) { __syncthreads(); } + + cute::copy(smem_tiled_copy_Oaccum, taccOrOaccum, taccOsOaccum); + + const index_t row_offset_o = binfo.q_offset(params.o_batch_stride, params.o_row_stride, bidb) + + m_block * kBlockM * params.o_row_stride + bidh * params.o_head_stride; + const index_t row_offset_oaccum = (((n_split_idx * params.b + bidb) * params.h + bidh) * params.seqlen_q + + m_block * kBlockM) * params.d_rounded; + const index_t row_offset_lseaccum = ((n_split_idx * params.b + bidb) * params.h + bidh) * params.seqlen_q + m_block * kBlockM; - GmemTiledCopyO gmem_tiled_copy_Oaccum; - auto gmem_thr_copy_Oaccum = gmem_tiled_copy_Oaccum.get_thread_slice(tidx); - Tensor tOsOaccum = gmem_thr_copy_Oaccum.partition_S(sOaccum); // ((Atom,AtomNum),ATOM_M,ATOM_N) - Tensor tOgOaccum = gmem_thr_copy_Oaccum.partition_D(gOaccum); + Tensor gOaccum = make_tensor(make_gmem_ptr(reinterpret_cast(Split ? params.oaccum_ptr : params.o_ptr) + (Split ? row_offset_oaccum : row_offset_o)), + Shape, Int>{}, + make_stride(Split ? kHeadDim : params.o_row_stride, _1{})); + Tensor gLSEaccum = make_tensor(make_gmem_ptr(reinterpret_cast(Split ? params.softmax_lseaccum_ptr : params.softmax_lse_ptr) + row_offset_lseaccum), + Shape>{}, Stride<_1>{}); + // if (tidx == 0) { printf("row_offset_o = %d, bidh = %d, gOaccum = %p\n", row_offset_o, bidh, gOaccum.data()); } - __syncthreads(); + GmemTiledCopyO gmem_tiled_copy_Oaccum; + auto gmem_thr_copy_Oaccum = gmem_tiled_copy_Oaccum.get_thread_slice(tidx); + Tensor tOsOaccum = gmem_thr_copy_Oaccum.partition_S(sOaccum); // ((Atom,AtomNum),ATOM_M,ATOM_N) + Tensor tOgOaccum = gmem_thr_copy_Oaccum.partition_D(gOaccum); - Tensor tOrOaccum = make_tensor(shape(tOgOaccum)); - cute::copy(gmem_tiled_copy_Oaccum, tOsOaccum, tOrOaccum); + __syncthreads(); - Tensor caccO = make_identity_tensor(Shape, Int>{}); // (BLK_M,BLK_K) -> (blk_m,blk_k) - Tensor taccOcO = thr_mma.partition_C(caccO); // (MMA,MMA_M,MMA_K) - static_assert(decltype(size<0>(taccOcO))::value == 4); - // Convert to ((2, 2), MMA_M, MMA_K) then take only the row indices. - Tensor taccOcO_row = logical_divide(taccOcO, Shape<_2>{})(make_coord(0, _), _, 0); - CUTE_STATIC_ASSERT_V(size(lse) == size(taccOcO_row)); // MMA_M - if (get<1>(taccOcO_row(0)) == 0) { -#pragma unroll - for (int mi = 0; mi < size(lse); ++mi) { - const int row = get<0>(taccOcO_row(mi)); - if (row < binfo.actual_seqlen_q - m_block * kBlockM) { - gLSEaccum(row) = lse(mi); - } + Tensor tOrOaccum = make_tensor(shape(tOgOaccum)); + cute::copy(gmem_tiled_copy_Oaccum, tOsOaccum, tOrOaccum); + + Tensor caccO = make_identity_tensor(Shape, Int>{}); // (BLK_M,BLK_K) -> (blk_m,blk_k) + Tensor taccOcO = thr_mma.partition_C(caccO); // (MMA,MMA_M,MMA_K) + static_assert(decltype(size<0>(taccOcO))::value == 4); + // Convert to ((2, 2), MMA_M, MMA_K) then take only the row indices. + Tensor taccOcO_row = logical_divide(taccOcO, Shape<_2>{})(make_coord(0, _), _, 0); + CUTE_STATIC_ASSERT_V(size(lse) == size(taccOcO_row)); // MMA_M + if (get<1>(taccOcO_row(0)) == 0) { + #pragma unroll + for (int mi = 0; mi < size(lse); ++mi) { + const int row = get<0>(taccOcO_row(mi)); + if (row < binfo.actual_seqlen_q - m_block * kBlockM) { gLSEaccum(row) = lse(mi); } + } } - } - // Construct identity layout for sO - Tensor cO = make_identity_tensor(make_shape(size<0>(sOaccum), size<1>(sOaccum))); // (BLK_M,BLK_K) -> (blk_m,blk_k) - // Repeat the partitioning with identity layouts - Tensor tOcO = gmem_thr_copy_Oaccum.partition_D(cO); // (ACPY,ACPY_M,ACPY_K) -> (blk_m,blk_k) - Tensor tOpO = make_tensor(make_shape(size<2>(tOgOaccum))); - if (!Is_even_K) { -#pragma unroll - for (int k = 0; k < size(tOpO); ++k) { - tOpO(k) = get<1>(tOcO(0, 0, k)) < params.d; + // Construct identity layout for sO + Tensor cO = make_identity_tensor(make_shape(size<0>(sOaccum), size<1>(sOaccum))); // (BLK_M,BLK_K) -> (blk_m,blk_k) + // Repeat the partitioning with identity layouts + Tensor tOcO = gmem_thr_copy_Oaccum.partition_D(cO); // (ACPY,ACPY_M,ACPY_K) -> (blk_m,blk_k) + Tensor tOpO = make_tensor(make_shape(size<2>(tOgOaccum))); + if (!Is_even_K) { + #pragma unroll + for (int k = 0; k < size(tOpO); ++k) { tOpO(k) = get<1>(tOcO(0, 0, k)) < params.d; } } - } - // Clear_OOB_K must be false since we don't want to write zeros to gmem - flash::copy( - gmem_tiled_copy_Oaccum, tOrOaccum, tOgOaccum, tOcO, tOpO, binfo.actual_seqlen_q - m_block * kBlockM); - // __syncthreads(); - // if (cute::thread0()) { print(tOgOaccum); } + // Clear_OOB_K must be false since we don't want to write zeros to gmem + flash::copy( + gmem_tiled_copy_Oaccum, tOrOaccum, tOgOaccum, tOcO, tOpO, binfo.actual_seqlen_q - m_block * kBlockM + ); } //////////////////////////////////////////////////////////////////////////////////////////////////// -template +template inline __device__ void compute_attn(const Params& params) { const int m_block = blockIdx.x; // The block index for the batch. @@ -1064,12 +1043,12 @@ inline __device__ void compute_attn(const Params& params) { // the attention matrix. This way, as long as we have the batch, head, and the location of // the 16 x 32 block within the attention matrix, we can generate the exact same dropout pattern. - flash::compute_attn_1rowblock(params, bidb, bidh, m_block); + flash::compute_attn_1rowblock(params, bidb, bidh, m_block); } //////////////////////////////////////////////////////////////////////////////////////////////////// -template +template inline __device__ void compute_attn_splitkv(const Params& params) { const int m_block = blockIdx.x; // The block index for the batch. @@ -1078,7 +1057,7 @@ inline __device__ void compute_attn_splitkv(const Params& params) { const int bidh = Split ? blockIdx.z - bidb * params.h : blockIdx.z; const int n_split_idx = Split ? blockIdx.y : 0; const int num_n_splits = Split ? gridDim.y : 1; - flash::compute_attn_1rowblock_splitkv(params, bidb, bidh, m_block, n_split_idx, num_n_splits); + flash::compute_attn_1rowblock_splitkv(params, bidb, bidh, m_block, n_split_idx, num_n_splits); } //////////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/operators/cuda/attention_lib/flash_attention/flash_fwd_launch_template.h b/operators/cuda/attention_lib/flash_attention/flash_fwd_launch_template.h index e2f2505a7..750305fd4 100644 --- a/operators/cuda/attention_lib/flash_attention/flash_fwd_launch_template.h +++ b/operators/cuda/attention_lib/flash_attention/flash_fwd_launch_template.h @@ -9,20 +9,20 @@ namespace flash { -template +template __global__ void flash_fwd_kernel(Flash_fwd_params params) { static_assert(!(Is_causal && Is_local)); // If Is_local is true, Is_causal should be false #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 - flash::compute_attn(params); + flash::compute_attn(params); #else (void)params; #endif } -template +template __global__ void flash_fwd_splitkv_kernel(Flash_fwd_params params) { #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 - flash::compute_attn_splitkv(params); + flash::compute_attn_splitkv(params); #else (void)params; #endif @@ -38,7 +38,7 @@ __global__ void flash_fwd_splitkv_combine_kernel(Flash_fwd_params params) { #endif } -template +template void run_flash_fwd(Flash_fwd_params& params, cudaStream_t stream) { constexpr size_t smem_size = Kernel_traits::kSmemSize; @@ -53,23 +53,25 @@ void run_flash_fwd(Flash_fwd_params& params, cudaStream_t stream) { BOOL_SWITCH(is_even_MN, IsEvenMNConst, [&] { BOOL_SWITCH(is_even_K, IsEvenKConst, [&] { BOOL_SWITCH(params.window_size_left >= 0 || params.window_size_right >= 0, Is_local, [&] { - // Will only return softmax if dropout, to reduce compilation time. - // If not IsEvenKConst, we also set IsEvenMNConst to false to reduce number of templates. - // If head dim > 128, set IsEvenMNConst to false to reduce number of templates - // If Is_local, set Is_causal to false - auto kernel = &flash_fwd_kernel < Kernel_traits, Is_causal && !Is_local, Is_local, IsEvenMNConst && IsEvenKConst && !Is_local && Kernel_traits::kHeadDim <= 128, IsEvenKConst, false > ; - // auto kernel = &flash_fwd_kernel; - if (smem_size >= 48 * 1024) { - cudaFuncSetAttribute( - kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size); - // ORT_ENFORCE(cudaFuncSetAttribute( - // kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); - } - // int ctas_per_sm; - // cudaError status_ = cudaOccupancyMaxActiveBlocksPerMultiprocessor( - // &ctas_per_sm, kernel, Kernel_traits::kNThreads, smem_size); - // printf("smem_size = %d, CTAs per SM = %d\n", int(smem_size), ctas_per_sm); - kernel<<>>(params); + BOOL_SWITCH(params.alibi_slopes_ptr != nullptr, Has_alibi, [&] { + // Will only return softmax if dropout, to reduce compilation time. + // If not IsEvenKConst, we also set IsEvenMNConst to false to reduce number of templates. + // If head dim > 128, set IsEvenMNConst to false to reduce number of templates + // If Is_local, set Is_causal to false + auto kernel = &flash_fwd_kernel < Kernel_traits, Is_dropout, Is_causal, Is_local && !Is_causal, Has_alibi, IsEvenMNConst && IsEvenKConst && !Is_local && Kernel_traits::kHeadDim <= 128, IsEvenKConst, false > ; + // auto kernel = &flash_fwd_kernel; + if (smem_size >= 48 * 1024) { + cudaFuncSetAttribute( + kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size); + // ORT_ENFORCE(cudaFuncSetAttribute( + // kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); + } + // int ctas_per_sm; + // cudaError status_ = cudaOccupancyMaxActiveBlocksPerMultiprocessor( + // &ctas_per_sm, kernel, Kernel_traits::kNThreads, smem_size); + // printf("smem_size = %d, CTAs per SM = %d\n", int(smem_size), ctas_per_sm); + kernel<<>>(params); + }); }); }); }); @@ -90,16 +92,18 @@ void run_flash_splitkv_fwd(Flash_fwd_params& params, cudaStream_t stream) { BOOL_SWITCH(params.window_size_left >= 0 || params.window_size_right >= 0, Is_local, [&] { BOOL_SWITCH(params.num_splits > 1, Split, [&] { BOOL_SWITCH(params.knew_ptr != nullptr, Append_KV, [&] { - // If Append_KV, then we must have seqlen_offsets, which means cu_seqlens_k != nullptr. - // printf("About to launch, Split = %d, Append_KV = %d, knew_ptr = %p\n", Split, Append_KV, params.knew_ptr); - auto kernel = &flash_fwd_splitkv_kernel < Kernel_traits, Is_causal && !Is_local, Is_local, IsEvenMNConst && !Append_KV && IsEvenKConst && !Is_local && Kernel_traits::kHeadDim <= 128, IsEvenKConst, Split, Append_KV > ; - // auto kernel = &flash_fwd_splitkv_kernel; - // auto kernel = &flash_fwd_splitkv_kernel; - if (smem_size >= 48 * 1024) { - cudaFuncSetAttribute( - kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size); - } - kernel<<>>(params); + BOOL_SWITCH(params.alibi_slopes_ptr != nullptr, Has_alibi, [&] { + // If Append_KV, then we must have seqlen_offsets, which means cu_seqlens_k != nullptr. + // printf("About to launch, Split = %d, Append_KV = %d, knew_ptr = %p\n", Split, Append_KV, params.knew_ptr); + auto kernel = &flash_fwd_splitkv_kernel < Kernel_traits, Is_causal && !Is_local, Is_local, Has_alibi, IsEvenMNConst && !Append_KV && IsEvenKConst && !Is_local && Kernel_traits::kHeadDim <= 128, IsEvenKConst, Split, Append_KV > ; + // auto kernel = &flash_fwd_splitkv_kernel; + // auto kernel = &flash_fwd_splitkv_kernel; + if (smem_size >= 48 * 1024) { + cudaFuncSetAttribute( + kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size); + } + kernel<<>>(params); + }); }); }); }); @@ -143,7 +147,7 @@ template void run_mha_fwd_hdim32(Flash_fwd_params& params, cudaStream_t stream) { constexpr static int Headdim = 32; BOOL_SWITCH(params.is_causal, Is_causal, [&] { - run_flash_fwd, Is_causal>(params, stream); + run_flash_fwd, false /*Is_dropout*/, Is_causal>(params, stream); }); } @@ -154,7 +158,7 @@ void run_mha_fwd_hdim64(Flash_fwd_params& params, cudaStream_t stream) { // Using 8 warps is 18% slower for seqlen=2k, 2 warps is 5% slower // Using block size (64 x 256) is 27% slower for seqlen=2k // Using block size (256 x 64) is 85% slower for seqlen=2k, because of register spilling - run_flash_fwd, Is_causal>(params, stream); + run_flash_fwd, false /*Is_dropout*/, Is_causal>(params, stream); // run_flash_fwd, Is_causal>(params, stream); // run_flash_fwd, Is_causal>(params, stream); }); @@ -168,12 +172,12 @@ void run_mha_fwd_hdim96(Flash_fwd_params& params, cudaStream_t stream) { // For sm86 or sm89, 64 x 64 is the fastest for causal (because it's square), if (is_sm8x) { if constexpr (!Is_causal) { - run_flash_fwd, Is_causal>(params, stream); + run_flash_fwd, false /*Is_dropout*/, Is_causal>(params, stream); } else { - run_flash_fwd, Is_causal>(params, stream); + run_flash_fwd, false /*Is_dropout*/, Is_causal>(params, stream); } } else { - run_flash_fwd, Is_causal>(params, stream); + run_flash_fwd, false /*Is_dropout*/, Is_causal>(params, stream); } // run_flash_fwd, Is_causal>(params, stream); // run_flash_fwd, Is_causal>(params, stream); @@ -192,12 +196,12 @@ void run_mha_fwd_hdim128(Flash_fwd_params& params, cudaStream_t stream) { // and 128 x 32 (48 KB smem) is the fastest for non-causal since we get 2 CTAs per SM. if (is_sm8x) { if constexpr (!Is_causal) { - run_flash_fwd, Is_causal>(params, stream); + run_flash_fwd, false /*Is_dropout*/, Is_causal>(params, stream); } else { - run_flash_fwd, Is_causal>(params, stream); + run_flash_fwd, false /*Is_dropout*/, Is_causal>(params, stream); } } else { - run_flash_fwd, Is_causal>(params, stream); + run_flash_fwd, false /*Is_dropout*/, Is_causal>(params, stream); } // run_flash_fwd, Is_causal>(params, stream); // run_flash_fwd, Is_causal>(params, stream); @@ -220,12 +224,12 @@ void run_mha_fwd_hdim160(Flash_fwd_params& params, cudaStream_t stream) { // and 128 x 64 with 8 warps is the fastest for non-causal. if (is_sm8x) { if constexpr (!Is_causal) { - run_flash_fwd, Is_causal>(params, stream); + run_flash_fwd, false /*Is_dropout*/, Is_causal>(params, stream); } else { - run_flash_fwd, Is_causal>(params, stream); + run_flash_fwd, false /*Is_dropout*/, Is_causal>(params, stream); } } else { - run_flash_fwd, Is_causal>(params, stream); + run_flash_fwd, false /*Is_dropout*/, Is_causal>(params, stream); } // run_flash_fwd, Is_causal>(params, stream); // run_flash_fwd, Is_causal>(params, stream); @@ -241,7 +245,7 @@ template void run_mha_fwd_hdim192(Flash_fwd_params& params, cudaStream_t stream) { constexpr int Headdim = 192; BOOL_SWITCH(params.is_causal, Is_causal, [&] { - run_flash_fwd, Is_causal>(params, stream); + run_flash_fwd, false /*Is_dropout*/, Is_causal>(params, stream); // run_flash_fwd, Is_causal>(params, stream); // run_flash_fwd, Is_causal>(params, stream); // run_flash_fwd>(params, stream); @@ -257,9 +261,9 @@ void run_mha_fwd_hdim224(Flash_fwd_params& params, cudaStream_t stream) { // printf("max_smem_per_block = %d\n", max_smem_per_block); BOOL_SWITCH(params.is_causal, Is_causal, [&] { if (max_smem_per_block >= 2 * Headdim * (128 + 2 * 64)) { // 112 KB - run_flash_fwd, Is_causal>(params, stream); + run_flash_fwd, false /*Is_dropout*/, Is_causal>(params, stream); } else { - run_flash_fwd, Is_causal>(params, stream); + run_flash_fwd, false /*Is_dropout*/, Is_causal>(params, stream); } // run_flash_fwd, Is_causal>(params, stream); // run_flash_fwd, Is_causal>(params, stream); @@ -280,9 +284,9 @@ void run_mha_fwd_hdim256(Flash_fwd_params& params, cudaStream_t stream) { // For A100, we want to run with 128 x 64 (128KB smem). // For H100 we want to run with 64 x 64 (96KB smem) since then we can get 2 CTAs per SM. if (max_smem_per_block >= 2 * Headdim * (128 + 2 * 64) && max_smem_per_sm < 4 * Headdim * (64 + 2 * 64)) { - run_flash_fwd, Is_causal>(params, stream); + run_flash_fwd, false /*Is_dropout*/, Is_causal>(params, stream); } else { - run_flash_fwd, Is_causal>(params, stream); + run_flash_fwd, false /*Is_dropout*/, Is_causal>(params, stream); } // 64 KB // run_flash_fwd, Is_causal>(params, stream); diff --git a/operators/cuda/attention_lib/flash_attention/softmax.h b/operators/cuda/attention_lib/flash_attention/softmax.h index 9c31336c9..a70406aed 100644 --- a/operators/cuda/attention_lib/flash_attention/softmax.h +++ b/operators/cuda/attention_lib/flash_attention/softmax.h @@ -54,10 +54,10 @@ __device__ inline void reduce_max(Tensor const& tensor, Tensor reduce_(tensor, max, max_op); } -template -__device__ inline void reduce_sum(Tensor const& tensor, Tensor& sum) { - SumOp sum_op; - reduce_(tensor, sum, sum_op); +template +__device__ __forceinline__ void reduce_sum(Tensor const& tensor, Tensor &sum){ + SumOp sum_op; + thread_reduce_(tensor, sum, sum_op); } // Apply the exp to all the elements. @@ -212,4 +212,168 @@ inline __device__ void apply_mask_causal_w_idx( } } +template +struct Softmax { + + using TensorT = decltype(make_tensor(Shape>{})); + TensorT row_max, row_sum; + + __forceinline__ __device__ Softmax() {}; + + template + __forceinline__ __device__ void softmax_rescale_o(Tensor0 &acc_s, Tensor1 &acc_o, float softmax_scale_log2) { + // Reshape acc_s from (MMA=4, MMA_M, MMA_N) to (nrow=(2, MMA_M), ncol=(2, MMA_N)) + Tensor scores = make_tensor(acc_s.data(), flash::convert_layout_acc_rowcol(acc_s.layout())); + static_assert(decltype(size<0>(scores))::value == kNRows); + if (Is_first) { + flash::template reduce_max(scores, row_max); + flash::scale_apply_exp2(scores, row_max, softmax_scale_log2); + flash::reduce_sum(scores, row_sum); + } else { + Tensor scores_max_prev = make_fragment_like(row_max); + cute::copy(row_max, scores_max_prev); + flash::template reduce_max(scores, row_max); + // Reshape acc_o from (MMA=4, MMA_M, MMA_K) to (nrow=(2, MMA_M), ncol=(2, MMA_K)) + Tensor acc_o_rowcol = make_tensor(acc_o.data(), flash::convert_layout_acc_rowcol(acc_o.layout())); + static_assert(decltype(size<0>(acc_o_rowcol))::value == kNRows); + #pragma unroll + for (int mi = 0; mi < size(row_max); ++mi) { + float scores_max_cur = !Check_inf + ? row_max(mi) + : (row_max(mi) == -INFINITY ? 0.0f : row_max(mi)); + float scores_scale = exp2f((scores_max_prev(mi) - scores_max_cur) * softmax_scale_log2); + row_sum(mi) *= scores_scale; + #pragma unroll + for (int ni = 0; ni < size<1>(acc_o_rowcol); ++ni) { acc_o_rowcol(mi, ni) *= scores_scale; } + } + flash::scale_apply_exp2(scores, row_max, softmax_scale_log2); + // We don't do the reduce across threads here since we don't need to use the row_sum. + // We do that reduce at the end when we need to normalize the softmax. + flash::reduce_sum(scores, row_sum); + } + }; + + template + __forceinline__ __device__ TensorT normalize_softmax_lse(Tensor0 &acc_o, float softmax_scale, float rp_dropout=1.0) { + SumOp sum_op; + quad_allreduce_(row_sum, row_sum, sum_op); + TensorT lse = make_fragment_like(row_sum); + Tensor acc_o_rowcol = make_tensor(acc_o.data(), flash::convert_layout_acc_rowcol(acc_o.layout())); + static_assert(decltype(size<0>(acc_o_rowcol))::value == kNRows); + #pragma unroll + for (int mi = 0; mi < size<0>(acc_o_rowcol); ++mi) { + float sum = row_sum(mi); + float inv_sum = (sum == 0.f || sum != sum) ? 1.f : 1.f / sum; + lse(mi) = (sum == 0.f || sum != sum) ? (Split ? -INFINITY : INFINITY) : row_max(mi) * softmax_scale + __logf(sum); + float scale = !Is_dropout ? inv_sum : inv_sum * rp_dropout; + #pragma unroll + for (int ni = 0; ni < size<1>(acc_o_rowcol); ++ni) { acc_o_rowcol(mi, ni) *= scale; } + } + return lse; + }; +}; + +template +struct Mask { + + const int max_seqlen_k, max_seqlen_q; + const int window_size_left, window_size_right; + const float alibi_slope; + + __forceinline__ __device__ Mask(const int max_seqlen_k, const int max_seqlen_q, + const int window_size_left, const int window_size_right, + const float alibi_slope=0.f) + : max_seqlen_k(max_seqlen_k) + , max_seqlen_q(max_seqlen_q) + , window_size_left(window_size_left) + , window_size_right(window_size_right) + , alibi_slope(!Has_alibi ? 0.0 : alibi_slope) { + }; + + // Causal_mask: whether this particular iteration needs causal masking + template + __forceinline__ __device__ void apply_mask(Tensor &tensor_, + const int col_idx_offset_, + const int row_idx_offset, + const int warp_row_stride) { + static_assert(!(Causal_mask && Is_local), "Cannot be both causal and local"); + static_assert(Layout::rank == 3, "Only support 3D Tensor"); + static_assert(decltype(size<0>(tensor_))::value == 4, "First dimension must be 4"); + static constexpr bool Need_masking = Has_alibi || Causal_mask || Is_local || !Is_even_MN; + // if (cute::thread0()) { printf("Has_alibi = %d, Causal_mask=%d, Is_local=%d, Is_even_MN = %d, Need_masking = %d\n", Has_alibi, Causal_mask, Is_local, Is_even_MN, Need_masking); } + if constexpr (Need_masking) { + // Reshape tensor_ from (MMA=4, MMA_M, MMA_N) to (nrow=(2, MMA_M), ncol=(2, MMA_N)) + Tensor tensor = make_tensor(tensor_.data(), flash::convert_layout_acc_rowcol(tensor_.layout())); + // Do we need both row and column indices, or just column incides? + static constexpr bool Col_idx_only = !(Has_alibi && !Is_causal) && !Is_local && !Causal_mask; + const int lane_id = threadIdx.x % 32; + const int col_idx_offset = col_idx_offset_ + (lane_id % 4) * 2; + if constexpr (Col_idx_only) { + #pragma unroll + for (int nj = 0; nj < size<1, 1>(tensor); ++nj) { + const int col_idx_base = col_idx_offset + nj * 8; + #pragma unroll + for (int j = 0; j < size<1, 0>(tensor); ++j) { + const int col_idx = col_idx_base + j; + #pragma unroll + for (int mi = 0; mi < size<0>(tensor); ++mi) { + // No causal, no local + if constexpr (Has_alibi) { + tensor(mi, make_coord(j, nj)) += alibi_slope * col_idx; + } + if constexpr (!Is_even_MN) { + if (col_idx >= max_seqlen_k) { tensor(mi, make_coord(j, nj)) = -INFINITY; } + } + } + } + } + } else { + #pragma unroll + for (int mi = 0; mi < size<0, 1>(tensor); ++mi) { + const int row_idx_base = row_idx_offset + mi * warp_row_stride; + #pragma unroll + for (int i = 0; i < size<0, 0>(tensor); ++i) { + const int row_idx = row_idx_base + i * 8; + const int col_idx_limit_left = std::max(0, row_idx + max_seqlen_k - max_seqlen_q - window_size_left); + const int col_idx_limit_right = std::min(max_seqlen_k, row_idx + 1 + max_seqlen_k - max_seqlen_q + window_size_right); + #pragma unroll + for (int nj = 0; nj < size<1, 1>(tensor); ++nj) { + const int col_idx_base = col_idx_offset + nj * 8; + #pragma unroll + for (int j = 0; j < size<1, 0>(tensor); ++j) { + const int col_idx = col_idx_base + j; + if constexpr (Has_alibi) { + if constexpr (Is_causal) { + tensor(make_coord(i, mi), make_coord(j, nj)) += alibi_slope * col_idx; + } else { + tensor(make_coord(i, mi), make_coord(j, nj)) -= alibi_slope * abs(row_idx + max_seqlen_k - max_seqlen_q - col_idx); + + } + } + if constexpr (Causal_mask) { + if (col_idx >= col_idx_limit_right) { + tensor(make_coord(i, mi), make_coord(j, nj)) = -INFINITY; + } + } + if constexpr (Is_local) { + if (col_idx >= col_idx_limit_right || col_idx < col_idx_limit_left) { + tensor(make_coord(i, mi), make_coord(j, nj)) = -INFINITY; + } + } + if constexpr (!Causal_mask && !Is_local && !Is_even_MN) { + // Causal and Local already handles MN masking + if (col_idx >= max_seqlen_k) { + tensor(make_coord(i, mi), make_coord(j, nj)) = -INFINITY; + } + } + } + } + } + } + } + } + }; + +}; + } // namespace flash diff --git a/operators/cuda/attention_lib/flash_attention/utils.h b/operators/cuda/attention_lib/flash_attention/utils.h index cd10bd534..f638a232a 100644 --- a/operators/cuda/attention_lib/flash_attention/utils.h +++ b/operators/cuda/attention_lib/flash_attention/utils.h @@ -198,6 +198,28 @@ inline __device__ void gemm_A_in_regs(Tensor0& acc, Tensor1& tCrA, Tensor2& tCrB //////////////////////////////////////////////////////////////////////////////////////////////////// +template +__forceinline__ __device__ void gemm_rs(Tensor0 &acc, Tensor1 &tCrA, Tensor2 &tCrB, Tensor3 const& tCsB, + TiledMma tiled_mma, TiledCopy smem_tiled_copy_B, + ThrCopy smem_thr_copy_B) { + CUTE_STATIC_ASSERT_V(size<1>(tCrA) == size<1>(acc)); // MMA_M + CUTE_STATIC_ASSERT_V(size<1>(tCrB) == size<2>(acc)); // MMA_N + CUTE_STATIC_ASSERT_V(size<2>(tCrA) == size<2>(tCrB)); // MMA_K + Tensor tCrB_copy_view = smem_thr_copy_B.retile_D(tCrB); + CUTE_STATIC_ASSERT_V(size<1>(tCsB) == size<1>(tCrB_copy_view)); // N + cute::copy(smem_tiled_copy_B, tCsB(_, _, _0{}), tCrB_copy_view(_, _, _0{})); + #pragma unroll + for (int i = 0; i < size<2>(tCrA); ++i) { + if (i < size<2>(tCrA) - 1) { + cute::copy(smem_tiled_copy_B, tCsB(_, _, i + 1), tCrB_copy_view(_, _, i + 1)); + } + cute::gemm(tiled_mma, tCrA(_, _, i), tCrB(_, _, i), acc); + } +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + // Convert acc_layout from (MMA=4, MMA_M, MMA_N) to (nrow=(2, MMA_M), ncol=(2, MMA_N)) template inline __device__ auto convert_layout_acc_rowcol(Layout acc_layout) { @@ -212,6 +234,25 @@ inline __device__ auto convert_layout_acc_rowcol(Layout acc_layout) { //////////////////////////////////////////////////////////////////////////////////////////////////// +// Convert acc_layout from (MMA=4, MMA_M, MMA_N) to ((4, 2), MMA_M, MMA_N / 2) +// if using m16n8k16, or to (4, MMA_M, MMA_N) if using m16n8k8. +template +__forceinline__ __device__ auto convert_layout_acc_Aregs(Layout acc_layout) { + using X = Underscore; + static_assert(decltype(size<0>(acc_layout))::value == 4); + static_assert(decltype(rank(acc_layout))::value == 3); + constexpr int mma_shape_K = get<2>(typename MMA_traits::Shape_MNK{}); + static_assert(mma_shape_K == 8 || mma_shape_K == 16); + if constexpr (mma_shape_K == 8) { + return acc_layout; + } else { + auto l = logical_divide(acc_layout, Shape{}); // (4, MMA_M, (2, MMA_N / 2))) + return make_layout(make_layout(get<0>(l), get<2, 0>(l)), get<1>(l), get<2, 1>(l)); + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + // Convert rowcol_layout from (nrow=(2, MMA_M), ncol=(2, MMA_N)) to ((2, 2, 2), MMA_M, MMA_N / 2) // if using m16n8k16, or to ((2, 2, 1), MMA_M, MMA_N) if using m16n8k8. template diff --git a/operators/math/dlib/stft_norm.hpp b/operators/math/dlib/stft_norm.hpp index 4a16b1d9c..9e8d59051 100644 --- a/operators/math/dlib/stft_norm.hpp +++ b/operators/math/dlib/stft_norm.hpp @@ -6,37 +6,32 @@ #include "ocos.h" #include -struct StftNormal{ +struct StftNormal { StftNormal() = default; OrtStatusPtr OnModelAttach(const OrtApi& api, const OrtKernelInfo& info) { return OrtW::GetOpAttribute(info, "onesided", onesided_); } - OrtStatusPtr Compute(const ortc::Tensor& input0, - int64_t n_fft, - int64_t hop_length, - const ortc::Span& input3, - int64_t frame_length, - ortc::Tensor& output0) const { + OrtxStatus Compute(const ortc::Tensor& input0, int64_t n_fft, int64_t hop_length, + const ortc::Span& input3, int64_t frame_length, ortc::Tensor& output0) const { auto X = input0.Data(); auto window = input3.data_; auto dimensions = input0.Shape(); auto win_length = input3.size(); if (dimensions.size() < 2 || input0.NumberOfElement() != dimensions[1]) { - return OrtW::CreateStatus("[Stft] Only batch == 1 tensor supported.", ORT_INVALID_ARGUMENT); + return {kOrtxErrorInvalidArgument, "[Stft] Only batch == 1 tensor supported."}; } if (frame_length != n_fft) { - return OrtW::CreateStatus("[Stft] Only support size of FFT equals the frame length.", ORT_INVALID_ARGUMENT); + return {kOrtxErrorInvalidArgument, "[Stft] Only support size of FFT equals the frame length."}; } dlib::matrix dm_x = dlib::mat(X, 1, dimensions[1]); dlib::matrix hann_win = dlib::mat(window, 1, win_length); - auto m_stft = dlib::stft( - dm_x, [&hann_win](size_t x, size_t len) { return hann_win(0, x); }, - n_fft, win_length, hop_length); + auto m_stft = + dlib::stft(dm_x, [&hann_win](size_t x, size_t len) { return hann_win(0, x); }, n_fft, win_length, hop_length); if (onesided_) { m_stft = dlib::subm(m_stft, 0, 0, m_stft.nr(), (m_stft.nc() >> 1) + 1); @@ -49,7 +44,7 @@ struct StftNormal{ auto out0 = output0.Allocate(outdim); memcpy(out0, result.steal_memory().get(), result_size * sizeof(float)); - return nullptr; + return {}; } private: diff --git a/shared/api/c_api_feature_extraction.cc b/shared/api/c_api_feature_extraction.cc new file mode 100644 index 000000000..8ffde2455 --- /dev/null +++ b/shared/api/c_api_feature_extraction.cc @@ -0,0 +1,77 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "speech_extractor.h" + +#include "c_api_utils.hpp" + +using namespace ort_extensions; + +class RawAudiosObject : public OrtxObjectImpl { + public: + RawAudiosObject() : OrtxObjectImpl(extObjectKind_t::kOrtxKindRawAudios) {} + ~RawAudiosObject() override = default; + + std::unique_ptr audios_; + size_t num_audios_; +}; + +extError_t ORTX_API_CALL OrtxLoadAudios(OrtxRawAudios** raw_audios, const char* const* audio_paths, size_t num_audios) { + if (raw_audios == nullptr || audio_paths == nullptr) { + ReturnableStatus::last_error_message_ = "Invalid argument"; + return kOrtxErrorInvalidArgument; + } + + auto audios_obj = std::make_unique(); + auto [audios, num] = + ort_extensions::LoadRawData(audio_paths, audio_paths + num_audios); + audios_obj->audios_ = std::move(audios); + audios_obj->num_audios_ = num; + + *raw_audios = static_cast(audios_obj.release()); + return extError_t(); +} + +extError_t ORTX_API_CALL OrtxCreateSpeechFeatureExtractor(OrtxFeatureExtractor** extractor, const char* def) { + if (extractor == nullptr || def == nullptr) { + ReturnableStatus::last_error_message_ = "Invalid argument"; + return kOrtxErrorInvalidArgument; + } + + auto extractor_ptr = std::make_unique(); + ReturnableStatus status = extractor_ptr->Init(def); + if (status.IsOk()) { + *extractor = static_cast(extractor_ptr.release()); + } else { + *extractor = nullptr; + } + + return status.Code(); +} + +extError_t ORTX_API_CALL OrtxSpeechLogMel(OrtxFeatureExtractor* extractor, OrtxRawAudios* raw_audios, + OrtxTensorResult** result) { + if (extractor == nullptr || raw_audios == nullptr || result == nullptr) { + ReturnableStatus::last_error_message_ = "Invalid argument"; + return kOrtxErrorInvalidArgument; + } + + auto extractor_ptr = static_cast(extractor); + auto audios_obj = static_cast(raw_audios); + + auto ts_result = std::make_unique(); + std::unique_ptr> log_mel[1]; + ReturnableStatus status = + extractor_ptr->DoCall(ort_extensions::span(audios_obj->audios_.get(), audios_obj->num_audios_), log_mel[0]); + if (status.IsOk()) { + std::vector> tensors; + std::transform(log_mel, log_mel + 1, std::back_inserter(tensors), + [](auto& ts) { return std::unique_ptr(ts.release()); }); + ts_result->SetTensors(std::move(tensors)); + *result = ts_result.release(); + } else { + *result = nullptr; + } + + return status.Code(); +} diff --git a/shared/api/c_api_processor.cc b/shared/api/c_api_processor.cc index 2beb90a13..8e2e12598 100644 --- a/shared/api/c_api_processor.cc +++ b/shared/api/c_api_processor.cc @@ -4,6 +4,8 @@ #include "ortx_processor.h" #include "image_processor.h" +#include "c_api_utils.hpp" + using namespace ort_extensions; extError_t OrtxCreateProcessor(OrtxProcessor** processor, const char* def) { @@ -37,19 +39,19 @@ extError_t ORTX_API_CALL OrtxLoadImages(OrtxRawImages** images, const char** ima } auto images_obj = std::make_unique(); - auto [img, num] = LoadRawImages(image_paths, image_paths + num_images); + auto [img, num] = LoadRawData(image_paths, image_paths + num_images); images_obj->images = std::move(img); images_obj->num_images = num; if (num_images_loaded != nullptr) { *num_images_loaded = num; } - + *images = static_cast(images_obj.release()); return extError_t(); } extError_t ORTX_API_CALL OrtxImagePreProcess(OrtxProcessor* processor, OrtxRawImages* images, - OrtxImageProcessorResult** result) { + OrtxTensorResult** result) { if (processor == nullptr || images == nullptr || result == nullptr) { ReturnableStatus::last_error_message_ = "Invalid argument"; return kOrtxErrorInvalidArgument; @@ -67,59 +69,14 @@ extError_t ORTX_API_CALL OrtxImagePreProcess(OrtxProcessor* processor, OrtxRawIm return status.Code(); } - auto result_ptr = std::make_unique(); + auto result_ptr = std::make_unique(); status = processor_ptr->PreProcess(ort_extensions::span(images_ptr->images.get(), images_ptr->num_images), *result_ptr); if (status.IsOk()) { - *result = static_cast(result_ptr.release()); + *result = static_cast(result_ptr.release()); } else { *result = nullptr; } return {}; } - -extError_t ORTX_API_CALL OrtxImageGetTensorResult(OrtxImageProcessorResult* result, size_t index, OrtxTensor** tensor) { - if (result == nullptr || tensor == nullptr) { - ReturnableStatus::last_error_message_ = "Invalid argument"; - return kOrtxErrorInvalidArgument; - } - - auto result_ptr = static_cast(result); - ReturnableStatus status(result_ptr->IsInstanceOf(extObjectKind_t::kOrtxKindImageProcessorResult)); - if (!status.IsOk()) { - return status.Code(); - } - - if (index >= result_ptr->results.size()) { - ReturnableStatus::last_error_message_ = "Index out of range"; - return kOrtxErrorInvalidArgument; - } - - auto tensor_ptr = std::make_unique>(); - tensor_ptr->SetObject(result_ptr->results[index].get()); - *tensor = static_cast(tensor_ptr.release()); - return extError_t(); -} - -extError_t ORTX_API_CALL OrtxClearOutputs(OrtxProcessor* processor, OrtxImageProcessorResult* result) { - if (processor == nullptr || result == nullptr) { - ReturnableStatus::last_error_message_ = "Invalid argument"; - return kOrtxErrorInvalidArgument; - } - - const auto processor_ptr = static_cast(processor); - ReturnableStatus status(processor_ptr->IsInstanceOf(extObjectKind_t::kOrtxKindProcessor)); - if (!status.IsOk()) { - return status.Code(); - } - - auto result_ptr = static_cast(result); - status = result_ptr->IsInstanceOf(extObjectKind_t::kOrtxKindImageProcessorResult); - if (!status.IsOk()) { - return status.Code(); - } - - ImageProcessor::ClearOutputs(result_ptr); - return extError_t(); -} diff --git a/shared/api/c_api_tokenizer.cc b/shared/api/c_api_tokenizer.cc index 22c24defc..e3d2fd8de 100644 --- a/shared/api/c_api_tokenizer.cc +++ b/shared/api/c_api_tokenizer.cc @@ -6,7 +6,7 @@ #include "c_api_utils.hpp" #include "tokenizer_impl.h" -namespace ort_extensions { +using namespace ort_extensions; class DetokenizerCache : public OrtxObjectImpl { public: @@ -17,29 +17,20 @@ class DetokenizerCache : public OrtxObjectImpl { std::string last_text_{}; // last detokenized text }; -template<> -OrtxObject* OrtxObjectFactory::CreateForward() { - return std::make_unique().release(); -} - -template<> -void OrtxObjectFactory::DisposeForward(OrtxObject* obj) { - Dispose(obj); +template <> +OrtxObject* OrtxObjectFactory::CreateForward() { + return Create(); } -} // namespace ort_extensions - -using namespace ort_extensions; -extError_t ORTX_API_CALL OrtxTokenize(const OrtxTokenizer* tokenizer, - const char* input[], size_t batch_size, OrtxTokenId2DArray** output) { +extError_t ORTX_API_CALL OrtxTokenize(const OrtxTokenizer* tokenizer, const char* input[], size_t batch_size, + OrtxTokenId2DArray** output) { if (tokenizer == nullptr || input == nullptr || output == nullptr) { ReturnableStatus::last_error_message_ = "Invalid argument"; return kOrtxErrorInvalidArgument; } auto token_ptr = static_cast(tokenizer); - ReturnableStatus status = - token_ptr->IsInstanceOf(extObjectKind_t::kOrtxKindTokenizer); + ReturnableStatus status = token_ptr->IsInstanceOf(extObjectKind_t::kOrtxKindTokenizer); if (!status.IsOk()) { return status.Code(); } @@ -61,8 +52,8 @@ extError_t ORTX_API_CALL OrtxTokenize(const OrtxTokenizer* tokenizer, return extError_t(); } -extError_t ORTX_API_CALL OrtxDetokenize(const OrtxTokenizer* tokenizer, - const OrtxTokenId2DArray* input, OrtxStringArray** output) { +extError_t ORTX_API_CALL OrtxDetokenize(const OrtxTokenizer* tokenizer, const OrtxTokenId2DArray* input, + OrtxStringArray** output) { if (tokenizer == nullptr || input == nullptr || output == nullptr) { ReturnableStatus::last_error_message_ = "Invalid argument"; return kOrtxErrorInvalidArgument; @@ -81,11 +72,8 @@ extError_t ORTX_API_CALL OrtxDetokenize(const OrtxTokenizer* tokenizer, } std::vector> t_ids; - std::transform(input_2d->token_ids().begin(), input_2d->token_ids().end(), - std::back_inserter(t_ids), - [](const std::vector& vec) { - return span(vec.data(), vec.size()); - }); + std::transform(input_2d->token_ids().begin(), input_2d->token_ids().end(), std::back_inserter(t_ids), + [](const std::vector& vec) { return span(vec.data(), vec.size()); }); std::vector output_text; status = token_ptr->Detokenize(t_ids, output_text); @@ -101,9 +89,7 @@ extError_t ORTX_API_CALL OrtxDetokenize(const OrtxTokenizer* tokenizer, ; } -extError_t ORTX_API_CALL OrtxDetokenize1D(const OrtxTokenizer* tokenizer, - const extTokenId_t* input, - size_t len, +extError_t ORTX_API_CALL OrtxDetokenize1D(const OrtxTokenizer* tokenizer, const extTokenId_t* input, size_t len, OrtxStringArray** output) { if (tokenizer == nullptr || input == nullptr || output == nullptr) { ReturnableStatus::last_error_message_ = "Invalid argument"; @@ -186,8 +172,8 @@ extError_t ORTX_API_CALL OrtxTokenId2DArrayGetBatch(const OrtxTokenId2DArray* to return extError_t(); } -extError_t ORTX_API_CALL OrtxTokenId2DArrayGetItem(const OrtxTokenId2DArray* token_id_2d_array, - size_t index, const extTokenId_t** item, size_t* length) { +extError_t ORTX_API_CALL OrtxTokenId2DArrayGetItem(const OrtxTokenId2DArray* token_id_2d_array, size_t index, + const extTokenId_t** item, size_t* length) { if (token_id_2d_array == nullptr || item == nullptr || length == nullptr) { ReturnableStatus::last_error_message_ = "Invalid argument"; return kOrtxErrorInvalidArgument; @@ -210,9 +196,8 @@ extError_t ORTX_API_CALL OrtxTokenId2DArrayGetItem(const OrtxTokenId2DArray* tok return extError_t(); } -extError_t OrtxDetokenizeCached(const OrtxTokenizer* tokenizer, - OrtxDetokenizerCache* cache, - extTokenId_t next_id, const char** text_out) { +extError_t OrtxDetokenizeCached(const OrtxTokenizer* tokenizer, OrtxDetokenizerCache* cache, extTokenId_t next_id, + const char** text_out) { if (tokenizer == nullptr || cache == nullptr || text_out == nullptr) { ReturnableStatus::last_error_message_ = "Invalid argument"; return kOrtxErrorInvalidArgument; diff --git a/shared/api/c_api_utils.cc b/shared/api/c_api_utils.cc index 0345fdb23..3fc376efe 100644 --- a/shared/api/c_api_utils.cc +++ b/shared/api/c_api_utils.cc @@ -10,6 +10,8 @@ using namespace ort_extensions; +class DetokenizerCache; // forward definition in tokenizer_impl.cc + thread_local std::string ReturnableStatus::last_error_message_; OrtxStatus OrtxObjectImpl::IsInstanceOf(extObjectKind_t kind) const { @@ -37,7 +39,7 @@ extError_t ORTX_API_CALL OrtxCreate(extObjectKind_t kind, OrtxObject** object, . va_start(args, object); if (kind == extObjectKind_t::kOrtxKindDetokenizerCache) { - *object = OrtxObjectFactory::CreateForward(); + *object = OrtxObjectFactory::CreateForward(); } else if (kind == extObjectKind_t::kOrtxKindTokenizer) { return OrtxCreateTokenizer(static_cast(object), va_arg(args, const char*)); } @@ -80,8 +82,8 @@ extError_t ORTX_API_CALL OrtxDisposeOnly(OrtxObject* object) { return kOrtxErrorInvalidArgument; } - if (Ortx_object->ortx_kind() == extObjectKind_t::kOrtxKindStringArray) { - OrtxObjectFactory::Dispose(object); + /* if (Ortx_object->ortx_kind() == extObjectKind_t::kOrtxKindStringArray) { + OrtxObjectFactory::Dispose(object); } else if (Ortx_object->ortx_kind() == extObjectKind_t::kOrtxKindTokenId2DArray) { OrtxObjectFactory::Dispose(object); } else if (Ortx_object->ortx_kind() == extObjectKind_t::kOrtxKindDetokenizerCache) { @@ -94,6 +96,11 @@ extError_t ORTX_API_CALL OrtxDisposeOnly(OrtxObject* object) { OrtxObjectFactory::Dispose(object); } else if (Ortx_object->ortx_kind() == extObjectKind_t::kOrtxKindProcessor) { OrtxObjectFactory::Dispose(object); + } */ + if (Ortx_object->ortx_kind() >= kOrtxKindBegin && Ortx_object->ortx_kind() < kOrtxKindEnd) { + OrtxObjectFactory::Dispose(object); + } else { + return kOrtxErrorInvalidArgument; } return extError_t(); @@ -113,6 +120,30 @@ extError_t ORTX_API_CALL OrtxDispose(OrtxObject** object) { return err; } +extError_t ORTX_API_CALL OrtxTensorResultGetAt(OrtxTensorResult* result, size_t index, OrtxTensor** tensor) { + if (result == nullptr || tensor == nullptr) { + ReturnableStatus::last_error_message_ = "Invalid argument"; + return kOrtxErrorInvalidArgument; + } + + auto result_ptr = static_cast(result); + ReturnableStatus status(result_ptr->IsInstanceOf(extObjectKind_t::kOrtxKindTensorResult)); + if (!status.IsOk()) { + return status.Code(); + } + + ortc::TensorBase* ts = result_ptr->GetAt(index); + if (ts == nullptr) { + ReturnableStatus::last_error_message_ = "Cannot get the tensor at the specified index from the result"; + return kOrtxErrorInvalidArgument; + } + + auto tensor_ptr = std::make_unique>(); + tensor_ptr->SetObject(ts); + *tensor = static_cast(tensor_ptr.release()); + return extError_t(); +} + extError_t ORTX_API_CALL OrtxGetTensorData(OrtxTensor* tensor, const void** data, const int64_t** shape, size_t* num_dims) { if (tensor == nullptr) { @@ -120,7 +151,7 @@ extError_t ORTX_API_CALL OrtxGetTensorData(OrtxTensor* tensor, const void** data return kOrtxErrorInvalidArgument; } - auto tensor_impl = static_cast*>(tensor); + auto tensor_impl = static_cast*>(tensor); if (tensor_impl->ortx_kind() != extObjectKind_t::kOrtxKindTensor) { ReturnableStatus::last_error_message_ = "Invalid argument"; return kOrtxErrorInvalidArgument; diff --git a/shared/api/c_api_utils.hpp b/shared/api/c_api_utils.hpp index d7794b610..c1fed4727 100644 --- a/shared/api/c_api_utils.hpp +++ b/shared/api/c_api_utils.hpp @@ -3,8 +3,10 @@ #pragma once #include +#include #include "ortx_utils.h" +#include "file_sys.h" #include "ext_status.h" #include "op_def_struct.h" @@ -12,7 +14,7 @@ namespace ort_extensions { class OrtxObjectImpl : public OrtxObject { public: explicit OrtxObjectImpl(extObjectKind_t kind = extObjectKind_t::kOrtxKindUnknown) : OrtxObject() { - ext_kind_ = static_cast(kind); + ext_kind_ = kind; }; virtual ~OrtxObjectImpl() = default; @@ -24,30 +26,21 @@ class OrtxObjectImpl : public OrtxObject { } return static_cast(ext_kind_); } - - template - struct Type2Kind { - static const extObjectKind_t value = kOrtxKindUnknown; - }; -}; - -template <> -struct OrtxObjectImpl::Type2Kind { - static const extObjectKind_t value = kOrtxKindTensor; }; -template +// A wrapper class to store a object pointer which is readonly. i.e. unowned. +template class OrtxObjectWrapper : public OrtxObjectImpl { public: - OrtxObjectWrapper() : OrtxObjectImpl(OrtxObjectImpl::Type2Kind::value) {} + OrtxObjectWrapper() : OrtxObjectImpl(kind) {} ~OrtxObjectWrapper() override = default; - void SetObject(T* t) { stored_object_ = t; } + void SetObject(const T* t) { stored_object_ = t; } - [[nodiscard]] T* GetObject() const { return stored_object_; } + [[nodiscard]] const T* GetObject() const { return stored_object_; } private: - T* stored_object_{}; + const T* stored_object_{}; }; template @@ -100,6 +93,35 @@ class StringArray : public OrtxObjectImpl { std::vector strings_; }; +class TensorResult : public OrtxObjectImpl { + public: + TensorResult() : OrtxObjectImpl(extObjectKind_t::kOrtxKindTensorResult) {} + ~TensorResult() override = default; + + void SetTensors(std::vector>&& tensors) { tensors_ = std::move(tensors); } + + [[nodiscard]] const std::vector>& tensors() const { return tensors_; } + + [[nodiscard]] std::vector GetTensors() const { + std::vector ts; + ts.reserve(tensors_.size()); + for (auto& t : tensors_) { + ts.push_back(t.get()); + } + return ts; + } + + ortc::TensorBase* GetAt(size_t i) const { + if (i < tensors_.size()) { + return tensors_[i].get(); + } + return nullptr; + } + + private: + std::vector> tensors_; +}; + struct ReturnableStatus { public: thread_local static std::string last_error_message_; @@ -123,24 +145,25 @@ struct ReturnableStatus { OrtxStatus status_; }; -template class OrtxObjectFactory { public: - static std::unique_ptr Create() { return std::make_unique(); } - - static OrtxObject* CreateForward(); - static void DisposeForward(OrtxObject* object); + template + static OrtxObject* Create() { + return std::make_unique().release(); + } + template static void Dispose(OrtxObject* object) { auto obj_ptr = static_cast(object); std::unique_ptr ptr(obj_ptr); ptr.reset(); } -}; - -class DetokenizerCache; // forward definition in tokenizer_impl.cc -class ProcessorResult; // forward definition in image_processor.h + // Forward declaration for creating an object which isn't visible to c_api_utils.cc + // and the definition is in the corresponding .cc file. + template + static OrtxObject* CreateForward(); +}; class CppAllocator : public ortc::IAllocator { public: @@ -157,4 +180,25 @@ class CppAllocator : public ortc::IAllocator { } }; +template +std::tuple, size_t> LoadRawData(It begin, It end) { + auto raw_data = std::make_unique(end - begin); + size_t n = 0; + for (auto it = begin; it != end; ++it) { + std::ifstream ifs = path(*it).open(std::ios::binary | std::ios::in); + if (!ifs.is_open()) { + break; + } + + ifs.seekg(0, std::ios::end); + size_t size = ifs.tellg(); + ifs.seekg(0, std::ios::beg); + + T& datum = raw_data[n++]; + datum.resize(size); + ifs.read(reinterpret_cast(datum.data()), size); + } + + return std::make_tuple(std::move(raw_data), n); +} } // namespace ort_extensions diff --git a/shared/api/image_processor.cc b/shared/api/image_processor.cc index 1cbab6e10..5383c10fe 100644 --- a/shared/api/image_processor.cc +++ b/shared/api/image_processor.cc @@ -7,6 +7,7 @@ #include "file_sys.h" #include "image_processor.h" +#include "c_api_utils.hpp" #include "cv2/imgcodecs/imdecode.hpp" #include "image_transforms.hpp" #include "image_transforms_phi_3.hpp" @@ -14,38 +15,11 @@ using namespace ort_extensions; using json = nlohmann::json; -namespace ort_extensions { -template -std::tuple, size_t> LoadRawImages(It begin, It end) { - auto raw_images = std::make_unique(end - begin); - size_t n = 0; - for (auto it = begin; it != end; ++it) { - std::ifstream ifs = path(*it).open(std::ios::binary); - if (!ifs.is_open()) { - break; - } - - ifs.seekg(0, std::ios::end); - size_t size = ifs.tellg(); - ifs.seekg(0, std::ios::beg); - - ImageRawData& raw_image = raw_images[n++]; - raw_image.resize(size); - ifs.read(reinterpret_cast(raw_image.data()), size); - } - - return std::make_tuple(std::move(raw_images), n); -} - -std::tuple, size_t> LoadRawImages( - const std::initializer_list& image_paths) { - return LoadRawImages(image_paths.begin(), image_paths.end()); +std::tuple, size_t> +ort_extensions::LoadRawImages(const std::initializer_list& image_paths) { + return ort_extensions::LoadRawData(image_paths.begin(), image_paths.end()); } -template std::tuple, size_t> LoadRawImages(char const**, char const**); - -} // namespace ort_extensions - Operation::KernelRegistry ImageProcessor::kernel_registry_ = { {"DecodeImage", []() { return CreateKernelInstance(image_decoder); }}, {"Resize", []() { return CreateKernelInstance(&Resize::Compute); }}, @@ -97,9 +71,7 @@ OrtxStatus ImageProcessor::Init(std::string_view processor_def) { return {}; } -ImageProcessor::ImageProcessor() - : OrtxObjectImpl(kOrtxKindProcessor), allocator_(&CppAllocator::Instance()) { -} +ImageProcessor::ImageProcessor() : OrtxObjectImpl(kOrtxKindProcessor), allocator_(&CppAllocator::Instance()) {} template static ortc::Tensor* StackTensor(const std::vector& arg_lists, int axis, ortc::IAllocator* allocator) { @@ -136,39 +108,6 @@ static ortc::Tensor* StackTensor(const std::vector& arg_lists, in return output.release(); } -static OrtxStatus StackTensors(const std::vector& arg_lists, std::vector& outputs, - ortc::IAllocator* allocator) { - if (arg_lists.empty()) { - return {}; - } - - size_t batch_size = arg_lists.size(); - size_t num_outputs = arg_lists[0].size(); - for (size_t axis = 0; axis < num_outputs; ++axis) { - std::vector ts_ptrs; - ts_ptrs.reserve(arg_lists.size()); - std::vector shape = arg_lists[0][axis]->Shape(); - for (auto& ts : arg_lists) { - if (shape != ts[axis]->Shape()) { - return {kOrtxErrorInvalidArgument, "[StackTensors]: shapes of tensors to stack are not the same."}; - } - ts_ptrs.push_back(ts[axis]); - } - - std::vector output_shape = shape; - output_shape.insert(output_shape.begin(), batch_size); - std::byte* tensor_buf = outputs[axis]->AllocateRaw(output_shape); - for (size_t i = 0; i < batch_size; ++i) { - auto ts = ts_ptrs[i]; - const std::byte* ts_buff = reinterpret_cast(ts->DataRaw()); - auto ts_size = ts->SizeInBytes(); - std::memcpy(tensor_buf + i * ts_size, ts_buff, ts_size); - } - } - - return {}; -} - std::tuple ImageProcessor::PreProcess(ort_extensions::span image_data, ortc::Tensor** pixel_values, ortc::Tensor** image_sizes, @@ -209,7 +148,7 @@ std::tuple ImageProcessor::PreProcess(ort_extension return {status, std::move(r)}; } -OrtxStatus ImageProcessor::PreProcess(ort_extensions::span image_data, ImageProcessorResult& r) const { +OrtxStatus ImageProcessor::PreProcess(ort_extensions::span image_data, TensorResult& r) const { std::vector inputs; inputs.resize(image_data.size()); for (size_t i = 0; i < image_data.size(); ++i) { @@ -235,9 +174,13 @@ OrtxStatus ImageProcessor::PreProcess(ort_extensions::span image_d } } - r.results = operations_.back()->AllocateOutputs(allocator_); - status = StackTensors(outputs, r.results, allocator_); + auto img_result = operations_.back()->AllocateOutputs(allocator_); + status = OrtxRunner::StackTensors(outputs, img_result, allocator_); operations_.back()->ResetTensors(allocator_); + if (status.IsOk()) { + r.SetTensors(std::move(img_result)); + } + return status; } @@ -257,14 +200,3 @@ void ImageProcessor::ClearOutputs(ProcessorResult* r) { r->num_img_takens = nullptr; } } - -void ort_extensions::ImageProcessor::ClearOutputs(ImageProcessorResult* r) { - if (r == nullptr) { - return; - } - - for (auto& ts : r->results) { - ts.reset(); - } - r->results.clear(); // clear the vector -} diff --git a/shared/api/image_processor.h b/shared/api/image_processor.h index 534e811d6..02eee4e32 100644 --- a/shared/api/image_processor.h +++ b/shared/api/image_processor.h @@ -16,9 +16,6 @@ namespace ort_extensions { using ImageRawData = std::vector; -template -std::tuple, size_t> LoadRawImages(It begin, It end); - std::tuple, size_t> LoadRawImages( const std::initializer_list& image_paths); @@ -29,13 +26,6 @@ class ProcessorResult : public OrtxObjectImpl { ortc::Tensor* image_sizes{}; ortc::Tensor* num_img_takens{}; }; - -class ImageProcessorResult : public OrtxObjectImpl { - public: - ImageProcessorResult() : OrtxObjectImpl(kOrtxKindImageProcessorResult) {} - std::vector results; -}; - class ImageProcessor : public OrtxObjectImpl { public: ImageProcessor(); @@ -43,15 +33,16 @@ class ImageProcessor : public OrtxObjectImpl { OrtxStatus Init(std::string_view processor_def); + // Deprecated, using the next function instead std::tuple PreProcess(ort_extensions::span image_data, ortc::Tensor** pixel_values, ortc::Tensor** image_sizes, ortc::Tensor** num_img_takens) const; - OrtxStatus PreProcess(ort_extensions::span image_data, ImageProcessorResult& r) const; + OrtxStatus PreProcess(ort_extensions::span image_data, TensorResult& r) const; + // Deprecated, using the next function instead static void ClearOutputs(ProcessorResult* r); - static void ClearOutputs(ImageProcessorResult* r); static Operation::KernelRegistry kernel_registry_; diff --git a/shared/api/runner.hpp b/shared/api/runner.hpp index 3590190bb..1b6a01ddf 100644 --- a/shared/api/runner.hpp +++ b/shared/api/runner.hpp @@ -28,7 +28,8 @@ class KernelDef { virtual TensorArgs AllocateOutput(ortc::IAllocator* allocator) const = 0; virtual OrtxStatus Apply(TensorArgs& inputs, TensorArgs& output) const = 0; - using AttrType = std::variant>; + using AttrType = + std::variant, std::vector, std::vector>; using AttrDict = std::unordered_map; template @@ -98,7 +99,7 @@ class KernelDef { template class KernelFunction : public KernelDef { public: - KernelFunction(OrtxStatus (*body)(Args...)) : body_(body){}; + KernelFunction(OrtxStatus (*body)(Args...)) : body_(body) {}; virtual ~KernelFunction() = default; TensorArgs AllocateOutput(ortc::IAllocator* allocator) const override { @@ -132,7 +133,7 @@ class KernelFunction : public KernelDef { template class KernelStruct : public KernelDef { public: - KernelStruct(OrtxStatus (T::*body)(Args...)) : body_(body){}; + KernelStruct(OrtxStatus (T::*body)(Args...)) : body_(body) {}; virtual ~KernelStruct() = default; TensorArgs AllocateOutput(ortc::IAllocator* allocator) const override { @@ -167,8 +168,18 @@ class KernelStruct : public KernelDef { attr_dict[key] = value.template get(); } else if (value.is_number_float()) { attr_dict[key] = value.template get(); - } else if (value.is_array()) { - attr_dict[key] = value.template get>(); + } else if (value.is_array() && value.size() > 0) { + auto& elem_0 = value.at(0); + if (elem_0.is_number_float()) { + attr_dict[key] = value.template get>(); + } else if (elem_0.is_string()) { + attr_dict[key] = value.template get>(); + } else if (elem_0.is_number_integer() || elem_0.is_number_unsigned()) { + attr_dict[key] = value.template get>(); + } else { + return {kOrtxErrorCorruptData, "Unsupported mix types in attribute value."}; + } + } else { return {kOrtxErrorCorruptData, "Invalid attribute type."}; } @@ -309,6 +320,39 @@ class OrtxRunner { return {}; } + static OrtxStatus StackTensors(const std::vector& arg_lists, std::vector& outputs, + ortc::IAllocator* allocator) { + if (arg_lists.empty()) { + return {}; + } + + size_t batch_size = arg_lists.size(); + size_t num_outputs = arg_lists[0].size(); + for (size_t axis = 0; axis < num_outputs; ++axis) { + std::vector ts_ptrs; + ts_ptrs.reserve(arg_lists.size()); + std::vector shape = arg_lists[0][axis]->Shape(); + for (auto& ts : arg_lists) { + if (shape != ts[axis]->Shape()) { + return {kOrtxErrorInvalidArgument, "[StackTensors]: shapes of tensors to stack are not the same."}; + } + ts_ptrs.push_back(ts[axis]); + } + + std::vector output_shape = shape; + output_shape.insert(output_shape.begin(), batch_size); + std::byte* tensor_buf = outputs[axis]->AllocateRaw(output_shape); + for (size_t i = 0; i < batch_size; ++i) { + auto ts = ts_ptrs[i]; + const std::byte* ts_buff = reinterpret_cast(ts->DataRaw()); + auto ts_size = ts->SizeInBytes(); + std::memcpy(tensor_buf + i * ts_size, ts_buff, ts_size); + } + } + + return {}; + } + private: ortc::IAllocator* allocator_; std::vector ops_; diff --git a/shared/api/speech_extractor.cc b/shared/api/speech_extractor.cc new file mode 100644 index 000000000..5cd005f8a --- /dev/null +++ b/shared/api/speech_extractor.cc @@ -0,0 +1,97 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "speech_extractor.h" + +#include "audio/audio_decoder.h" +#include "speech_features.hpp" + +using namespace ort_extensions; + +Operation::KernelRegistry SpeechFeatureExtractor::kernel_registry_ = { + {"AudioDecoder", []() { return CreateKernelInstance(&AudioDecoder::ComputeNoOpt); }}, + {"STFTNorm", []() { return CreateKernelInstance(&SpeechFeatures::STFTNorm); }}, + {"LogMelSpectrum", []() { return CreateKernelInstance(&LogMel::Compute); }}, +}; + +SpeechFeatureExtractor::SpeechFeatureExtractor() + : OrtxObjectImpl(extObjectKind_t::kOrtxKindFeatureExtractor), allocator_(&CppAllocator::Instance()) {} + +OrtxStatus SpeechFeatureExtractor::Init(std::string_view extractor_def) { + std::string fe_def_str; + if (extractor_def.size() >= 5 && extractor_def.substr(extractor_def.size() - 5) == ".json") { + std::ifstream ifs = path({extractor_def.data(), extractor_def.size()}).open(); + if (!ifs.is_open()) { + return {kOrtxErrorInvalidArgument, std::string("[ImageProcessor]: failed to open ") + std::string(extractor_def)}; + } + fe_def_str = std::string(std::istreambuf_iterator(ifs), std::istreambuf_iterator()); + extractor_def = fe_def_str.c_str(); + } + + // pase the extraction_def by json + auto fe_json = json::parse(extractor_def, nullptr, false); + if (fe_json.is_discarded()) { + return {kOrtxErrorInvalidArgument, "[SpeechFeatureExtractor]: failed to parse extractor json configuration."}; + } + + auto fe_root = fe_json.at("feature_extraction"); + if (!fe_root.is_object()) { + return {kOrtxErrorInvalidArgument, "[SpeechFeatureExtractor]: feature_extraction field is missing."}; + } + + auto op_sequence = fe_root.at("sequence"); + if (!op_sequence.is_array() || op_sequence.empty()) { + return {kOrtxErrorInvalidArgument, "[SpeechFeatureExtractor]: sequence field is missing."}; + } + + operations_.reserve(op_sequence.size()); + for (auto mod_iter = op_sequence.begin(); mod_iter != op_sequence.end(); ++mod_iter) { + auto op = std::make_unique(kernel_registry_); + auto status = op->Init(mod_iter->dump()); + if (!status.IsOk()) { + return status; + } + + operations_.push_back(std::move(op)); + } + + return {}; +} + +OrtxStatus SpeechFeatureExtractor::DoCall(ort_extensions::span raw_speech, + std::unique_ptr>& log_mel) const { + // setup the input tensors + std::vector inputs; + inputs.resize(raw_speech.size()); + for (size_t i = 0; i < raw_speech.size(); ++i) { + auto& ts_input = inputs[i]; + AudioRawData& speech = raw_speech[i]; + std::vector shape = {static_cast(speech.size())}; + ts_input.push_back(std::make_unique>(shape, speech.data()).release()); + } + + std::vector outputs; + std::vector ops(operations_.size()); + std::transform(operations_.begin(), operations_.end(), ops.begin(), [](auto& op) { return op.get(); }); + OrtxRunner runner(allocator_, ops.data(), ops.size()); + auto status = runner.Run(inputs, outputs); + if (!status.IsOk()) { + return status; + } + + // clear the input tensors + for (auto& input : inputs) { + for (auto& ts : input) { + std::unique_ptr(ts).reset(); + } + } + + auto results = operations_.back()->AllocateOutputs(allocator_); + status = OrtxRunner::StackTensors(outputs, results, allocator_); + if (status.IsOk()) { + log_mel.reset(static_cast*>(results[0].release())); + operations_.back()->ResetTensors(allocator_); + } + + return status; +} diff --git a/shared/api/speech_extractor.h b/shared/api/speech_extractor.h new file mode 100644 index 000000000..3219da6eb --- /dev/null +++ b/shared/api/speech_extractor.h @@ -0,0 +1,33 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include "ortx_extractor.h" +#include "c_api_utils.hpp" +#include "runner.hpp" + + +namespace ort_extensions { + +typedef std::vector AudioRawData; + +class SpeechFeatureExtractor : public OrtxObjectImpl { + public: + SpeechFeatureExtractor(); + + virtual ~SpeechFeatureExtractor() = default; + + public: + OrtxStatus Init(std::string_view extractor_def); + + OrtxStatus DoCall(ort_extensions::span raw_speech, std::unique_ptr>& log_mel) const; + + static Operation::KernelRegistry kernel_registry_; + + private: + std::vector> operations_; + ortc::IAllocator* allocator_; +}; + +} // namespace ort_extensions diff --git a/shared/api/speech_features.hpp b/shared/api/speech_features.hpp new file mode 100644 index 000000000..acc368a12 --- /dev/null +++ b/shared/api/speech_features.hpp @@ -0,0 +1,224 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include +#include + +#ifndef M_PI +#define M_PI 3.14159265358979323846 +#endif + +namespace ort_extensions { + +class SpeechFeatures { + public: + template + OrtxStatus Init(const DictT& attrs) { + for (const auto& [key, value] : attrs) { + if (key == "n_fft") { + n_fft_ = std::get(value); + } else if (key == "hop_length") { + hop_length_ = std::get(value); + } else if (key == "frame_length") { + frame_length_ = std::get(value); + } else if (key == "hann_win") { + auto& win = std::get>(value); + hann_win_.resize(win.size()); + std::transform(win.begin(), win.end(), hann_win_.begin(), [](double x) { return static_cast(x); }); + } else if (key != "_comment") { + return {kOrtxErrorInvalidArgument, "[AudioFeatures]: Invalid key in the JSON configuration."}; + } + } + + if (hann_win_.empty()) { + hann_win_ = hann_window(frame_length_); + } + return {}; + } + + OrtxStatus STFTNorm(const ortc::Tensor& pcm, ortc::Tensor& stft_norm) { + return stft_norm_.Compute(pcm, n_fft_, hop_length_, {hann_win_.data(), hann_win_.size()}, frame_length_, stft_norm); + } + + static std::vector hann_window(int N) { + std::vector window(N); + + for (int n = 0; n < N; ++n) { + // this formula leads to more rounding errors than the one below + // window[n] = static_cast(0.5 * (1 - std::cos(2 * M_PI * n / (N - 1)))); + double n_sin = std::sin(M_PI * n / N); + window[n] = static_cast(n_sin * n_sin); + } + + return window; + } + + private: + StftNormal stft_norm_; + int64_t n_fft_{}; + int64_t hop_length_{}; + int64_t frame_length_{}; + std::vector hann_win_; +}; + +class LogMel { + public: + template + OrtxStatus Init(const DictT& attrs) { + int n_fft = 0; + int n_mel = 0; + int chunk_size = 0; + for (const auto& [key, value] : attrs) { + if (key == "hop_length") { + hop_length_ = std::get(value); + } else if (key == "n_fft") { + n_fft = std::get(value); + } else if (key == "n_mel") { + n_mel = std::get(value); + } else if (key == "chunk_size") { + chunk_size = std::get(value); + } else { + return {kOrtxErrorInvalidArgument, "[LogMel]: Invalid key in the JSON configuration."}; + } + } + + n_samples_ = n_sr_ * chunk_size; + mel_filters_ = MelFilterBank(n_fft, n_mel, n_sr_); + return {}; + } + + OrtxStatus Compute(const ortc::Tensor& stft_norm, ortc::Tensor& logmel) { + // Compute the Mel spectrogram by following Python code + /* + magnitudes = stft_norm[:, :, :-1] + mel_spec = self.mel_filters @ magnitudes + log_spec = torch.clamp(mel_spec, min=1e-10).log10() + spec_min = log_spec.max() - 8.0 + log_spec = torch.maximum(log_spec, spec_min) + spec_shape = log_spec.shape + padding_spec = torch.ones(spec_shape[0], + spec_shape[1], + self.n_samples // self.hop_length - spec_shape[2], + dtype=torch.float) + padding_spec *= spec_min + log_spec = torch.cat((log_spec, padding_spec), dim=2) + log_spec = (log_spec + 4.0) / 4.0 + return log_spec + */ + assert(stft_norm.Shape().size() == 3 && stft_norm.Shape()[0] == 1); + std::vector stft_shape = stft_norm.Shape(); + dlib::matrix magnitudes(stft_norm.Shape()[1], stft_norm.Shape()[2] - 1); + for (int i = 0; i < magnitudes.nr(); ++i) { + std::copy(stft_norm.Data() + i * stft_shape[2], stft_norm.Data() + (i + 1) * stft_shape[2] - 1, + magnitudes.begin() + i * magnitudes.nc()); + } + + dlib::matrix mel_spec = mel_filters_ * magnitudes; + for (int i = 0; i < mel_spec.nr(); ++i) { + for (int j = 0; j < mel_spec.nc(); ++j) { + mel_spec(i, j) = std::max(1e-10f, mel_spec(i, j)); + } + } + + dlib::matrix log_spec = dlib::log10(mel_spec); + float log_spec_min = dlib::max(log_spec) - 8.0f; + for (int i = 0; i < log_spec.nr(); ++i) { + for (int j = 0; j < log_spec.nc(); ++j) { + float v = std::max(log_spec(i, j), log_spec_min); + v = (v + 4.0f) / 4.0f; + log_spec(i, j) = v; + } + } + + std::vector shape = {mel_filters_.nr(), n_samples_ / hop_length_}; + float* buff = logmel.Allocate(shape); + std::fill(buff, buff + logmel.NumberOfElement(), (log_spec_min + 4.0f) / 4.0f); + for (int i = 0; i < log_spec.nr(); ++i) { + auto row_len = log_spec.nc() * i; + std::copy(log_spec.begin() + i * log_spec.nc(), log_spec.begin() + (i + 1) * log_spec.nc(), buff + i * shape[1]); + } + + return {}; + } + + // Function to compute the Mel filterbank + static dlib::matrix MelFilterBank(int n_fft, int n_mels, int sr = 16000, float min_mel = 0, + float max_mel = 45.245640471924965) { + // Initialize the filterbank matrix + dlib::matrix fbank(n_mels, n_fft / 2 + 1); + memset(fbank.begin(), 0, fbank.size() * sizeof(float)); + + // Compute the frequency bins for the DFT + std::vector freq_bins(n_fft / 2 + 1); + for (int i = 0; i <= n_fft / 2; ++i) { + freq_bins[i] = i * sr / static_cast(n_fft); + } + + // Compute the Mel scale frequencies + std::vector mel(n_mels + 2); + for (int i = 0; i < n_mels + 2; ++i) { + mel[i] = min_mel + i * (max_mel - min_mel) / (n_mels + 1); + } + + // Fill in the linear scale + float f_min = 0.0f; + float f_sp = 200.0f / 3.0f; + std::vector freqs(n_mels + 2); + for (int i = 0; i < n_mels + 2; ++i) { + freqs[i] = f_min + f_sp * mel[i]; + } + + // Nonlinear scale + float min_log_hz = 1000.0f; + float min_log_mel = (min_log_hz - f_min) / f_sp; + float logstep = log(6.4) / 27.0; + + for (int i = 0; i < n_mels + 2; ++i) { + if (mel[i] >= min_log_mel) { + freqs[i] = min_log_hz * exp(logstep * (mel[i] - min_log_mel)); + } + } + + std::vector mel_bins = freqs; + std::vector mel_spacing(n_mels + 1); + for (int i = 0; i < n_mels + 1; ++i) { + mel_spacing[i] = mel_bins[i + 1] - mel_bins[i]; + } + + // Compute the ramps + std::vector> ramps(n_mels + 2, std::vector(n_fft / 2 + 1)); + for (int i = 0; i < n_mels + 2; ++i) { + for (int j = 0; j <= n_fft / 2; ++j) { + ramps[i][j] = mel_bins[i] - freq_bins[j]; + } + } + + for (int i = 0; i < n_mels; ++i) { + for (int j = 0; j <= n_fft / 2; ++j) { + float left = -ramps[i][j] / mel_spacing[i]; + float right = ramps[i + 2][j] / mel_spacing[i + 1]; + fbank(i, j) = std::max(0.0f, std::min(left, right)); + } + } + + // Energy normalization + for (int i = 0; i < n_mels; ++i) { + float energy_norm = 2.0f / (mel_bins[i + 2] - mel_bins[i]); + for (int j = 0; j <= n_fft / 2; ++j) { + fbank(i, j) *= energy_norm; + } + } + + return fbank; + } + + private: + int64_t n_samples_ = {}; // sr * chunk_size + int64_t hop_length_{}; + const int64_t n_sr_{16000}; + dlib::matrix mel_filters_; +}; + +} // namespace ort_extensions diff --git a/test/data/whisper/feature_extraction.json b/test/data/whisper/feature_extraction.json new file mode 100644 index 000000000..f7deaac0f --- /dev/null +++ b/test/data/whisper/feature_extraction.json @@ -0,0 +1,437 @@ +{ + "feature_extraction": { + "sequence": [ + { + "operation": { + "name": "audio_decoder", + "type": "AudioDecoder" + } + }, + { + "operation": { + "name": "STFT", + "type": "STFTNorm", + "attrs": { + "n_fft": 400, + "frame_length": 400, + "hop_length": 160, + "_comment": [ + 0.0, + 0.0000616908073425293, + 0.0002467334270477295, + 0.0005550682544708252, + 0.000986635684967041, + 0.0015413463115692139, + 0.0022190213203430176, + 0.0030195116996765137, + 0.003942638635635376, + 0.004988163709640503, + 0.006155818700790405, + 0.007445335388183594, + 0.008856385946273804, + 0.010388582944869995, + 0.012041628360748291, + 0.013815045356750488, + 0.01570841670036316, + 0.01772129535675049, + 0.019853144884109497, + 0.022103488445281982, + 0.02447172999382019, + 0.026957333087921143, + 0.029559612274169922, + 0.03227800130844116, + 0.03511175513267517, + 0.03806024789810181, + 0.0411226749420166, + 0.044298380613327026, + 0.04758647084236145, + 0.05098623037338257, + 0.05449673533439636, + 0.058117181062698364, + 0.06184667348861694, + 0.0656842589378357, + 0.06962898373603821, + 0.07367992401123047, + 0.0778360664844513, + 0.08209633827209473, + 0.08645972609519958, + 0.09092515707015991, + 0.09549149870872498, + 0.10015767812728882, + 0.10492250323295593, + 0.1097848117351532, + 0.11474338173866272, + 0.11979702115058899, + 0.12494447827339172, + 0.13018447160720825, + 0.1355157196521759, + 0.14093685150146484, + 0.1464466154575348, + 0.15204361081123352, + 0.1577264666557312, + 0.16349375247955322, + 0.16934409737586975, + 0.1752760112285614, + 0.18128803372383118, + 0.18737870454788208, + 0.19354650378227234, + 0.1997898817062378, + 0.20610737800598145, + 0.21249738335609436, + 0.21895831823349, + 0.2254886031150818, + 0.23208662867546082, + 0.23875075578689575, + 0.24547931551933289, + 0.2522706985473633, + 0.25912320613861084, + 0.26603513956069946, + 0.27300477027893066, + 0.2800304591655731, + 0.2871103882789612, + 0.29424285888671875, + 0.30142611265182495, + 0.30865830183029175, + 0.31593772768974304, + 0.3232625722885132, + 0.3306310474872589, + 0.3380413055419922, + 0.34549152851104736, + 0.352979838848114, + 0.3605044484138489, + 0.3680635094642639, + 0.37565508484840393, + 0.38327735662460327, + 0.3909284174442291, + 0.39860638976097107, + 0.4063093662261963, + 0.41403549909591675, + 0.42178282141685486, + 0.4295494258403778, + 0.43733343482017517, + 0.44513291120529175, + 0.45294591784477234, + 0.46077051758766174, + 0.46860480308532715, + 0.4764467775821686, + 0.4842946231365204, + 0.492146372795105, + 0.5, + 0.5078536868095398, + 0.515705406665802, + 0.5235532522201538, + 0.5313953161239624, + 0.5392295718193054, + 0.5470541715621948, + 0.5548672080039978, + 0.562666654586792, + 0.5704506635665894, + 0.5782172679901123, + 0.5859646201133728, + 0.5936906933784485, + 0.6013936996459961, + 0.609071671962738, + 0.6167227625846863, + 0.6243450045585632, + 0.6319366097450256, + 0.6394955515861511, + 0.6470202207565308, + 0.6545085310935974, + 0.6619587540626526, + 0.6693689823150635, + 0.6767374277114868, + 0.6840623021125793, + 0.691341757774353, + 0.6985740065574646, + 0.7057572603225708, + 0.7128896713256836, + 0.719969630241394, + 0.7269952893257141, + 0.7339649796485901, + 0.7408769130706787, + 0.7477294206619263, + 0.7545207738876343, + 0.761249303817749, + 0.7679134607315063, + 0.774511456489563, + 0.7810417413711548, + 0.7875027060508728, + 0.7938927412033081, + 0.800210177898407, + 0.8064535856246948, + 0.8126214146614075, + 0.8187121152877808, + 0.8247240781784058, + 0.8306560516357422, + 0.8365063667297363, + 0.8422735929489136, + 0.8479564785957336, + 0.8535534143447876, + 0.8590631484985352, + 0.8644843101501465, + 0.8698155879974365, + 0.8750555515289307, + 0.8802030086517334, + 0.8852566480636597, + 0.8902152180671692, + 0.8950775265693665, + 0.899842381477356, + 0.9045084714889526, + 0.9090749025344849, + 0.9135403037071228, + 0.9179036617279053, + 0.9221639633178711, + 0.9263200759887695, + 0.9303710460662842, + 0.9343158006668091, + 0.9381533861160278, + 0.941882848739624, + 0.945503294467926, + 0.9490138292312622, + 0.9524135589599609, + 0.9557017087936401, + 0.9588773250579834, + 0.961939811706543, + 0.9648882746696472, + 0.9677220582962036, + 0.9704403877258301, + 0.9730427265167236, + 0.9755282998085022, + 0.9778965711593628, + 0.9801468849182129, + 0.9822787046432495, + 0.9842916131019592, + 0.9861849546432495, + 0.9879584312438965, + 0.9896113872528076, + 0.9911436438560486, + 0.9925546646118164, + 0.9938441514968872, + 0.9950118064880371, + 0.996057391166687, + 0.9969804883003235, + 0.997780978679657, + 0.9984586238861084, + 0.999013364315033, + 0.9994449615478516, + 0.9997532367706299, + 0.9999383091926575, + 1, + 0.9999383091926575, + 0.9997532367706299, + 0.9994449615478516, + 0.999013364315033, + 0.9984586238861084, + 0.997780978679657, + 0.9969804286956787, + 0.9960573315620422, + 0.9950118064880371, + 0.9938441514968872, + 0.9925546646118164, + 0.9911435842514038, + 0.9896113872528076, + 0.9879583716392517, + 0.9861849546432495, + 0.9842915534973145, + 0.9822787046432495, + 0.9801468253135681, + 0.9778964519500732, + 0.9755282402038574, + 0.9730426073074341, + 0.9704403877258301, + 0.9677219390869141, + 0.9648882150650024, + 0.9619396924972534, + 0.9588772654533386, + 0.9557015895843506, + 0.9524134397506714, + 0.9490137100219727, + 0.9455032348632812, + 0.9418827295303345, + 0.9381532669067383, + 0.9343156814575195, + 0.9303709268569946, + 0.9263200759887695, + 0.9221639633178711, + 0.9179036617279053, + 0.913540244102478, + 0.9090747833251953, + 0.9045084714889526, + 0.8998422622680664, + 0.8950774669647217, + 0.8902151584625244, + 0.8852565884590149, + 0.8802029490470886, + 0.8750554919242859, + 0.869815468788147, + 0.8644842505455017, + 0.8590630888938904, + 0.853553295135498, + 0.8479562997817993, + 0.842273473739624, + 0.836506187915802, + 0.8306558728218079, + 0.8247239589691162, + 0.8187118768692017, + 0.8126212358474731, + 0.8064534664154053, + 0.8002099990844727, + 0.793892502784729, + 0.7875025272369385, + 0.7810416221618652, + 0.7745113372802734, + 0.767913281917572, + 0.7612491846084595, + 0.7545205950737, + 0.7477291822433472, + 0.7408767342567444, + 0.7339648008346558, + 0.7269951105117798, + 0.7199694514274597, + 0.7128894925117493, + 0.7057570219039917, + 0.6985738277435303, + 0.6913415789604187, + 0.684062123298645, + 0.6767372488975525, + 0.6693688035011292, + 0.6619585752487183, + 0.6545083522796631, + 0.6470199823379517, + 0.6394953727722168, + 0.6319363117218018, + 0.6243447661399841, + 0.6167224645614624, + 0.6090714335441589, + 0.601393461227417, + 0.5936904549598694, + 0.5859643220901489, + 0.5782170295715332, + 0.5704504251480103, + 0.5626664161682129, + 0.5548669099807739, + 0.5470539331436157, + 0.5392293334007263, + 0.5313950181007385, + 0.5235530138015747, + 0.5157051682472229, + 0.507853627204895, + 0.5, + 0.4921463429927826, + 0.484294593334198, + 0.4764467477798462, + 0.46860471367836, + 0.4607704281806946, + 0.4529458284378052, + 0.4451328217983246, + 0.437333345413208, + 0.42954933643341064, + 0.4217827320098877, + 0.4140354096889496, + 0.4063093066215515, + 0.3986063003540039, + 0.39092832803726196, + 0.3832772672176361, + 0.37565499544143677, + 0.36806342005729675, + 0.3605043888092041, + 0.35297977924346924, + 0.3454914391040802, + 0.338041216135025, + 0.33063095808029175, + 0.3232625126838684, + 0.3159376382827759, + 0.3086581826210022, + 0.3014259934425354, + 0.2942427396774292, + 0.28711026906967163, + 0.2800303101539612, + 0.2730046510696411, + 0.2660350203514099, + 0.2591230869293213, + 0.25227057933807373, + 0.24547919631004333, + 0.2387506067752838, + 0.23208650946617126, + 0.22548848390579224, + 0.21895819902420044, + 0.2124972641468048, + 0.2061072587966919, + 0.19978976249694824, + 0.1935463547706604, + 0.18737855553627014, + 0.18128788471221924, + 0.17527586221694946, + 0.1693439483642578, + 0.16349363327026367, + 0.15772631764411926, + 0.15204349160194397, + 0.14644649624824524, + 0.1409367322921753, + 0.13551557064056396, + 0.1301843225955963, + 0.12494435906410217, + 0.11979690194129944, + 0.11474326252937317, + 0.10978469252586365, + 0.10492238402366638, + 0.10015755891799927, + 0.09549137949943542, + 0.09092503786087036, + 0.08645960688591003, + 0.08209621906280518, + 0.07783591747283936, + 0.07367980480194092, + 0.06962886452674866, + 0.06568413972854614, + 0.06184655427932739, + 0.0581170916557312, + 0.0544966459274292, + 0.05098611116409302, + 0.04758638143539429, + 0.044298261404037476, + 0.04112258553504944, + 0.038060128688812256, + 0.03511166572570801, + 0.03227788209915161, + 0.02955952286720276, + 0.02695724368095398, + 0.024471670389175415, + 0.02210339903831482, + 0.01985308527946472, + 0.017721205949783325, + 0.015708357095718384, + 0.0138150155544281, + 0.012041598558425903, + 0.010388582944869995, + 0.008856356143951416, + 0.007445335388183594, + 0.006155818700790405, + 0.004988163709640503, + 0.003942638635635376, + 0.0030195116996765137, + 0.0022190213203430176, + 0.0015413165092468262, + 0.000986635684967041, + 0.0005550682544708252, + 0.0002467334270477295, + 0.0000616908073425293 + ] + } + } + }, + { + "operation": { + "name": "log_mel_spectrogram", + "type": "LogMelSpectrum", + "attrs": { + "chunk_size": 30, + "hop_length": 160, + "n_fft": 400, + "n_mel": 80 + } + } + } + ] + } +} \ No newline at end of file diff --git a/test/pp_api_test/c_only_test.h b/test/pp_api_test/c_only_test.h index ed414ace6..1a20b8612 100644 --- a/test/pp_api_test/c_only_test.h +++ b/test/pp_api_test/c_only_test.h @@ -4,7 +4,9 @@ #pragma once #include "ortx_tokenizer.h" +// make sure the C only compiler compatibility only. #include "ortx_processor.h" +#include "ortx_extractor.h" #ifdef __cplusplus diff --git a/test/pp_api_test/test_feature_extraction.cc b/test/pp_api_test/test_feature_extraction.cc new file mode 100644 index 000000000..9c3076daf --- /dev/null +++ b/test/pp_api_test/test_feature_extraction.cc @@ -0,0 +1,39 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include +#include +#include +#include + +#include "gtest/gtest.h" +#include "ortx_cpp_helper.h" +#include "shared/api/speech_extractor.h" + +using namespace ort_extensions; + +TEST(ExtractorTest, TestWhisperFeatureExtraction) { + const char* audio_path[] = {"data/jfk.flac", "data/1272-141231-0002.wav", "data/1272-141231-0002.mp3"}; + OrtxObjectPtr raw_audios; + extError_t err = OrtxLoadAudios(ort_extensions::ptr(raw_audios), audio_path, 3); + ASSERT_EQ(err, kOrtxOK); + + OrtxObjectPtr feature_extractor(OrtxCreateSpeechFeatureExtractor, "data/whisper/feature_extraction.json"); + OrtxObjectPtr result; + err = OrtxSpeechLogMel(feature_extractor.get(), raw_audios.get(), ort_extensions::ptr(result)); + ASSERT_EQ(err, kOrtxOK); + + OrtxObjectPtr tensor; + err = OrtxTensorResultGetAt(result.get(), 0, ort_extensions::ptr(tensor)); + ASSERT_EQ(err, kOrtxOK); + + const float* data{}; + const int64_t* shape{}; + size_t num_dims; + err = OrtxGetTensorDataFloat(tensor.get(), &data, &shape, &num_dims); + ASSERT_EQ(err, kOrtxOK); + ASSERT_EQ(num_dims, 3); + ASSERT_EQ(shape[0], 3); + ASSERT_EQ(shape[1], 80); + ASSERT_EQ(shape[2], 3000); +} diff --git a/test/pp_api_test/test_processor.cc b/test/pp_api_test/test_processor.cc index df06e54e8..1fcf132db 100644 --- a/test/pp_api_test/test_processor.cc +++ b/test/pp_api_test/test_processor.cc @@ -7,7 +7,7 @@ #include #include "gtest/gtest.h" -#include "ortx_c_helper.h" +#include "ortx_cpp_helper.h" #include "shared/api/image_processor.h" using namespace ort_extensions; @@ -85,18 +85,18 @@ TEST(ProcessorTest, TestClipImageProcessing) { } ASSERT_EQ(err, kOrtxOK); - OrtxObjectPtr result; + OrtxObjectPtr result; err = OrtxImagePreProcess(processor.get(), raw_images.get(), ort_extensions::ptr(result)); ASSERT_EQ(err, kOrtxOK); - OrtxObjectPtr tensor; - err = OrtxImageGetTensorResult(result.get(), 0, ort_extensions::ptr(tensor)); + OrtxTensor* tensor; + err = OrtxTensorResultGetAt(result.get(), 0, &tensor); ASSERT_EQ(err, kOrtxOK); const float* data{}; const int64_t* shape{}; size_t num_dims; - err = OrtxGetTensorDataFloat(tensor.get(), &data, &shape, &num_dims); + err = OrtxGetTensorDataFloat(tensor, &data, &shape, &num_dims); ASSERT_EQ(err, kOrtxOK); ASSERT_EQ(num_dims, 4); } diff --git a/test/static_test/test_tenor_api.cc b/test/static_test/test_tensor_api.cc similarity index 100% rename from test/static_test/test_tenor_api.cc rename to test/static_test/test_tensor_api.cc diff --git a/test/test_processing.py b/test/test_processing.py index 8552ef722..90940fb8d 100644 --- a/test/test_processing.py +++ b/test/test_processing.py @@ -92,17 +92,17 @@ def test_gpt2_preprocessing(self): merges_file=util.get_test_data_file("data", "gpt2.merges.txt"), ) inputs = tok.forward(test_sentence) - pnp.export(tok, test_sentence, opset_version=12, output_path="temp_tok2.onnx") + pnp.export(tok, test_sentence, opset_version=14, output_path="temp_tok2.onnx") with open("temp_gpt2lmh.onnx", "wb") as f: torch.onnx.export( - gpt2_m, inputs, f, opset_version=12, do_constant_folding=False + gpt2_m, inputs, f, opset_version=14, do_constant_folding=False ) - pnp.export(gpt2_m, *inputs, opset_version=12, do_constant_folding=False) + pnp.export(gpt2_m, *inputs, opset_version=14, do_constant_folding=False) full_model = pnp.SequentialProcessingModule(tok, gpt2_m) expected = full_model.forward(test_sentence) model = pnp.export( - full_model, test_sentence, opset_version=12, do_constant_folding=False + full_model, test_sentence, opset_version=14, do_constant_folding=False ) mfunc = OrtPyFunction.from_model(model) actuals = mfunc(test_sentence)