diff --git a/CITATION.cff b/CITATION.cff
index 82bcac5a7b750..10b7290022aef 100644
--- a/CITATION.cff
+++ b/CITATION.cff
@@ -3,8 +3,7 @@ title: ONNX Runtime
message: "Please use this information to cite ONNX Runtime in
research or other publications."
authors:
- - affiliation: Microsoft Corporation
- given-names: ONNX Runtime developers
+ - name: ONNX Runtime developers
date-released: 2018-11-29
url: "https://onnxruntime.ai"
repository-code: "https://github.com/microsoft/onnxruntime"
diff --git a/cmake/adjust_global_compile_flags.cmake b/cmake/adjust_global_compile_flags.cmake
index a56864ebf4644..8161ea574b8cc 100644
--- a/cmake/adjust_global_compile_flags.cmake
+++ b/cmake/adjust_global_compile_flags.cmake
@@ -92,13 +92,8 @@ if (onnxruntime_MINIMAL_BUILD)
endif()
endif()
-# Enable stream for all the non-minimal build, except for DML. There's currently a bug
-# in the allocation planner when reusing buffers and more than one streams are used that
-# make it possible (although rarely) to reach a reference count of 0 for a buffer that is
-# still being used. Since DML doesn't benefit from multiple streams, disabling it is the
-# safest option for now.
-# https://github.com/microsoft/onnxruntime/issues/19480
-if (NOT onnxruntime_MINIMAL_BUILD AND NOT onnxruntime_USE_DML)
+# Enable stream for all the non-minimal build
+if (NOT onnxruntime_MINIMAL_BUILD)
add_compile_definitions(ORT_ENABLE_STREAM)
endif()
diff --git a/cmake/onnxruntime_rocm_hipify.cmake b/cmake/onnxruntime_rocm_hipify.cmake
index 85a9bf50460d3..1bb70e9c2ed27 100644
--- a/cmake/onnxruntime_rocm_hipify.cmake
+++ b/cmake/onnxruntime_rocm_hipify.cmake
@@ -20,10 +20,6 @@ set(contrib_ops_excluded_files
"bert/fastertransformer_decoder_attention/*"
"bert/multihead_attention.cc"
"bert/multihead_attention.h"
- "bert/fast_gelu_impl.cu"
- "bert/fast_gelu_impl.h"
- "bert/fast_gelu.cc"
- "bert/fast_gelu.h"
"bert/relative_attn_bias.cc"
"bert/relative_attn_bias.h"
"bert/relative_attn_bias_impl.cu"
diff --git a/csharp/test/Microsoft.ML.OnnxRuntime.Tests.NetCoreApp/InferenceTest.netcore.cs b/csharp/test/Microsoft.ML.OnnxRuntime.Tests.NetCoreApp/InferenceTest.netcore.cs
index 715aed7e1d64f..7f3d5d6624b07 100644
--- a/csharp/test/Microsoft.ML.OnnxRuntime.Tests.NetCoreApp/InferenceTest.netcore.cs
+++ b/csharp/test/Microsoft.ML.OnnxRuntime.Tests.NetCoreApp/InferenceTest.netcore.cs
@@ -145,7 +145,7 @@ private void TestCUDAProviderOptions()
private void CanRunInferenceOnAModelWithTensorRT()
{
string modelPath = Path.Combine(Directory.GetCurrentDirectory(), "squeezenet.onnx");
-
+
int deviceId = 0;
string deviceIdStr = System.Environment.GetEnvironmentVariable("ONNXRUNTIME_TEST_GPU_DEVICE_ID");
if (!string.IsNullOrEmpty(deviceIdStr) && int.TryParse(deviceIdStr, out int parsedValue) && parsedValue >= 0)
diff --git a/docs/OperatorKernels.md b/docs/OperatorKernels.md
index 8ff2135c6b1f6..b0ed68d595c42 100644
--- a/docs/OperatorKernels.md
+++ b/docs/OperatorKernels.md
@@ -127,6 +127,7 @@ Do not modify directly.*
|GatherND|*in* data:**T**
*in* indices:**tensor(int64)**
*out* output:**T**|13+|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**indices** = tensor(int64)|
|||12|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**indices** = tensor(int64)|
|||11|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**indices** = tensor(int64)|
+|Gelu|*in* X:**T**
*out* Y:**T**|20+|**T** = tensor(float)|
|Gemm|*in* A:**T**
*in* B:**T**
*in* C:**T**
*out* Y:**T**|13+|**T** = tensor(double), tensor(float)|
|||[11, 12]|**T** = tensor(double), tensor(float)|
|||[9, 10]|**T** = tensor(double), tensor(float)|
@@ -606,6 +607,7 @@ Do not modify directly.*
|GatherND|*in* data:**T**
*in* indices:**tensor(int64)**
*out* output:**T**|13+|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int64)
**indices** = tensor(int64)|
|||12|**T** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int64)
**indices** = tensor(int64)|
|||11|**T** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int64)
**indices** = tensor(int64)|
+|Gelu|*in* X:**T**
*out* Y:**T**|20+|**T** = tensor(double), tensor(float), tensor(float16)|
|Gemm|*in* A:**T**
*in* B:**T**
*in* C:**T**
*out* Y:**T**|13+|**T** = tensor(bfloat16), tensor(double), tensor(float), tensor(float16)|
|||[11, 12]|**T** = tensor(double), tensor(float), tensor(float16)|
|||[9, 10]|**T** = tensor(double), tensor(float), tensor(float16)|
@@ -617,6 +619,7 @@ Do not modify directly.*
|||[7, 8]|**T** = tensor(double), tensor(float), tensor(float16)|
|GreaterOrEqual|*in* A:**T**
*in* B:**T**
*out* C:**T1**|16+|**T** = tensor(double), tensor(float), tensor(float16), tensor(int32), tensor(int64), tensor(uint32), tensor(uint64)
**T1** = tensor(bool)|
|||[12, 15]|**T** = tensor(double), tensor(float), tensor(float16), tensor(int32), tensor(int64), tensor(uint32), tensor(uint64)
**T1** = tensor(bool)|
+|GridSample|*in* X:**T1**
*in* grid:**T2**
*out* Y:**T1**|16+|**T1** = tensor(float)
**T2** = tensor(float)|
|HardSigmoid|*in* X:**T**
*out* Y:**T**|6+|**T** = tensor(double), tensor(float), tensor(float16)|
|Identity|*in* input:**T**
*out* output:**T**
or
*in* input:**V**
*out* output:**V**|19+|**V** = seq(tensor(bfloat16)), seq(tensor(bool)), seq(tensor(double)), seq(tensor(float)), seq(tensor(float16)), seq(tensor(float8e4m3fn)), seq(tensor(float8e4m3fnuz)), seq(tensor(float8e5m2)), seq(tensor(float8e5m2fnuz)), seq(tensor(int16)), seq(tensor(int32)), seq(tensor(int64)), seq(tensor(int8)), seq(tensor(uint16)), seq(tensor(uint32)), seq(tensor(uint64)), seq(tensor(uint8)), tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(float8e4m3fn), tensor(float8e4m3fnuz), tensor(float8e5m2), tensor(float8e5m2fnuz), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
|||[14, 18]|**V** = seq(tensor(bfloat16)), seq(tensor(bool)), seq(tensor(double)), seq(tensor(float)), seq(tensor(float16)), seq(tensor(int16)), seq(tensor(int32)), seq(tensor(int64)), seq(tensor(int8)), seq(tensor(uint16)), seq(tensor(uint32)), seq(tensor(uint64)), seq(tensor(uint8)), tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
diff --git a/include/onnxruntime/core/framework/execution_provider.h b/include/onnxruntime/core/framework/execution_provider.h
index 31c988f500779..c1cc69edc17d8 100644
--- a/include/onnxruntime/core/framework/execution_provider.h
+++ b/include/onnxruntime/core/framework/execution_provider.h
@@ -33,6 +33,8 @@ class Node;
#include "core/framework/stream_handles.h"
#include "core/framework/tuning_context.h"
+struct OrtRunOptions;
+
namespace onnxruntime {
/**
@@ -51,6 +53,8 @@ struct NodeComputeInfo {
DestroyFunctionStateFunc release_state_func;
};
+using RunOptions = OrtRunOptions;
+
enum class DataLayout {
NCHW,
NHWC,
@@ -184,7 +188,7 @@ class IExecutionProvider {
Run may not be finished on device This function should be regarded as the
point after which a new Run would start to submit commands from CPU
*/
- virtual common::Status OnRunStart() { return Status::OK(); }
+ virtual common::Status OnRunStart(const onnxruntime::RunOptions& /*run_options*/) { return Status::OK(); }
/**
Called when InferenceSession::Run ended
@@ -192,7 +196,9 @@ class IExecutionProvider {
may not be finished on device This function should be regarded as the point
that all commands of current Run has been submmited by CPU
*/
- virtual common::Status OnRunEnd(bool /*sync_stream*/) { return Status::OK(); }
+ virtual common::Status OnRunEnd(bool /*sync_stream*/, const onnxruntime::RunOptions& /*run_options*/) {
+ return Status::OK();
+ }
/**
Indicate whether the graph capturing mode (e.g., cuda graph) is enabled for
diff --git a/include/onnxruntime/core/providers/cuda/cuda_resource.h b/include/onnxruntime/core/providers/cuda/cuda_resource.h
index 1fef077860be3..00e7dec5727d1 100644
--- a/include/onnxruntime/core/providers/cuda/cuda_resource.h
+++ b/include/onnxruntime/core/providers/cuda/cuda_resource.h
@@ -19,4 +19,4 @@ enum CudaResource : int {
enable_skip_layer_norm_strict_mode_t,
prefer_nhwc_t,
use_tf32_t,
-};
\ No newline at end of file
+};
diff --git a/include/onnxruntime/core/session/onnxruntime_run_options_config_keys.h b/include/onnxruntime/core/session/onnxruntime_run_options_config_keys.h
index 1f5fcd50e185c..b0a17e175fef3 100644
--- a/include/onnxruntime/core/session/onnxruntime_run_options_config_keys.h
+++ b/include/onnxruntime/core/session/onnxruntime_run_options_config_keys.h
@@ -30,3 +30,15 @@ static const char* const kOrtRunOptionsConfigEnableMemoryArenaShrinkage = "memor
// Per default it will be set to '0'
// Taking CUDA EP as an example, it omit triggering cudaStreamSynchronize on the compute stream.
static const char* const kOrtRunOptionsConfigDisableSynchronizeExecutionProviders = "disable_synchronize_execution_providers";
+
+// Set HTP performance mode for QNN HTP backend before session run.
+// options for HTP performance mode: "burst", "balanced", "default", "high_performance",
+// "high_power_saver", "low_balanced", "extreme_power_saver", "low_power_saver", "power_saver",
+// "sustained_high_performance". Default to "default".
+static const char* const kOrtRunOptionsConfigQnnPerfMode = "qnn.htp_perf_mode";
+
+// Set HTP performance mode for QNN HTP backend post session run.
+static const char* const kOrtRunOptionsConfigQnnPerfModePostRun = "qnn.htp_perf_mode_post_run";
+
+// Set RPC control latency for QNN HTP backend
+static const char* const kOrtRunOptionsConfigQnnRpcControlLatency = "qnn.rpc_control_latency";
diff --git a/js/node/lib/backend.ts b/js/node/lib/backend.ts
index e8eb0e9babf5a..927953b4f1dd6 100644
--- a/js/node/lib/backend.ts
+++ b/js/node/lib/backend.ts
@@ -36,7 +36,7 @@ class OnnxruntimeSessionHandler implements InferenceSessionHandler {
async run(feeds: SessionHandler.FeedsType, fetches: SessionHandler.FetchesType, options: InferenceSession.RunOptions):
Promise {
return new Promise((resolve, reject) => {
- process.nextTick(() => {
+ setImmediate(() => {
try {
resolve(this.#inferenceSession.run(feeds, fetches, options));
} catch (e) {
@@ -56,7 +56,7 @@ class OnnxruntimeBackend implements Backend {
async createInferenceSessionHandler(pathOrBuffer: string|Uint8Array, options?: InferenceSession.SessionOptions):
Promise {
return new Promise((resolve, reject) => {
- process.nextTick(() => {
+ setImmediate(() => {
try {
resolve(new OnnxruntimeSessionHandler(pathOrBuffer, options || {}));
} catch (e) {
diff --git a/js/react_native/e2e/yarn.lock b/js/react_native/e2e/yarn.lock
index 9e20a286c4e27..6f05faf046098 100644
--- a/js/react_native/e2e/yarn.lock
+++ b/js/react_native/e2e/yarn.lock
@@ -3351,9 +3351,9 @@ invariant@^2.2.4:
loose-envify "^1.0.0"
ip@^1.1.5:
- version "1.1.8"
- resolved "https://registry.yarnpkg.com/ip/-/ip-1.1.8.tgz#ae05948f6b075435ed3307acce04629da8cdbf48"
- integrity sha512-PuExPYUiu6qMBQb4l06ecm6T6ujzhmh+MeJcW9wa89PoAz5pvd4zPgN5WJV104mb6S2T1AwNIAaB70JNrLQWhg==
+ version "1.1.9"
+ resolved "https://registry.yarnpkg.com/ip/-/ip-1.1.9.tgz#8dfbcc99a754d07f425310b86a99546b1151e396"
+ integrity sha512-cyRxvOEpNHNtchU3Ln9KC/auJgup87llfQpQ+t5ghoC/UhL16SWzbueiCsdTnWmqAWl7LadfuwhlqmtOaqMHdQ==
is-accessor-descriptor@^0.1.6:
version "0.1.6"
diff --git a/js/react_native/yarn.lock b/js/react_native/yarn.lock
index 4dca90d7415cf..bbb0c4f3d1e22 100644
--- a/js/react_native/yarn.lock
+++ b/js/react_native/yarn.lock
@@ -3701,9 +3701,9 @@ invariant@^2.2.4:
loose-envify "^1.0.0"
ip@^1.1.5:
- version "1.1.8"
- resolved "https://registry.yarnpkg.com/ip/-/ip-1.1.8.tgz#ae05948f6b075435ed3307acce04629da8cdbf48"
- integrity sha512-PuExPYUiu6qMBQb4l06ecm6T6ujzhmh+MeJcW9wa89PoAz5pvd4zPgN5WJV104mb6S2T1AwNIAaB70JNrLQWhg==
+ version "1.1.9"
+ resolved "https://registry.yarnpkg.com/ip/-/ip-1.1.9.tgz#8dfbcc99a754d07f425310b86a99546b1151e396"
+ integrity sha512-cyRxvOEpNHNtchU3Ln9KC/auJgup87llfQpQ+t5ghoC/UhL16SWzbueiCsdTnWmqAWl7LadfuwhlqmtOaqMHdQ==
is-absolute@^1.0.0:
version "1.0.0"
diff --git a/js/web/README.md b/js/web/README.md
index c75a40ad6da28..906c78a1b7ec4 100644
--- a/js/web/README.md
+++ b/js/web/README.md
@@ -12,7 +12,7 @@ The [Open Neural Network Exchange](http://onnx.ai/) (ONNX) is an open standard f
With ONNX Runtime Web, web developers can score models directly on browsers with various benefits including reducing server-client communication and protecting user privacy, as well as offering install-free and cross-platform in-browser ML experience.
-ONNX Runtime Web can run on both CPU and GPU. On CPU side, [WebAssembly](https://developer.mozilla.org/en-US/docs/WebAssembly) is adopted to execute the model at near-native speed. ONNX Runtime Web complies the native ONNX Runtime CPU engine into WebAssembly backend by using Emscripten, so it supports most functionalities native ONNX Runtime offers, including full ONNX operator coverage, multi-threading, [ONNX Runtime Quantization](https://www.onnxruntime.ai/docs/how-to/quantization.html) as well as [ONNX Runtime Mobile](https://onnxruntime.ai/docs/tutorials/mobile/). For performance acceleration with GPUs, ONNX Runtime Web leverages WebGL, a popular standard for accessing GPU capabilities. We are keeping improving op coverage and optimizing performance in WebGL backend.
+ONNX Runtime Web can run on both CPU and GPU. On CPU side, [WebAssembly](https://developer.mozilla.org/en-US/docs/WebAssembly) is adopted to execute the model at near-native speed. ONNX Runtime Web compiles the native ONNX Runtime CPU engine into WebAssembly backend by using Emscripten, so it supports most functionalities native ONNX Runtime offers, including full ONNX operator coverage, multi-threading, [ONNX Runtime Quantization](https://www.onnxruntime.ai/docs/how-to/quantization.html) as well as [ONNX Runtime Mobile](https://onnxruntime.ai/docs/tutorials/mobile/). For performance acceleration with GPUs, ONNX Runtime Web leverages WebGL, a popular standard for accessing GPU capabilities. We are keeping improving op coverage and optimizing performance in WebGL backend.
See [Compatibility](#Compatibility) and [Operators Supported](#Operators) for a list of platforms and operators ONNX Runtime Web currently supports.
@@ -22,7 +22,7 @@ Refer to [ONNX Runtime JavaScript examples](https://github.com/microsoft/onnxrun
## Documents
-### Developement
+### Development
Refer to the following links for development information:
diff --git a/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv_backprop_mm_webgpu.ts b/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv_backprop_mm_webgpu.ts
index b5b6a2a15cd8c..11c8778b72335 100644
--- a/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv_backprop_mm_webgpu.ts
+++ b/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv_backprop_mm_webgpu.ts
@@ -23,17 +23,17 @@ import {DataType} from '../../../../wasm-common';
import {LOG_DEBUG} from '../../../log';
import {TensorView} from '../../../tensor-view';
import {ProgramInfo, ProgramInputTensorInfoDependency, ProgramUniform} from '../../types';
-import {createTensorShapeVariables, inputVariable, outputVariable, ShaderHelper, UniformsArrayType} from '../common';
+import {createTensorShapeVariables, inputVariable, outputVariable, ShaderHelper, tensorTypeToWsglStorageType, UniformsArrayType} from '../common';
import {ConvTransposeAttributes} from '../conv-transpose';
import {appendActivationUniforms, appendActivationUniformsData, getActivationSnippet} from '../fuse-utils';
-import {biasSnippet, typeSnippet} from './activation_util';
+import {biasSnippet} from './activation_util';
import {utilFunctions} from './conv_util';
import {makeMatMulPackedSource, makeMatMulPackedVec4Source} from './matmul_packed_webgpu';
const conv2dTransposeCommonSnippet =
- (isChannelsLast: boolean, addBias = false, attributes: ConvTransposeAttributes, innerElementSize = 4): string => {
- const type = typeSnippet(innerElementSize, 'f32');
+ (isChannelsLast: boolean, addBias = false, attributes: ConvTransposeAttributes, type: string,
+ innerElementSize = 4): string => {
const getWSnippet = (innerElementSize: number) => {
switch (innerElementSize) {
case 1:
@@ -47,7 +47,7 @@ const conv2dTransposeCommonSnippet =
let v1 = w[getIndexFromCoords4D(coord1, vec4(uniforms.w_shape))];
let v2 = w[getIndexFromCoords4D(coord2, vec4(uniforms.w_shape))];
let v3 = w[getIndexFromCoords4D(coord3, vec4(uniforms.w_shape))];
- return vec4(v0, v1, v2, v3);
+ return ${type}(v0, v1, v2, v3);
`;
default:
throw new Error(`innerElementSize ${innerElementSize} is not supported.`);
@@ -224,7 +224,7 @@ export const createConv2DTransposeMatMulProgramInfo =
const bias = inputVariable('bias', inputs[2].dataType, inputs[2].dims.length, components);
inputVariables.push(bias);
declareFunctions += `
- fn getBiasByOutputCoords(coords : vec4) -> ${isVec4 ? 'vec4' : 'f32'} {
+ fn getBiasByOutputCoords(coords : vec4) -> ${bias.type.value} {
return bias[coords.${isChannelsLast ? 'w' : 'y'}${isVec4 ? '/ 4' : ''}];
}`;
}
@@ -236,16 +236,20 @@ export const createConv2DTransposeMatMulProgramInfo =
{name: 'pads', type: 'i32', length: pads.length}
];
appendActivationUniforms(attributes, uniforms);
+ const elemType = tensorTypeToWsglStorageType(inputs[0].dataType, 1);
+ if (elemType !== 'f16' && elemType !== 'f32') {
+ throw new Error(`elemType ${elemType} is not supported.`);
+ }
return `
${utilFunctions('uniforms.result_strides')}
${shaderHelper.registerUniforms(uniforms).declareVariables(...inputVariables, output)};
${declareFunctions}
- ${conv2dTransposeCommonSnippet(isChannelsLast, hasBias, attributes, innerElementSize)}
+ ${conv2dTransposeCommonSnippet(isChannelsLast, hasBias, attributes, x.type.value, innerElementSize)}
${
isVec4 ? makeMatMulPackedVec4Source(
- elementsPerThread, workGroupSize, 'f32', undefined, !isChannelsLast, tileInner) :
+ elementsPerThread, workGroupSize, elemType, undefined, !isChannelsLast, tileInner) :
makeMatMulPackedSource(
- elementsPerThread, workGroupSize, 'f32', undefined, !isChannelsLast, tileInner, false,
+ elementsPerThread, workGroupSize, elemType, undefined, !isChannelsLast, tileInner, false,
undefined, sequentialAccessByThreads)}`;
};
diff --git a/js/web/lib/wasm/jsep/webgpu/ops/where.ts b/js/web/lib/wasm/jsep/webgpu/ops/where.ts
index cfee07a9239d7..a6375847fc42f 100644
--- a/js/web/lib/wasm/jsep/webgpu/ops/where.ts
+++ b/js/web/lib/wasm/jsep/webgpu/ops/where.ts
@@ -27,7 +27,7 @@ const createWhereOpProgramShader =
const expressionA = `a_data[index_a${x}][component_a${x}]`;
const expressionB = `b_data[index_b${x}][component_b${x}]`;
// eslint-disable-next-line no-bitwise
- const expressionC = `bool(c_data[index_c${x}] & ${0xff000000 >>> ((3 - x) * 8)}u)`;
+ const expressionC = `bool(c_data[index_c${x}] & (0xffu << (component_c${x} * 8)))`;
return `
let output_indices${x} = ${output.offsetToIndices(`global_idx * 4u + ${x}u`)};
let offset_a${x} = ${a.broadcastedIndicesToOffset(`output_indices${x}`, output)};
@@ -38,6 +38,7 @@ const createWhereOpProgramShader =
let index_c${x} = offset_c${x} / 4u;
let component_a${x} = offset_a${x} % 4u;
let component_b${x} = offset_b${x} % 4u;
+ let component_c${x} = offset_c${x} % 4u;
${resStr}[${x}] = ${typeCast}(${expression(expressionA, expressionB, expressionC)});
`;
};
diff --git a/js/web/test/data/ops/where.jsonc b/js/web/test/data/ops/where.jsonc
index 047fd6fd7511b..990120dd3708e 100644
--- a/js/web/test/data/ops/where.jsonc
+++ b/js/web/test/data/ops/where.jsonc
@@ -168,5 +168,39 @@
]
}
]
+ },
+ {
+ "name": "Where with no attributes",
+ "operator": "Where",
+ "attributes": [],
+ "cases": [
+ {
+ "name": "T[1 1 2 1] T[1 4] T[1 1 2 4] float32 broadcast 1",
+ "inputs": [
+ {
+ "data": [true, false],
+ "dims": [1, 1, 2, 1],
+ "type": "bool"
+ },
+ {
+ "data": [1, 2, 3, 4],
+ "dims": [1, 4],
+ "type": "float32"
+ },
+ {
+ "data": [5, 6, 7, 8, 9, 10, 11, 12],
+ "dims": [1, 1, 2, 4],
+ "type": "float32"
+ }
+ ],
+ "outputs": [
+ {
+ "data": [1, 2, 3, 4, 9, 10, 11, 12],
+ "dims": [1, 1, 2, 4],
+ "type": "float32"
+ }
+ ]
+ }
+ ]
}
]
diff --git a/js/web/test/test-runner.ts b/js/web/test/test-runner.ts
index ecc7d4b4a09a5..a4adf5c4ce144 100644
--- a/js/web/test/test-runner.ts
+++ b/js/web/test/test-runner.ts
@@ -627,8 +627,8 @@ export async function runModelTestSet(
try {
const feeds: Record = {};
const outputsMetaInfo: Record = {};
- testCase.inputs!.forEach((tensor, i) => feeds[context.session.inputNames[i]] = tensor);
- testCase.outputs!.forEach((tensor, i) => outputsMetaInfo[context.session.outputNames[i]] = tensor);
+ testCase.inputs!.forEach((tensor) => feeds[tensor.name] = tensor);
+ testCase.outputs!.forEach((tensor) => outputsMetaInfo[tensor.name] = tensor);
const [start, end, outputs] =
await sessionRun({session: context.session, feeds, outputsMetaInfo, ioBinding: context.ioBinding});
if (context.perfData.count === 0) {
diff --git a/onnxruntime/contrib_ops/cpu/activations.cc b/onnxruntime/contrib_ops/cpu/activations.cc
index 556699192d2eb..3e0533dd8b9e5 100644
--- a/onnxruntime/contrib_ops/cpu/activations.cc
+++ b/onnxruntime/contrib_ops/cpu/activations.cc
@@ -2,7 +2,7 @@
// Licensed under the MIT License.
#include "core/providers/cpu/activation/activations.h"
-#include "activations.h"
+#include "contrib_ops/cpu/activations.h"
namespace onnxruntime {
namespace contrib {
@@ -26,14 +26,6 @@ ONNX_CPU_OPERATOR_VERSIONED_KERNEL(
KernelDefBuilder().MayInplace(0, 0).TypeConstraint("T", DataTypeImpl::GetTensorType()),
ThresholdedRelu);
-ONNX_OPERATOR_KERNEL_EX(
- Gelu,
- kMSDomain,
- 1,
- kCpuExecutionProvider,
- KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType()),
- Gelu);
-
ONNX_OPERATOR_KERNEL_EX(
QuickGelu,
kMSDomain,
diff --git a/onnxruntime/contrib_ops/cpu/activations.h b/onnxruntime/contrib_ops/cpu/activations.h
index aed4c2229215d..7e64235d3fc3d 100644
--- a/onnxruntime/contrib_ops/cpu/activations.h
+++ b/onnxruntime/contrib_ops/cpu/activations.h
@@ -54,47 +54,6 @@ namespace contrib {
DEFINE_ELE_KERNEL(ScaledTanh);
DEFINE_ELE_KERNEL(ParametricSoftplus);
-template
-class Gelu : public OpKernel {
- public:
- Gelu(const OpKernelInfo& info) : OpKernel(info) {
- }
-
- Status Compute(OpKernelContext* context) const override {
- const Tensor* input = context->Input(0);
- const T* input_data = input->Data();
-
- Tensor* output = context->Output(0, input->Shape());
- T* output_data = output->MutableData();
-
- concurrency::ThreadPool* tp = context->GetOperatorThreadPool();
- int64_t elem_count = input->Shape().Size();
- constexpr int64_t length_per_task = 4096; // this number comes from FastGelu.
- int64_t task_count = (elem_count + length_per_task - 1) / length_per_task;
- concurrency::ThreadPool::TryBatchParallelFor(
- tp, static_cast(task_count),
- [&](ptrdiff_t task_idx) {
- const auto start = task_idx * length_per_task;
- const T* p_input = input_data + start;
- T* p_output = output_data + start;
- int64_t count = std::min(length_per_task, elem_count - start);
-
- for (int64_t i = 0; i < count; i++) {
- T value = p_input[i];
- p_output[i] = value * static_cast(M_SQRT1_2);
- }
-
- MlasComputeErf(p_output, p_output, narrow(count));
-
- for (int64_t i = 0; i < count; i++) {
- p_output[i] = 0.5f * p_input[i] * (p_output[i] + 1.0f);
- }
- },
- 0);
- return Status::OK();
- }
-};
-
// Implement a new one instead of inheriting from ElementWiseRangedTransform so that we can call
// MlasComputeLogistic instead of using Eigen for better perf.
template
diff --git a/onnxruntime/contrib_ops/cuda/activation/activations.cc b/onnxruntime/contrib_ops/cuda/activation/activations.cc
index 1a86c5dbece5a..6303858b9bd48 100644
--- a/onnxruntime/contrib_ops/cuda/activation/activations.cc
+++ b/onnxruntime/contrib_ops/cuda/activation/activations.cc
@@ -49,7 +49,6 @@ namespace cuda {
UNARY_ACTIVATION_OP_HFD(Affine, 1, kOnnxDomain);
UNARY_ACTIVATION_OP_HFD(ParametricSoftplus, 1, kOnnxDomain);
UNARY_ACTIVATION_OP_HFD(ScaledTanh, 1, kOnnxDomain);
-UNARY_ACTIVATION_OP_HFD(Gelu, 1, kMSDomain);
UNARY_ACTIVATION_OP_HFD(QuickGelu, 1, kMSDomain);
REGISTER_ACTIVATION_KERNEL(ThresholdedRelu, 1, kOnnxDomain, MLFloat16)
diff --git a/onnxruntime/contrib_ops/cuda/activation/activations.h b/onnxruntime/contrib_ops/cuda/activation/activations.h
index ab339f276c2bd..fc9a71b0b7fa1 100644
--- a/onnxruntime/contrib_ops/cuda/activation/activations.h
+++ b/onnxruntime/contrib_ops/cuda/activation/activations.h
@@ -66,17 +66,6 @@ class ScaledTanh final : public UnaryElementwise {
float beta_;
};
-template
-class Gelu final : public UnaryElementwise {
- public:
- Gelu(const OpKernelInfo& info) : UnaryElementwise(info) {}
-
- Status ComputeInternal(OpKernelContext* context) const override;
-
- private:
- MAKE_FUNC_CTX_NULL()
-};
-
template
class QuickGelu final : public UnaryElementwise {
public:
diff --git a/onnxruntime/contrib_ops/cuda/activation/activations_impl.cu b/onnxruntime/contrib_ops/cuda/activation/activations_impl.cu
index 0c856815fd437..36f33fbb24c18 100644
--- a/onnxruntime/contrib_ops/cuda/activation/activations_impl.cu
+++ b/onnxruntime/contrib_ops/cuda/activation/activations_impl.cu
@@ -36,20 +36,6 @@ struct OP_ScaledTanh : public CtxScaledTanh {
}
};
-template
-struct OP_Gelu : public CtxGelu {
- __device__ __inline__ T operator()(const T& a) const {
- return _Gelu(a);
- }
-};
-
-template <>
-struct OP_Gelu : public CtxGelu {
- __device__ __inline__ half operator()(const half& a) const {
- return static_cast(_Gelu(static_cast(a)));
- }
-};
-
template
struct OP_QuickGelu : public CtxQuickGelu {
__device__ __inline__ T operator()(const T& a) const {
diff --git a/onnxruntime/contrib_ops/cuda/activation/activations_impl.h b/onnxruntime/contrib_ops/cuda/activation/activations_impl.h
index 5d18283a395e3..782d4bf59a5ad 100644
--- a/onnxruntime/contrib_ops/cuda/activation/activations_impl.h
+++ b/onnxruntime/contrib_ops/cuda/activation/activations_impl.h
@@ -11,14 +11,12 @@ namespace cuda {
typedef onnxruntime::cuda::CtxAlphaBeta CtxAffine;
typedef onnxruntime::cuda::CtxAlphaBeta CtxParametricSoftplus;
typedef onnxruntime::cuda::CtxAlphaBeta CtxScaledTanh;
-typedef onnxruntime::cuda::CtxNull CtxGelu;
typedef onnxruntime::cuda::CtxAlpha CtxQuickGelu;
#define UNARY_CONTRIB_ACTIVATION_OPS() \
UNARY_ACTIVATION_OP_NAME(ScaledTanh) \
UNARY_ACTIVATION_OP_NAME(Affine) \
UNARY_ACTIVATION_OP_NAME(ParametricSoftplus) \
- UNARY_ACTIVATION_OP_NAME(Gelu) \
UNARY_ACTIVATION_OP_NAME(QuickGelu)
#define UNARY_ACTIVATION_OP_NAME(name) UNARY_ACTIVATION_IMPL_DECLARATION(name);
diff --git a/onnxruntime/contrib_ops/cuda/bert/fast_gelu.cc b/onnxruntime/contrib_ops/cuda/bert/fast_gelu.cc
index 892f5c181a607..e8974a29476b6 100644
--- a/onnxruntime/contrib_ops/cuda/bert/fast_gelu.cc
+++ b/onnxruntime/contrib_ops/cuda/bert/fast_gelu.cc
@@ -4,9 +4,14 @@
#include "core/providers/cuda/cuda_common.h"
#include "core/providers/cuda/cudnn_common.h"
#include "fast_gelu.h"
-#include "fast_gelu_impl.h"
+#include "core/providers/cuda/tensor/gelu_impl.h"
#include "contrib_ops/cpu/bert/bias_gelu_helper.h"
-#include "transformer_common.h"
+#ifdef USE_ROCM
+#include "contrib_ops/rocm/bert/elementwise.h"
+#endif
+#ifdef USE_CUDA
+#include "contrib_ops/cuda/bert/transformer_common.h"
+#endif
namespace onnxruntime {
namespace contrib {
@@ -31,8 +36,10 @@ using namespace ONNX_NAMESPACE;
template
FastGelu::FastGelu(const OpKernelInfo& op_kernel_info) : CudaKernel(op_kernel_info) {
+#ifdef USE_CUDA
const TransformerOptions* options = TransformerOptions::GetInstance();
use_half2_ = !options->DisableHalf2();
+#endif
}
template
@@ -50,6 +57,14 @@ Status FastGelu::ComputeInternal(OpKernelContext* context) const {
int64_t bias_length = (nullptr == bias) ? 0 : bias->Shape().Size();
typedef typename ToCudaType::MappedType CudaT;
+#ifdef USE_ROCM
+ return LaunchElementwiseKernel(
+ GetTuningContext(), context->GetComputeStream(),
+ reinterpret_cast(input->Data()), static_cast(input_length),
+ (nullptr != bias) ? reinterpret_cast(bias->Data()) : nullptr, static_cast(bias_length),
+ reinterpret_cast(output->MutableData()));
+#endif
+#ifdef USE_CUDA
return LaunchFastGeluKernel(GetDeviceProp(),
Stream(context),
static_cast(input_length),
@@ -58,6 +73,7 @@ Status FastGelu::ComputeInternal(OpKernelContext* context) const {
(nullptr != bias) ? reinterpret_cast(bias->Data()) : nullptr,
reinterpret_cast(output->MutableData()),
use_half2_);
+#endif
}
} // namespace cuda
diff --git a/onnxruntime/contrib_ops/cuda/bert/fast_gelu.h b/onnxruntime/contrib_ops/cuda/bert/fast_gelu.h
index 3e642a70afef5..d563556593e6e 100644
--- a/onnxruntime/contrib_ops/cuda/bert/fast_gelu.h
+++ b/onnxruntime/contrib_ops/cuda/bert/fast_gelu.h
@@ -18,7 +18,7 @@ class FastGelu final : public CudaKernel {
Status ComputeInternal(OpKernelContext* ctx) const override;
private:
- bool use_half2_;
+ bool use_half2_; // Only applicable to CUDA kernel (not ROCM).
};
} // namespace cuda
diff --git a/onnxruntime/contrib_ops/cuda/cuda_contrib_kernels.cc b/onnxruntime/contrib_ops/cuda/cuda_contrib_kernels.cc
index be8c0dc86c135..57e951d3a68ff 100644
--- a/onnxruntime/contrib_ops/cuda/cuda_contrib_kernels.cc
+++ b/onnxruntime/contrib_ops/cuda/cuda_contrib_kernels.cc
@@ -203,6 +203,10 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, DistributedSqueeze);
#endif
+#ifdef ENABLE_CUDA_NHWC_OPS
+class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSInternalNHWCDomain, 16, float, GridSample);
+#endif
+
template <>
KernelCreateInfo BuildKernelCreateInfo() {
KernelCreateInfo info;
@@ -408,6 +412,9 @@ Status RegisterCudaContribKernels(KernelRegistry& kernel_registry) {
BuildKernelCreateInfo,
#endif
+#ifdef ENABLE_CUDA_NHWC_OPS
+ BuildKernelCreateInfo,
+#endif
};
for (auto& function_table_entry : function_table) {
diff --git a/onnxruntime/contrib_ops/cuda/grid_sample.cc b/onnxruntime/contrib_ops/cuda/grid_sample.cc
index 4c2999c279e0a..2500de39d3536 100644
--- a/onnxruntime/contrib_ops/cuda/grid_sample.cc
+++ b/onnxruntime/contrib_ops/cuda/grid_sample.cc
@@ -9,22 +9,23 @@ namespace onnxruntime {
namespace contrib {
namespace cuda {
-#define REGISTER_KERNEL_TYPED(T) \
+#define REGISTER_KERNEL_TYPED(T, VERSION, LAYOUT, DOMAIN) \
ONNX_OPERATOR_TYPED_KERNEL_EX( \
GridSample, \
- kMSDomain, \
- 1, \
+ DOMAIN, \
+ VERSION, \
T, \
kCudaExecutionProvider, \
(*KernelDefBuilder::Create()) \
.TypeConstraint("T1", DataTypeImpl::GetTensorType()) \
.TypeConstraint("T2", DataTypeImpl::GetTensorType()), \
- GridSample);
+ onnxruntime::contrib::cuda::GridSample);
-REGISTER_KERNEL_TYPED(float)
+REGISTER_KERNEL_TYPED(float, 1, LAYOUT_NCHW, kMSDomain)
+REGISTER_KERNEL_TYPED(float, 16, LAYOUT_NHWC, kMSInternalNHWCDomain)
-template
-GridSample::GridSample(const OpKernelInfo& info) : CudaKernel(info) {
+template
+GridSample::GridSample(const OpKernelInfo& info) : CudaKernel(info) {
std::string mode_str = info.GetAttrOrDefault("mode", "bilinear");
std::string padding_mode_str = info.GetAttrOrDefault("padding_mode", "zeros");
align_corners_ = static_cast(info.GetAttrOrDefault("align_corners", 0));
@@ -48,8 +49,8 @@ GridSample::GridSample(const OpKernelInfo& info) : CudaKernel(info) {
}
}
-template
-Status GridSample::ComputeInternal(OpKernelContext* context) const {
+template
+Status GridSample::ComputeInternal(OpKernelContext* context) const {
const Tensor* X = context->Input(0);
const auto& dims_input = X->Shape().GetDims();
const Tensor* Grid = context->Input(1);
@@ -61,11 +62,13 @@ Status GridSample::ComputeInternal(OpKernelContext* context) const {
ORT_ENFORCE(dims_grid[0] == dims_input[0], "Grid batch size ", dims_grid[0], " does not match input batch size ", dims_input[0]);
ORT_ENFORCE(dims_grid[3] == 2, "Last dimension of grid: ", dims_grid[3], ", expect 2");
+ using Ch = Channels;
+
TensorShapeVector dims_output(4);
- dims_output[0] = dims_input[0];
- dims_output[1] = dims_input[1];
- dims_output[2] = dims_grid[1];
- dims_output[3] = dims_grid[2];
+ dims_output[Ch::N] = dims_input[Ch::N];
+ dims_output[Ch::C] = dims_input[Ch::C];
+ dims_output[Ch::H] = dims_grid[1 /* Grid::H */];
+ dims_output[Ch::W] = dims_grid[2 /* Grid::W */];
Tensor* Y = context->Output(0, dims_output);
// Return early if the output tensor is going to be of size 0
if (Y->Shape().Size() == 0) {
@@ -74,7 +77,7 @@ Status GridSample::ComputeInternal(OpKernelContext* context) const {
typedef typename ToCudaType::MappedType CudaT;
CudaT* Y_data = reinterpret_cast(Y->MutableData());
- GridSampleImpl(
+ GridSampleImpl(
Stream(context),
reinterpret_cast(X->Data()),
reinterpret_cast(Grid->Data()),
@@ -89,4 +92,8 @@ Status GridSample::ComputeInternal(OpKernelContext* context) const {
}
} // namespace cuda
} // namespace contrib
+
+namespace cuda {
+REGISTER_KERNEL_TYPED(float, 16, LAYOUT_NCHW, kOnnxDomain)
+} // namespace cuda
} // namespace onnxruntime
diff --git a/onnxruntime/contrib_ops/cuda/grid_sample.h b/onnxruntime/contrib_ops/cuda/grid_sample.h
index 08ca58c7cc458..16581bfe77482 100644
--- a/onnxruntime/contrib_ops/cuda/grid_sample.h
+++ b/onnxruntime/contrib_ops/cuda/grid_sample.h
@@ -12,7 +12,7 @@ namespace cuda {
using namespace onnxruntime::cuda;
-template
+template
class GridSample final : public CudaKernel {
public:
explicit GridSample(const OpKernelInfo& info);
diff --git a/onnxruntime/contrib_ops/cuda/grid_sample_impl.cu b/onnxruntime/contrib_ops/cuda/grid_sample_impl.cu
index 8a391eca7e86a..b23da635bc83d 100644
--- a/onnxruntime/contrib_ops/cuda/grid_sample_impl.cu
+++ b/onnxruntime/contrib_ops/cuda/grid_sample_impl.cu
@@ -50,28 +50,34 @@ __device__ T GsReflect(T x, float x_min, float x_max) {
return static_cast(fx);
}
-template
+template
__device__ T PixelAtGrid(const T* input_data, int64_t bIdx, int64_t cIdx, int64_t y, int64_t x,
- int64_t padding_mode, int64_t N, int64_t C, int64_t H, int64_t W, float border[4]) {
+ int64_t padding_mode, int64_t N, int64_t C, int64_t H, int64_t W, float border[4]) {
T pixel = 0.0f;
+
+ auto PixelOffset = [bIdx, cIdx, C, H, W](int64_t x, int64_t y) -> int64_t {
+ return Layout == LAYOUT_NCHW
+ ? (bIdx * C * H * W + cIdx * H * W + y * W + x)
+ : (bIdx * H * W * C + y * W * C + x * C + cIdx);
+ };
+
if (padding_mode == 0) { // zeros
if (x >= 0 && x < W && y >= 0 && y < H) {
- pixel = input_data[bIdx * C * H * W + cIdx * H * W + y * W + x];
+ pixel = input_data[PixelOffset(x, y)];
}
- } else if (padding_mode == 1) { //border
+ } else if (padding_mode == 1) { // border
x = max((int64_t)0, min((int64_t)W - 1, (int64_t)x));
y = max((int64_t)0, min((int64_t)H - 1, (int64_t)y));
- pixel = input_data[bIdx * C * H * W + cIdx * H * W + y * W + x];
+ pixel = input_data[PixelOffset(x, y)];
} else { // Reflection
- x = (int64_t) GsReflect(x, border[0], border[2]);
- y = (int64_t) GsReflect(y, border[1], border[3]);
- pixel = input_data[bIdx * C * H * W + cIdx * H * W + y * W + x];
+ x = (int64_t)GsReflect(x, border[0], border[2]);
+ y = (int64_t)GsReflect(y, border[1], border[3]);
+ pixel = input_data[PixelOffset(x, y)];
}
return pixel;
}
-__device__ void GsGetCubicCoeffs(float x, float coeffs[4])
-{
+__device__ void GsGetCubicCoeffs(float x, float coeffs[4]) {
float cubic_alpha = -0.75f;
x = abs(x);
coeffs[0] = (((cubic_alpha * (x + 1) - 5 * cubic_alpha) * (x + 1) + 8 * cubic_alpha) * (x + 1) - 4 * cubic_alpha);
@@ -93,7 +99,7 @@ __device__ T GsBicubicInterpolate(T p[4][4], float x, float y) {
return pixel;
}
-template
+template
__global__ void _GridSampleKernel(
const T* input_data,
const T* grid_data,
@@ -110,16 +116,32 @@ __global__ void _GridSampleKernel(
{
CALCULATE_ELEMENTWISE_INDEX_OR_EXIT(idx, N * C * H_out * W_out);
// extract batch index, channel index, y index, x index for current thread
- int BIdx = idx / (C * H_out * W_out );
- int tmpBCnt = BIdx * (C * H_out * W_out);
+ int BIdx, yIdx, xIdx, cIdx;
+ if constexpr (Layout == LAYOUT_NCHW) {
+ BIdx = idx / (C * H_out * W_out);
+ int tmpBCnt = BIdx * (C * H_out * W_out);
+
+ cIdx = (idx - tmpBCnt) / (H_out * W_out);
+ int tmpCCnt = tmpBCnt + cIdx * (H_out * W_out);
- int cIdx = (idx - tmpBCnt) / (H_out * W_out);
- int tmpCCnt = tmpBCnt + cIdx * (H_out * W_out);
+ yIdx = (idx - tmpCCnt) / W_out;
+ int tmpHCnt = tmpCCnt + yIdx * W_out;
- int yIdx = (idx - tmpCCnt) / W_out;
- int tmpHCnt = tmpCCnt + yIdx * W_out;
+ xIdx = (idx - tmpHCnt);
+ } else {
+ static_assert(Layout == LAYOUT_NHWC, "Unsupported layout");
- int xIdx = (idx - tmpHCnt);
+ BIdx = idx / (H_out * W_out * C);
+ int tmpBCnt = BIdx * (H_out * W_out * C);
+
+ yIdx = (idx - tmpBCnt) / (W_out * C);
+ int tmpHCnt = tmpBCnt + yIdx * (W_out * C);
+
+ xIdx = (idx - tmpHCnt) / C;
+ int tmpWCnt = tmpHCnt + xIdx * C;
+
+ cIdx = (idx - tmpWCnt);
+ }
int grid_idx = BIdx * H_out * W_out + yIdx * W_out + xIdx;
T grid_X = grid_data[grid_idx * 2 + 0];
@@ -147,8 +169,9 @@ __global__ void _GridSampleKernel(
if (grid_x_imgSpace < x_min || grid_x_imgSpace > x_max ||
grid_y_imgSpace < y_min || grid_y_imgSpace > y_max) { // out of bound
if (padding_mode == 1) { // border
- grid_x_imgSpace = max(0.0f, min(grid_x_imgSpace, W_in - 1.0f));
- grid_y_imgSpace = max(0.0f, min(grid_y_imgSpace, H_in - 1.0f));
+ // Clamping must not be done here, see #10607
+ // grid_x_imgSpace = max(0.0f, min(grid_x_imgSpace, W_in - 1.0f));
+ // grid_y_imgSpace = max(0.0f, min(grid_y_imgSpace, H_in - 1.0f));
} else if (padding_mode == 2) { // reflection
grid_x_imgSpace = GsReflect(grid_x_imgSpace, x_min, x_max);
grid_y_imgSpace = GsReflect(grid_y_imgSpace, y_min, y_max);
@@ -175,10 +198,10 @@ __global__ void _GridSampleKernel(
w_lb = w_b * w_l;
w_rb = w_b * w_r;
- T lt_v = PixelAtGrid(input_data, BIdx, cIdx, y1, x1, padding_mode, N, C, H_in, W_in, border);
- T rt_v = PixelAtGrid(input_data, BIdx, cIdx, y1, x2, padding_mode, N, C, H_in, W_in, border);
- T lb_v = PixelAtGrid(input_data, BIdx, cIdx, y2, x1, padding_mode, N, C, H_in, W_in, border);
- T rb_v = PixelAtGrid(input_data, BIdx, cIdx, y2, x2, padding_mode, N, C, H_in, W_in, border);
+ T lt_v = PixelAtGrid(input_data, BIdx, cIdx, y1, x1, padding_mode, N, C, H_in, W_in, border);
+ T rt_v = PixelAtGrid(input_data, BIdx, cIdx, y1, x2, padding_mode, N, C, H_in, W_in, border);
+ T lb_v = PixelAtGrid(input_data, BIdx, cIdx, y2, x1, padding_mode, N, C, H_in, W_in, border);
+ T rb_v = PixelAtGrid(input_data, BIdx, cIdx, y2, x2, padding_mode, N, C, H_in, W_in, border);
T interpoV = w_lt * lt_v + w_rt * rt_v + w_lb * lb_v + w_rb * rb_v;
output_data[outIdx] = interpoV;
return;
@@ -186,7 +209,8 @@ __global__ void _GridSampleKernel(
if (mode == 1) { // nearest
int x_n = grid_x_imgSpace;
int y_n = grid_y_imgSpace;
- output_data[outIdx] = PixelAtGrid(input_data, BIdx, cIdx, y_n, x_n, padding_mode, N, C, H_in, W_in, border);
+ output_data[outIdx] =
+ PixelAtGrid(input_data, BIdx, cIdx, y_n, x_n, padding_mode, N, C, H_in, W_in, border);
return;
}
if (mode == 2) { // bicubic
@@ -195,7 +219,8 @@ __global__ void _GridSampleKernel(
T p[4][4] = {}; // [H][W]
for (int64_t h = 0; h < 4; h++) {
for (int64_t w = 0; w < 4; w++) {
- p[h][w] = PixelAtGrid(input_data, BIdx, cIdx, h + y0, w + x0, padding_mode, N, C, H_in, W_in, border);
+ p[h][w] =
+ PixelAtGrid(input_data, BIdx, cIdx, h + y0, w + x0, padding_mode, N, C, H_in, W_in, border);
}
}
T dx = grid_x_imgSpace - x0 - 1;
@@ -204,7 +229,7 @@ __global__ void _GridSampleKernel(
}
}
-template
+template
void GridSampleImpl(
cudaStream_t stream,
const T* input_data,
@@ -216,17 +241,23 @@ void GridSampleImpl(
const int64_t H_out,
const int64_t W_out,
T* output_data) {
- int blocksPerGrid = (int)(ceil(static_cast(dims[0] * dims[1] * H_out * W_out) / GridDim::maxThreadsPerBlock));
- _GridSampleKernel<<>>(
- input_data, grid_data, mode, padding_mode, align_corners, dims[0], dims[1], dims[2], dims[3], H_out, W_out, output_data);
+ using Ch = Channels;
+
+ int blocksPerGrid = static_cast(
+ ceil(static_cast(dims[Ch::N] * dims[Ch::C] * H_out * W_out) / GridDim::maxThreadsPerBlock));
+ _GridSampleKernel<<>>(
+ input_data, grid_data, mode, padding_mode, align_corners,
+ dims[Ch::N], dims[Ch::C], dims[Ch::H], dims[Ch::W],
+ H_out, W_out, output_data);
}
-#define SPECIALIZED_IMPL(T) \
- template void GridSampleImpl(cudaStream_t stream, const T* input_data, const T* grid_data, \
- const int64_t mode, const int64_t padding_mode, const int64_t align_corners, \
- const int64_t[4], const int64_t H_out, const int64_t W_out, T* output_data);
+#define SPECIALIZED_IMPL(T, IsNHWC) \
+ template void GridSampleImpl(cudaStream_t stream, const T* input_data, const T* grid_data, \
+ const int64_t mode, const int64_t padding_mode, const int64_t align_corners, \
+ const int64_t[4], const int64_t H_out, const int64_t W_out, T* output_data);
-SPECIALIZED_IMPL(float)
+SPECIALIZED_IMPL(float, false) // NCHW
+SPECIALIZED_IMPL(float, true) // NHWC
} // namespace cuda
} // namespace contrib
diff --git a/onnxruntime/contrib_ops/cuda/grid_sample_impl.h b/onnxruntime/contrib_ops/cuda/grid_sample_impl.h
index 6df86ce161908..62cd66a48fa84 100644
--- a/onnxruntime/contrib_ops/cuda/grid_sample_impl.h
+++ b/onnxruntime/contrib_ops/cuda/grid_sample_impl.h
@@ -8,7 +8,7 @@ namespace onnxruntime {
namespace contrib {
namespace cuda {
-template
+template
void GridSampleImpl(
cudaStream_t stream,
const T* input_data,
diff --git a/onnxruntime/contrib_ops/rocm/bert/fast_gelu.cc b/onnxruntime/contrib_ops/rocm/bert/fast_gelu.cc
deleted file mode 100644
index 9cb414e4e8980..0000000000000
--- a/onnxruntime/contrib_ops/rocm/bert/fast_gelu.cc
+++ /dev/null
@@ -1,59 +0,0 @@
-// Copyright (c) Microsoft Corporation. All rights reserved.
-// Licensed under the MIT License.
-
-#include "contrib_ops/rocm/bert/fast_gelu.h"
-
-#include "core/providers/rocm/rocm_common.h"
-#include "core/providers/rocm/miopen_common.h"
-#include "contrib_ops/cpu/bert/bias_gelu_helper.h"
-#include "contrib_ops/rocm/bert/elementwise.h"
-#include "contrib_ops/rocm/bert/transformer_common.h"
-
-namespace onnxruntime {
-namespace contrib {
-namespace rocm {
-
-#define REGISTER_KERNEL_TYPED(T) \
- ONNX_OPERATOR_TYPED_KERNEL_EX( \
- FastGelu, \
- kMSDomain, \
- 1, \
- T, \
- kRocmExecutionProvider, \
- (*KernelDefBuilder::Create()) \
- .TypeConstraint("T", DataTypeImpl::GetTensorType()), \
- FastGelu);
-
-REGISTER_KERNEL_TYPED(float)
-REGISTER_KERNEL_TYPED(MLFloat16)
-REGISTER_KERNEL_TYPED(BFloat16)
-
-using namespace ONNX_NAMESPACE;
-
-template
-Status FastGelu::ComputeInternal(OpKernelContext* context) const {
- ORT_RETURN_IF_ERROR(bias_gelu_helper::CheckInputs(context));
-
- const Tensor* input = context->Input(0);
- const Tensor* bias = context->Input(1);
- Tensor* output = context->Output(0, input->Shape());
-
- int64_t input_length = input->Shape().Size();
- if (input_length == 0) {
- return Status::OK();
- }
- int64_t bias_length = (nullptr == bias) ? 0 : bias->Shape().Size();
- typedef typename ToHipType::MappedType HipT;
-
- const HipT* input_buffer = reinterpret_cast(input->Data());
- const HipT* bias_buffer = (nullptr != bias) ? reinterpret_cast(bias->Data()) : nullptr;
- return LaunchElementwiseKernel(
- GetTuningContext(), context->GetComputeStream(),
- input_buffer, static_cast(input_length),
- bias_buffer, static_cast(bias_length),
- reinterpret_cast(output->MutableData()));
-}
-
-} // namespace rocm
-} // namespace contrib
-} // namespace onnxruntime
diff --git a/onnxruntime/contrib_ops/rocm/bert/fast_gelu.h b/onnxruntime/contrib_ops/rocm/bert/fast_gelu.h
deleted file mode 100644
index 42bfe5a0b0246..0000000000000
--- a/onnxruntime/contrib_ops/rocm/bert/fast_gelu.h
+++ /dev/null
@@ -1,24 +0,0 @@
-// Copyright (c) Microsoft Corporation. All rights reserved.
-// Licensed under the MIT License.
-
-#pragma once
-
-#include "core/common/common.h"
-#include "core/providers/rocm/rocm_kernel.h"
-
-namespace onnxruntime {
-namespace contrib {
-namespace rocm {
-
-using namespace onnxruntime::rocm;
-
-template
-class FastGelu final : public RocmKernel {
- public:
- FastGelu(const OpKernelInfo& op_kernel_info) : RocmKernel(op_kernel_info) {}
- Status ComputeInternal(OpKernelContext* ctx) const override;
-};
-
-} // namespace rocm
-} // namespace contrib
-} // namespace onnxruntime
diff --git a/onnxruntime/contrib_ops/rocm/diffusion/group_norm_triton.cuh b/onnxruntime/contrib_ops/rocm/diffusion/group_norm_triton.cuh
index b3d3e92209b39..c6ca16bfdfc80 100644
--- a/onnxruntime/contrib_ops/rocm/diffusion/group_norm_triton.cuh
+++ b/onnxruntime/contrib_ops/rocm/diffusion/group_norm_triton.cuh
@@ -46,8 +46,6 @@ auto GetTritonGroupNormNHWCTypeStringAndOps() {
auto block_size = metadata->constants.at("BLOCK_SIZE");
auto hw_size = metadata->constants.at("HW_SIZE");
auto impl = [i, block_size, hw_size](const GroupNormNHWCTunableParams* params) -> Status {
- TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF((params->skip != nullptr || params->bias != nullptr),
- "Input skip or bias is not supported by triton kernel.");
TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF(
params->channels_per_group > block_size || params->channels_per_group * 2 <= block_size,
"Arg block_size (", block_size, ") is not the next power of 2 of channels_per_group (",
@@ -61,23 +59,36 @@ auto GetTritonGroupNormNHWCTypeStringAndOps() {
}
// Construct args for launch kernel
struct {
- void* X;
- void* Y;
+ const void* src;
+ const void* skip;
+ const void* bias;
+ void* out;
+ void* add_out;
const void* gamma;
const void* beta;
int hw;
int c;
int c_per_group;
float eps;
+ bool has_skip;
+ bool has_bias;
+ bool broadcast_skip;
} args = {
- (void*)params->src,
+ (const void*)params->src,
+ (const void*)params->skip,
+ (const void*)params->bias,
(void*)params->dst,
+ (void*)params->skip_workspace,
(const void*)params->gamma,
(const void*)params->beta,
params->hw,
params->c,
params->channels_per_group,
- params->epsilon};
+ params->epsilon,
+ params->skip != nullptr,
+ params->bias != nullptr,
+ params->broadcast_skip,
+ };
// Grid dim is (batch_count, groups, 1)
return LaunchTritonKernel(params->StreamHandle(), i, params->n, params->groups, 1, &args, sizeof(args));
diff --git a/onnxruntime/contrib_ops/rocm/diffusion/group_norm_triton.py b/onnxruntime/contrib_ops/rocm/diffusion/group_norm_triton.py
index 5368cb1cf635b..5ba96ebc117f0 100644
--- a/onnxruntime/contrib_ops/rocm/diffusion/group_norm_triton.py
+++ b/onnxruntime/contrib_ops/rocm/diffusion/group_norm_triton.py
@@ -12,13 +12,19 @@
@triton.jit
def group_norm_kernel(
input_ptr,
+ skip_ptr,
+ bias_ptr,
output_ptr,
+ add_out_ptr,
gamma_ptr,
beta_ptr,
img_size,
c,
c_per_group,
eps,
+ has_skip,
+ has_bias,
+ broadcast_skip,
BLOCK_SIZE: tl.constexpr,
HW_SIZE: tl.constexpr,
ACTIVATION_SILU: tl.constexpr,
@@ -36,14 +42,35 @@ def group_norm_kernel(
offsets = hw[:, None] * c + cols[None, :]
mask = (cols < c_per_group)[None, :]
+ bias = tl.zeros([BLOCK_SIZE], dtype=tl.float32)
+ if has_skip:
+ add_out_ptr += row_x * stride + row_y * c_per_group
+ if broadcast_skip:
+ broadcast_skip_ptr = skip_ptr + row_x * c + row_y * c_per_group
+ bias += tl.load(broadcast_skip_ptr + cols, mask=cols < c_per_group, other=0.0).to(tl.float32)
+ else:
+ skip_ptr += row_x * stride + row_y * c_per_group
+ if has_bias:
+ bias_ptr += row_y * c_per_group
+ bias += tl.load(bias_ptr + cols, mask=cols < c_per_group, other=0.0).to(tl.float32)
+
# Calculate mean and variance
_sum = tl.zeros([HW_SIZE, BLOCK_SIZE], dtype=tl.float32)
_square_sum = tl.zeros([HW_SIZE, BLOCK_SIZE], dtype=tl.float32)
for i in range(tl.cdiv(img_size, HW_SIZE)):
x_ptr = input_ptr + i * HW_SIZE * c
a = tl.load(x_ptr + offsets, mask=mask, other=0.0).to(tl.float32)
+ if has_skip and not broadcast_skip:
+ s_ptr = skip_ptr + i * HW_SIZE * c
+ s = tl.load(s_ptr + offsets, mask=mask, other=0.0).to(tl.float32)
+ a += s
+ if has_bias or broadcast_skip:
+ a += bias
_sum += a
_square_sum += a * a
+ if has_skip:
+ add_y_ptr = add_out_ptr + i * HW_SIZE * c
+ tl.store(add_y_ptr + offsets, a, mask=mask)
# Set axis=None (or leave it unspecified) to reduce all axes.
# TODO: In older Triton we have to reduce an axis at a time, but in our case
@@ -57,9 +84,13 @@ def group_norm_kernel(
gamma = tl.load(gamma_ptr + cols, mask=cols < c_per_group).to(tl.float32)
beta = tl.load(beta_ptr + cols, mask=cols < c_per_group).to(tl.float32)
for i in range(tl.cdiv(img_size, HW_SIZE)):
- x_ptr = input_ptr + i * HW_SIZE * c
y_ptr = output_ptr + i * HW_SIZE * c
- x = tl.load(x_ptr + offsets, mask=mask, other=0.0).to(tl.float32)
+ if has_skip:
+ add_y_ptr = add_out_ptr + i * HW_SIZE * c
+ x = tl.load(add_y_ptr + offsets, mask=mask, other=0.0).to(tl.float32)
+ else:
+ x_ptr = input_ptr + i * HW_SIZE * c
+ x = tl.load(x_ptr + offsets, mask=mask, other=0.0).to(tl.float32)
x_hat = (x - group_mean) * rstd
y = x_hat * gamma + beta
if ACTIVATION_SILU:
@@ -77,7 +108,7 @@ def group_norm_kernel(
hw_sizes = [8, 16, 32, 64, 128, 256]
warps = [1, 2, 4, 8, 16]
name_pattern = "GroupNormTriton_{}_{}_b{}_hw{}_w{}"
-sig_pattern = "*{},*{},*fp32,*fp32,i32,i32,i32,fp32"
+sig_pattern = "*{},*{},*{},*{},*{},*fp32,*fp32,i32,i32,i32,fp32,i1,i1,i1"
group_pattern = "GroupNormTriton_{}_{}"
@@ -88,7 +119,7 @@ def get_function_table():
silu_suffix = "Silu" if silu else "Pass"
name = name_pattern.format(silu_suffix, dtype, b, hw_size, warp)
group = group_pattern.format(silu_suffix, dtype)
- sig = sig_pattern.format(dtype, dtype)
+ sig = sig_pattern.format(dtype, dtype, dtype, dtype, dtype)
kwargs = {
"num_warps": warp,
"constants": {"BLOCK_SIZE": b, "HW_SIZE": hw_size, "ACTIVATION_SILU": int(silu)},
diff --git a/onnxruntime/core/framework/allocation_planner.cc b/onnxruntime/core/framework/allocation_planner.cc
index ea7a6432a7507..158ab8ed610f4 100644
--- a/onnxruntime/core/framework/allocation_planner.cc
+++ b/onnxruntime/core/framework/allocation_planner.cc
@@ -182,7 +182,6 @@ class PlannerImpl {
// upstream_node_0 and upstream_node_1 are the immmediate upstream nodes of downstream_node
// upstream_node_2 is the immediate nodes ahead of downstream_node in the same logic stream
InlinedHashMap> dependence_graph_;
- InlinedHashMap> value_consumer_map_;
InlinedHashMap value_node_map_;
// OrtValueInfo: Auxiliary information about an OrtValue used only during plan-generation:
@@ -295,7 +294,7 @@ class PlannerImpl {
}
#endif
- // Find if there exists some input tensor that we can use in-place for output_arg_num-th input in the node.
+ // Find if there exists some input tensor that we can use in-place for output_arg_num-th output in the node.
bool FindReusableInput(const onnxruntime::Node& node, int output_arg_num, OrtValueIndex* reusable_input,
bool* is_strided_tensor) {
*is_strided_tensor = false;
@@ -530,6 +529,7 @@ class PlannerImpl {
// Initialize allocation plan:
plan_.allocation_plan.resize(num_ml_values);
+ for (int i = 0; static_cast(i) < num_ml_values; i++) AllocPlan(i).reused_buffer = i;
}
bool HasExternalOutputs(const Node& node) const {
@@ -1065,7 +1065,8 @@ class PlannerImpl {
// build the consumer list for each value
int num_ml_values = ort_value_name_idx_map_.MaxIdx() + 1;
- value_consumer_map_.reserve(num_ml_values);
+ InlinedHashMap> value_consumer_map;
+ value_consumer_map.reserve(num_ml_values);
// iterate each stream from back, so the first element is the last consumer in single stream case
for (auto& stream : stream_nodes_) {
@@ -1078,10 +1079,10 @@ class PlannerImpl {
const auto& name = input.Name();
int value_idx;
ORT_RETURN_IF_ERROR(ort_value_name_idx_map_.GetIdx(name, value_idx));
- auto origin = Buffer(value_idx);
- if (origin != -1 && plan_.allocation_plan[origin].alloc_kind == AllocKind::kAllocate) {
+ auto origin = AllocPlan(value_idx).reused_buffer;
+ if (AllocPlan(origin).alloc_kind == AllocKind::kAllocate) {
// add current node as consumer for origin buffer
- value_consumer_map_[origin].insert(node_index);
+ value_consumer_map[origin].insert(node_index);
}
}
return Status::OK();
@@ -1138,8 +1139,8 @@ class PlannerImpl {
std::cout << p_input_arg->Name() << " reused by " << p_output_arg->Name() << " as input" << std::endl;
allocation_plan[output_idx_global].alloc_kind = AllocKind::kReuse;
allocation_plan[output_idx_global].reused_buffer = reusable_input;
- value_consumer_map_[reusable_input].insert(value_consumer_map_[output_idx_global].begin(),
- value_consumer_map_[output_idx_global].end());
+ value_consumer_map[reusable_input].insert(value_consumer_map[output_idx_global].begin(),
+ value_consumer_map[output_idx_global].end());
reused.insert(reusable_input);
found_reusable = true;
break;
@@ -1168,8 +1169,8 @@ class PlannerImpl {
allocation_plan[reusable_input].alloc_kind == AllocKind::kAllocate) {
allocation_plan[output_idx_global].alloc_kind = AllocKind::kReuse;
allocation_plan[output_idx_global].reused_buffer = reusable_input;
- value_consumer_map_[reusable_input].insert(value_consumer_map_[output_idx_global].begin(),
- value_consumer_map_[output_idx_global].end());
+ value_consumer_map[reusable_input].insert(value_consumer_map[output_idx_global].begin(),
+ value_consumer_map[output_idx_global].end());
reused.insert(reusable_input);
continue;
} // if
@@ -1187,11 +1188,11 @@ class PlannerImpl {
OrtValueIndex input_arg_index{};
if (value_map.GetIdx(p_input_arg->Name(), input_arg_index).IsOK() &&
allocation_plan[input_arg_index].alloc_kind == AllocKind::kAllocate) {
- if (value_consumer_map_[input_arg_index].size() == 1 && SameSize(*p_input_arg, *p_output_arg)) {
+ if (value_consumer_map[input_arg_index].size() == 1 && SameSize(*p_input_arg, *p_output_arg)) {
allocation_plan[output_idx_global].alloc_kind = AllocKind::kReuse;
allocation_plan[output_idx_global].reused_buffer = input_arg_index;
- value_consumer_map_[input_arg_index].insert(value_consumer_map_[output_idx_global].begin(),
- value_consumer_map_[output_idx_global].end());
+ value_consumer_map[input_arg_index].insert(value_consumer_map[output_idx_global].begin(),
+ value_consumer_map[output_idx_global].end());
reused.insert(input_arg_index);
}
}
@@ -1266,7 +1267,7 @@ class PlannerImpl {
}
bool all_covered = true;
- for (auto consumer : value_consumer_map_[output_idx_global]) {
+ for (auto consumer : value_consumer_map[output_idx_global]) {
if (deps->find(consumer) == deps->end()) {
all_covered = false;
break;
@@ -1277,9 +1278,9 @@ class PlannerImpl {
allocation_plan[downstream_value].reused_buffer = output_idx_global;
get_reused = true;
// add new consumer for the value to be reused
- value_consumer_map_[output_idx_global].insert(value_node_map_[downstream_value]);
- value_consumer_map_[output_idx_global].insert(value_consumer_map_[downstream_value].begin(),
- value_consumer_map_[downstream_value].end());
+ value_consumer_map[output_idx_global].insert(value_node_map_[downstream_value]);
+ value_consumer_map[output_idx_global].insert(value_consumer_map[downstream_value].begin(),
+ value_consumer_map[downstream_value].end());
node_iter = size_iter->second.erase(node_iter);
if (size_iter->second.empty()) {
local_iter->second.erase(size_iter);
@@ -1342,8 +1343,9 @@ class PlannerImpl {
ort_value_usecount.reserve(ort_value_info_.size());
#endif
for (size_t i = 0; i < stream_nodes_.size(); ++i) {
- // compute use count first
+ // compute use count first. TODO(leca): call ComputeReuseCount() only once is enough!
ORT_RETURN_IF_ERROR(ComputeReuseCount());
+ for (int j = 0; static_cast(j) < ort_value_info_.size(); j++) Buffer(j) = j;
#if !defined(ORT_MINIMAL_BUILD) && defined(ORT_MEMORY_PROFILE)
if (i == 0) {
for (auto ort_value_info : ort_value_info_) {
@@ -1693,8 +1695,8 @@ class PlannerImpl {
const auto& name = input.Name();
int value_idx;
ORT_RETURN_IF_ERROR(ort_value_name_idx_map_.GetIdx(name, value_idx));
- auto origin = Buffer(value_idx);
- if (origin != -1 && plan_.allocation_plan[origin].alloc_kind == AllocKind::kAllocate) {
+ auto origin = AllocPlan(value_idx).reused_buffer;
+ if (AllocPlan(origin).alloc_kind == AllocKind::kAllocate) {
// add current node as consumer for origin buffer
value_consumers[origin].push_back(node_index);
}
@@ -1889,7 +1891,7 @@ class PlannerImpl {
// 2. the consumer is in the same stream(non-cpu device), but it consumes a CPU tensor from an non-shape op.
// for example, a resize cuda kernel consumer a tensor from MemCpyToHost cuda kernel on the same stream.
// in this case, the FIFO can't guarantee the cpu tensor is ready when resize kernel is launching
- OrtDevice::DeviceType output_arg_device = plan_.allocation_plan[output_arg_idx].location.Type();
+ OrtDevice::DeviceType output_arg_device = AllocPlan(output_arg_idx).location.Type();
WaitNotificationFn wait_handle = stream_handle_registry.GetWaitHandle(stream_device, output_arg_device);
if ((node_stream_map_[it->Index()] != i || output_arg_device == OrtDevice::CPU) && wait_handle != nullptr) {
if (node_to_notification.find(node_index) == node_to_notification.end()) {
diff --git a/onnxruntime/core/framework/stream_execution_context.cc b/onnxruntime/core/framework/stream_execution_context.cc
index 875e7f395bfa8..dd7f4d35b34bd 100644
--- a/onnxruntime/core/framework/stream_execution_context.cc
+++ b/onnxruntime/core/framework/stream_execution_context.cc
@@ -181,11 +181,13 @@ void RunSince(size_t stream_idx, StreamExecutionContext& ctx, SessionScope& sess
}
#ifdef USE_CANN
+ // Leave it to CANN EP to fill the gap if they want to use run_options
+ static onnxruntime::RunOptions run_options;
// For CANN EP, it is necessary to explicitly create a corresponding Context for each thread in the thread pool,
// which is different from CUDA Runtime API, but similar to CUDA Driver API.
auto& execution_providers = ctx.GetSessionState().GetExecutionProviders();
for (auto& xp : execution_providers) {
- auto status = xp->OnRunStart();
+ auto status = xp->OnRunStart(run_options);
if (!status.IsOK()) {
ctx.SetStatus(status);
return;
diff --git a/onnxruntime/core/optimizer/layout_transformation/layout_transformation.cc b/onnxruntime/core/optimizer/layout_transformation/layout_transformation.cc
index 4505d4afdf1e0..a8717b99a8750 100644
--- a/onnxruntime/core/optimizer/layout_transformation/layout_transformation.cc
+++ b/onnxruntime/core/optimizer/layout_transformation/layout_transformation.cc
@@ -31,6 +31,7 @@ CostCheckResult PostLayoutTransformCostCheck(const api::GraphRef& graph, const a
}
#if defined(USE_CUDA) && ENABLE_CUDA_NHWC_OPS
+// TODO(mtavenrath) generate list from registered kernels using nhwc domain
const std::unordered_set& GetCUDALayoutSensitiveOps() {
static std::unordered_set cuda_nhwc_ops = []() {
return std::unordered_set{
@@ -41,6 +42,7 @@ const std::unordered_set& GetCUDALayoutSensitiveOps() {
"MaxPool",
"GlobalAveragePool",
"AveragePool",
+ "GridSample",
};
}();
return cuda_nhwc_ops;
diff --git a/onnxruntime/core/providers/cann/cann_execution_provider.cc b/onnxruntime/core/providers/cann/cann_execution_provider.cc
index 752b742805a7c..9a242919665bb 100644
--- a/onnxruntime/core/providers/cann/cann_execution_provider.cc
+++ b/onnxruntime/core/providers/cann/cann_execution_provider.cc
@@ -1045,7 +1045,7 @@ CANNExecutionProvider::~CANNExecutionProvider() {
}
// All threads share the same context and stream
-Status CANNExecutionProvider::OnRunStart() {
+Status CANNExecutionProvider::OnRunStart(const onnxruntime::RunOptions& /*run_options*/) {
CANN_RETURN_IF_ERROR(aclrtSetDevice(info_.device_id));
return Status::OK();
diff --git a/onnxruntime/core/providers/cann/cann_execution_provider.h b/onnxruntime/core/providers/cann/cann_execution_provider.h
index 63ae980869c65..d83bd88d6958f 100644
--- a/onnxruntime/core/providers/cann/cann_execution_provider.h
+++ b/onnxruntime/core/providers/cann/cann_execution_provider.h
@@ -33,7 +33,7 @@ class CANNExecutionProvider : public IExecutionProvider {
explicit CANNExecutionProvider(const CANNExecutionProviderInfo& info);
virtual ~CANNExecutionProvider();
- Status OnRunStart() override;
+ Status OnRunStart(const onnxruntime::RunOptions& run_options) override;
template
Status Fill(Tensor* y, void* addr, aclrtStream stream) const {
diff --git a/onnxruntime/core/providers/cpu/cpu_execution_provider.cc b/onnxruntime/core/providers/cpu/cpu_execution_provider.cc
index 813fdc54ecd0d..48e4617b33b4d 100644
--- a/onnxruntime/core/providers/cpu/cpu_execution_provider.cc
+++ b/onnxruntime/core/providers/cpu/cpu_execution_provider.cc
@@ -1035,6 +1035,7 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain,
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 20, float, IsNaN);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 20, double, IsNaN);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 20, MLFloat16, IsNaN);
+class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 20, Gelu);
#if !defined(DISABLE_FLOAT8_TYPES)
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 20, Float8E4M3FN, IsNaN);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 20, Float8E4M3FNUZ, IsNaN);
@@ -2562,6 +2563,7 @@ Status RegisterOnnxOperatorKernels(KernelRegistry& kernel_registry) {
BuildKernelCreateInfo,
BuildKernelCreateInfo,
+ BuildKernelCreateInfo,
#if !defined(DISABLE_FLOAT8_TYPES)
BuildKernelCreateInfo,
diff --git a/onnxruntime/core/providers/cpu/tensor/gelu.cc b/onnxruntime/core/providers/cpu/tensor/gelu.cc
new file mode 100644
index 0000000000000..d55973eda180f
--- /dev/null
+++ b/onnxruntime/core/providers/cpu/tensor/gelu.cc
@@ -0,0 +1,108 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// Licensed under the MIT License.
+
+#include "core/common/common.h"
+#include "core/common/narrow.h"
+#include "core/framework/op_kernel.h"
+#include "core/util/math_cpuonly.h"
+#include "core/mlas/inc/mlas.h"
+
+#include "core/platform/threadpool.h"
+#include
+#include "core/providers/cpu/element_wise_ranged_transform.h"
+#include "core/providers/cpu/tensor/gelu.h"
+
+using onnxruntime::narrow;
+using namespace onnxruntime::common;
+
+namespace onnxruntime {
+
+// May revisit the implementations to support inplace computation, if needed.
+
+ONNX_CPU_OPERATOR_KERNEL(
+ Gelu,
+ 20,
+ KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType()),
+ Gelu);
+
+#ifndef DISABLE_CONTRIB_OPS
+namespace contrib {
+ONNX_OPERATOR_KERNEL_EX(
+ Gelu,
+ kMSDomain,
+ 1,
+ kCpuExecutionProvider,
+ KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType()),
+ Gelu);
+}
+#endif
+
+template
+Status Gelu::Compute(OpKernelContext* context) const {
+ const Tensor* input = context->Input(0);
+ const T* input_data = input->Data();
+
+ Tensor* output = context->Output(0, input->Shape());
+ T* output_data = output->MutableData();
+
+ concurrency::ThreadPool* tp = context->GetOperatorThreadPool();
+ int64_t elem_count = input->Shape().Size();
+ constexpr int64_t length_per_task = 4096; // this number comes from FastGelu.
+ int64_t task_count = (elem_count + length_per_task - 1) / length_per_task;
+
+ if (approximation_algorithm_ == "tanh") {
+ // FastGelu allows optional bias. Here we split input data into chunks. Each chunk
+ // has N elements (except the last chunk), and use thread pool to parallel chunks.
+ // N = 4096 is selected based on performance test results on input shape 1x128x768.
+ // FastGelu uses approximation for Gelu. The formula is 0.5 * (1 + Tanh(x * (C * x * x + B))) * x.
+ static constexpr float B = 0.7978845608028654f; // sqrt(2.0 / M_PI)
+ static constexpr float C = 0.035677408136300125f; // 0.044715 * sqrt(2.0 / M_PI)
+
+ concurrency::ThreadPool::TryBatchParallelFor(
+ tp, static_cast(task_count),
+ [&](ptrdiff_t task_idx) {
+ const auto start = task_idx * length_per_task;
+ const T* p_input = input_data + start;
+ T* p_output = output_data + start;
+ int64_t count = std::min(length_per_task, elem_count - start);
+
+ for (int64_t i = 0; i < count; i++) {
+ T value = p_input[i];
+ p_output[i] = value * (static_cast(C) * value * value + static_cast(B));
+ }
+
+ MlasComputeTanh(p_output, p_output, narrow(count));
+
+ for (int64_t i = 0; i < count; i++) {
+ p_output[i] = 0.5f * p_input[i] * (p_output[i] + 1.0f);
+ }
+ },
+ 0);
+ return Status::OK();
+ } else if (approximation_algorithm_ == "none") {
+ concurrency::ThreadPool::TryBatchParallelFor(
+ tp, static_cast(task_count),
+ [&](ptrdiff_t task_idx) {
+ const auto start = task_idx * length_per_task;
+ const T* p_input = input_data + start;
+ T* p_output = output_data + start;
+ int64_t count = std::min(length_per_task, elem_count - start);
+
+ for (int64_t i = 0; i < count; i++) {
+ T value = p_input[i];
+ p_output[i] = value * static_cast(M_SQRT1_2);
+ }
+
+ MlasComputeErf(p_output, p_output, narrow(count));
+
+ for (int64_t i = 0; i < count; i++) {
+ p_output[i] = 0.5f * p_input[i] * (p_output[i] + 1.0f);
+ }
+ },
+ 0);
+ return Status::OK();
+ }
+ return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Unsupported approximation_algorithm: ", approximation_algorithm_);
+}
+
+} // namespace onnxruntime
diff --git a/onnxruntime/core/providers/cpu/tensor/gelu.h b/onnxruntime/core/providers/cpu/tensor/gelu.h
new file mode 100644
index 0000000000000..13238028d878a
--- /dev/null
+++ b/onnxruntime/core/providers/cpu/tensor/gelu.h
@@ -0,0 +1,18 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// Licensed under the MIT License.
+
+namespace onnxruntime {
+
+template
+class Gelu final : public OpKernel {
+ public:
+ explicit Gelu(const OpKernelInfo& info) : OpKernel(info) {
+ approximation_algorithm_ = info.GetAttrOrDefault("approximate", "none");
+ }
+ Status Compute(OpKernelContext* ctx) const override;
+
+ private:
+ std::string approximation_algorithm_;
+};
+
+} // namespace onnxruntime
diff --git a/onnxruntime/core/providers/cuda/cuda_execution_provider.cc b/onnxruntime/core/providers/cuda/cuda_execution_provider.cc
index 48a952e6dd98f..00783bcbc2665 100644
--- a/onnxruntime/core/providers/cuda/cuda_execution_provider.cc
+++ b/onnxruntime/core/providers/cuda/cuda_execution_provider.cc
@@ -386,7 +386,7 @@ Status CUDAExecutionProvider::Sync() const {
return Status::OK();
}
-Status CUDAExecutionProvider::OnRunStart() {
+Status CUDAExecutionProvider::OnRunStart(const onnxruntime::RunOptions& /*run_options*/) {
// always set CUDA device when session::Run() in case it runs in a worker thread
CUDA_RETURN_IF_ERROR(cudaSetDevice(GetDeviceId()));
if (IsGraphCaptureEnabled() && GetPerThreadContext().IsGraphCaptureAllowed() && !GetPerThreadContext().IsGraphCaptured()) {
@@ -396,7 +396,7 @@ Status CUDAExecutionProvider::OnRunStart() {
return Status::OK();
}
-Status CUDAExecutionProvider::OnRunEnd(bool sync_stream) {
+Status CUDAExecutionProvider::OnRunEnd(bool sync_stream, const onnxruntime::RunOptions& /*run_options*/) {
if (IsGraphCaptureEnabled() && !GetPerThreadContext().IsGraphCaptured()) {
if (GetPerThreadContext().IsGraphCaptureAllowed()) {
GetPerThreadContext().CaptureEnd();
@@ -1256,6 +1256,7 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain,
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 16, double, LessOrEqual);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 16, MLFloat16, LessOrEqual);
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 16, 17, ScatterElements);
+class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 16, float, GridSample);
// Opset 17
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 17, float, LayerNormalization);
@@ -1329,6 +1330,11 @@ class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 19, S
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 19, Shape);
#endif
+// Opset 20
+class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 20, float, Gelu);
+class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 20, double, Gelu);
+class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 20, MLFloat16, Gelu);
+
template <>
KernelCreateInfo BuildKernelCreateInfo() {
return {};
@@ -2143,6 +2149,7 @@ static Status RegisterCudaKernels(KernelRegistry& kernel_registry) {
BuildKernelCreateInfo,
BuildKernelCreateInfo,
BuildKernelCreateInfo,
+ BuildKernelCreateInfo,
// Opset 17
BuildKernelCreateInfo,
@@ -2222,6 +2229,11 @@ static Status RegisterCudaKernels(KernelRegistry& kernel_registry) {
BuildKernelCreateInfo,
BuildKernelCreateInfo,
BuildKernelCreateInfo,
+
+ // Opset 20
+ BuildKernelCreateInfo,
+ BuildKernelCreateInfo,
+ BuildKernelCreateInfo,
#endif
};
diff --git a/onnxruntime/core/providers/cuda/cuda_execution_provider.h b/onnxruntime/core/providers/cuda/cuda_execution_provider.h
index 55f0b5570e0ee..5f62f313b86a2 100644
--- a/onnxruntime/core/providers/cuda/cuda_execution_provider.h
+++ b/onnxruntime/core/providers/cuda/cuda_execution_provider.h
@@ -29,9 +29,9 @@ class CUDAExecutionProvider : public IExecutionProvider {
Status Sync() const override;
- Status OnRunStart() override;
+ Status OnRunStart(const onnxruntime::RunOptions& run_options) override;
- Status OnRunEnd(bool sync_stream) override;
+ Status OnRunEnd(bool sync_stream, const onnxruntime::RunOptions& run_options) override;
DataLayout GetPreferredLayout() const override;
@@ -115,6 +115,7 @@ class CUDAExecutionProvider : public IExecutionProvider {
PerThreadContext(OrtDevice::DeviceId device_id, cudaStream_t stream, size_t cuda_mem_limit, ArenaExtendStrategy arena_extend_strategy,
CUDAExecutionProviderExternalAllocatorInfo external_alloc_info, OrtArenaCfg* arena_cfg);
~PerThreadContext();
+ ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(PerThreadContext);
cublasHandle_t CublasHandle() const {
return cublas_handle_;
diff --git a/onnxruntime/core/providers/cuda/cudnn_common.h b/onnxruntime/core/providers/cuda/cudnn_common.h
index fdd14dedad47e..2cbeb13696270 100644
--- a/onnxruntime/core/providers/cuda/cudnn_common.h
+++ b/onnxruntime/core/providers/cuda/cudnn_common.h
@@ -24,12 +24,12 @@ class CudnnTensor final {
operator cudnnTensorDescriptor_t() const { return tensor_; }
+ Status CreateTensorIfNeeded();
+
template
static cudnnDataType_t GetDataType();
private:
- Status CreateTensorIfNeeded();
-
cudnnTensorDescriptor_t tensor_;
};
diff --git a/onnxruntime/core/providers/cuda/rnn/cudnn_rnn_base.cc b/onnxruntime/core/providers/cuda/rnn/cudnn_rnn_base.cc
index 99c1f48e21c74..b61b104790fe5 100644
--- a/onnxruntime/core/providers/cuda/rnn/cudnn_rnn_base.cc
+++ b/onnxruntime/core/providers/cuda/rnn/cudnn_rnn_base.cc
@@ -9,40 +9,49 @@ namespace onnxruntime {
namespace cuda {
template
-void CudnnRnnBase::SetWeightBias(const cudnnHandle_t handle,
- const cudnnRNNDescriptor_t rnn_desc,
- const int pseudo_layer,
- const cudnnTensorDescriptor_t x_desc,
- const cudnnFilterDescriptor_t w_desc,
- const cudnnFilterDescriptor_t filter_desc,
- const void* reorganized_w_data,
- const int lin_layer_id,
- const T* pos,
- int& offset,
- bool is_matrix,
- cudaStream_t cuda_stream) const {
+Status CudnnRnnBase::SetWeightBias(const cudnnHandle_t handle,
+ const cudnnRNNDescriptor_t rnn_desc,
+ const int pseudo_layer,
+ size_t reorganized_w_data_size,
+ const void* reorganized_w_data,
+ const int lin_layer_id,
+ const T* pos,
+ int& offset,
+ bool is_matrix,
+ cudaStream_t cuda_stream) const {
int numDims;
- std::vector matDims(3);
+ std::array matDims;
+ std::array strideA;
cudnnDataType_t dt;
- cudnnTensorFormat_t tf;
T* mem_offset;
- if (is_matrix) {
- cudnnGetRNNLinLayerMatrixParams(handle, rnn_desc, pseudo_layer, x_desc, w_desc, reorganized_w_data, lin_layer_id, filter_desc, (void**)&mem_offset);
- } else {
- cudnnGetRNNLinLayerBiasParams(handle, rnn_desc, pseudo_layer, x_desc, w_desc, reorganized_w_data, lin_layer_id, filter_desc, (void**)&mem_offset);
- }
+ CudnnTensor tensor_desc_matrix, tensor_desc_bias;
+ ORT_RETURN_IF_ERROR(tensor_desc_bias.CreateTensorIfNeeded());
+ ORT_RETURN_IF_ERROR(tensor_desc_matrix.CreateTensorIfNeeded());
- cudnnGetFilterNdDescriptor(filter_desc, 3, &dt, &tf, &numDims, matDims.data());
+ T *mem_offset_matrix, *mem_offset_bias;
+ CUDNN_RETURN_IF_ERROR(cudnnGetRNNWeightParams(
+ handle, rnn_desc, pseudo_layer, reorganized_w_data_size, reorganized_w_data,
+ lin_layer_id, tensor_desc_matrix, (void**)&mem_offset_matrix, tensor_desc_bias, (void**)&mem_offset_bias));
+ CUDNN_RETURN_IF_ERROR(cudnnGetTensorNdDescriptor(
+ is_matrix ? tensor_desc_matrix : tensor_desc_bias, 3, &dt, &numDims, matDims.data(), strideA.data()));
+
+ mem_offset = is_matrix ? mem_offset_matrix : mem_offset_bias;
int count = matDims[0] * matDims[1] * matDims[2];
+
+ if (strideA[0] != count) {
+ return ORT_MAKE_STATUS(ONNXRUNTIME, StatusCode::INVALID_ARGUMENT, "Stride is not packed");
+ }
CUDA_CALL_THROW(cudaMemcpyAsync(mem_offset, pos + offset, count * sizeof(T), cudaMemcpyDeviceToDevice, cuda_stream));
+
offset += count;
+
+ return Status::OK();
}
template
Status CudnnRnnBase::SetCudnnRnnWeightBias(const cudnnHandle_t cudnn_handle,
const cudnnRNNDescriptor_t rnn_desc,
- const cudnnTensorDescriptor_t x_desc,
- const cudnnFilterDescriptor_t w_desc,
+ size_t reorganized_w_data_size,
void* reorganized_w_data,
const T* W_data,
const T* R_data,
@@ -51,18 +60,22 @@ Status CudnnRnnBase::SetCudnnRnnWeightBias(const cudnnHandle_t cudnn_handle,
int w_offset = 0;
int r_offset = 0;
int bias_offset = 0;
- CudnnFilterDescriptor filter_desc;
for (int layer = 0; layer < RNN_NUM_LAYERS * num_directions_; ++layer) {
for (size_t idx = 0; idx < W_lin_layer_id_.size(); ++idx) {
- SetWeightBias(cudnn_handle, rnn_desc, layer, x_desc, w_desc, filter_desc, reorganized_w_data, W_lin_layer_id_[idx], W_data, w_offset, true, cuda_stream);
+ ORT_RETURN_IF_ERROR(SetWeightBias(
+ cudnn_handle, rnn_desc, layer, reorganized_w_data_size, reorganized_w_data,
+ W_lin_layer_id_[idx], W_data, w_offset, true, cuda_stream));
if (B_data != nullptr) {
- SetWeightBias(cudnn_handle, rnn_desc, layer, x_desc, w_desc, filter_desc, reorganized_w_data, W_lin_layer_id_[idx], B_data, bias_offset, false, cuda_stream);
+ ORT_RETURN_IF_ERROR(SetWeightBias(cudnn_handle, rnn_desc, layer, reorganized_w_data_size, reorganized_w_data,
+ W_lin_layer_id_[idx], B_data, bias_offset, false, cuda_stream));
}
}
for (size_t idx = 0; idx < R_lin_layer_id_.size(); ++idx) {
- SetWeightBias(cudnn_handle, rnn_desc, layer, x_desc, w_desc, filter_desc, reorganized_w_data, R_lin_layer_id_[idx], R_data, r_offset, true, cuda_stream);
+ ORT_RETURN_IF_ERROR(SetWeightBias(cudnn_handle, rnn_desc, layer, reorganized_w_data_size, reorganized_w_data,
+ R_lin_layer_id_[idx], R_data, r_offset, true, cuda_stream));
if (B_data != nullptr) {
- SetWeightBias(cudnn_handle, rnn_desc, layer, x_desc, w_desc, filter_desc, reorganized_w_data, R_lin_layer_id_[idx], B_data, bias_offset, false, cuda_stream);
+ ORT_RETURN_IF_ERROR(SetWeightBias(cudnn_handle, rnn_desc, layer, reorganized_w_data_size, reorganized_w_data,
+ R_lin_layer_id_[idx], B_data, bias_offset, false, cuda_stream));
}
}
}
@@ -72,6 +85,7 @@ Status CudnnRnnBase::SetCudnnRnnWeightBias(const cudnnHandle_t cudnn_handle,
template
Status CudnnRnnBase::ReorganizeWeights(const Tensor* W, const Tensor* R, const Tensor* B,
+ size_t& reorganized_w_data_size_in_bytes,
IAllocatorUniquePtr& reorganized_w_data,
CudnnFilterDescriptor& target_w_desc,
CudnnRNN& rnn_desc, onnxruntime::Stream* ort_stream) const {
@@ -91,19 +105,16 @@ Status CudnnRnnBase::ReorganizeWeights(const Tensor* W, const Tensor* R, cons
TensorShapeVector dims_w({w_size, 1, 1});
ORT_RETURN_IF_ERROR(target_w_desc.Set(dims_w, CudnnTensor::GetDataType()));
- TensorShapeVector fake_dims_x({1, input_size, 1});
- CudnnTensor fake_x_desc;
- ORT_RETURN_IF_ERROR(fake_x_desc.Set(fake_dims_x, CudnnTensor::GetDataType()));
-
// Prepare the weight data
- reorganized_w_data = GetScratchBuffer(w_size * sizeof(T), ort_stream);
+ reorganized_w_data_size_in_bytes = w_size * sizeof(T);
+ reorganized_w_data = GetScratchBuffer(reorganized_w_data_size_in_bytes, ort_stream);
// In many cases, this allocation is bigger than needed, leaving part of
- // the buffer unintialized. non-zero garbage data leads to wrong result
+ // the buffer uninitialized. non-zero garbage data leads to wrong result
// in call to cudnnRNNForwardInference()
// TODO! refine allocation size for each case.
cudaStream_t cuda_stream = ort_stream ? static_cast(ort_stream->GetHandle()) : nullptr;
- cudaMemsetAsync(reorganized_w_data.get(), 0, w_size * sizeof(T), cuda_stream);
+ CUDA_RETURN_IF_ERROR(cudaMemsetAsync(reorganized_w_data.get(), 0, reorganized_w_data_size_in_bytes, cuda_stream));
const T* W_data = W->Data();
const T* R_data = R->Data();
@@ -111,8 +122,9 @@ Status CudnnRnnBase::ReorganizeWeights(const Tensor* W, const Tensor* R, cons
auto* ort_cuda_stream = dynamic_cast(ort_stream);
cudnnHandle_t cudnn_handle = ort_cuda_stream ? ort_cuda_stream->cudnn_handle_ : DefaultCudnnHandle();
- ORT_RETURN_IF_ERROR(SetCudnnRnnWeightBias(cudnn_handle, rnn_desc, fake_x_desc, target_w_desc,
- reorganized_w_data.get(), W_data, R_data, B_data, cuda_stream));
+ ORT_RETURN_IF_ERROR(SetCudnnRnnWeightBias(cudnn_handle, rnn_desc,
+ reorganized_w_data_size_in_bytes, reorganized_w_data.get(),
+ W_data, R_data, B_data, cuda_stream));
return Status::OK();
}
@@ -128,22 +140,31 @@ Status CudnnRnnBase::CacheCudnnRnnWeights(const OpKernelInfo& info) {
bool get_R = info.TryGetConstantInput(RNN_Input_Index::R, &R);
bool get_B = info.TryGetConstantInput(RNN_Input_Index::B, &B);
+ bool has_bias = B != nullptr;
+
if (get_W && get_R) {
CudnnRNN tmp_rnn_desc;
- ORT_RETURN_IF_ERROR(tmp_rnn_desc.Set(DefaultCudnnHandle(),
+ auto proj_size = hidden_size_;
+ ORT_RETURN_IF_ERROR(tmp_rnn_desc.Set(W->Shape()[2], // input_size
hidden_size_,
+ proj_size,
RNN_NUM_LAYERS,
cudnn_dropout_desc_,
cudnn_direction_mode_,
rnn_mode_,
- CudnnTensor::GetDataType(),
- GetDeviceProp()));
+ has_bias,
+ CudnnTensor::GetDataType()));
if (get_B) {
- ORT_RETURN_IF_ERROR(ReorganizeWeights(W, R, B, w_data_cache_, w_desc_cache_, tmp_rnn_desc, nullptr));
+ ORT_RETURN_IF_ERROR(ReorganizeWeights(W, R, B,
+ w_data_cache_size_in_bytes_, w_data_cache_, w_desc_cache_,
+ tmp_rnn_desc, nullptr));
} else {
- ORT_RETURN_IF_ERROR(ReorganizeWeights(W, R, nullptr, w_data_cache_, w_desc_cache_, tmp_rnn_desc, nullptr));
+ ORT_RETURN_IF_ERROR(ReorganizeWeights(W, R, nullptr,
+ w_data_cache_size_in_bytes_, w_data_cache_, w_desc_cache_,
+ tmp_rnn_desc, nullptr));
}
cudaStreamSynchronize(nullptr);
+
weight_cached_ = true;
}
@@ -158,17 +179,72 @@ Status CudnnRnnBase::ComputeInternal(OpKernelContext* ctx) const {
ORT_ENFORCE(nullptr != X);
// optional inputs
- const Tensor* sequence_lens = ctx->Input(RNN_Input_Index::sequence_lens); // [batch_size]
- const Tensor* initial_h = ctx->Input(RNN_Input_Index::initial_h); // initial hidden. [num_directions_, batch_size, hidden_size_]
+ // [batch_size]
+ const Tensor* sequence_lens = ctx->Input(RNN_Input_Index::sequence_lens);
+ // initial hidden. [num_directions_, batch_size, hidden_size_]
+ const Tensor* initial_h = ctx->Input(RNN_Input_Index::initial_h);
const Tensor* initial_c(nullptr);
if (rnn_mode_ == CUDNN_LSTM) {
- initial_c = ctx->Input(RNN_Input_Index::initial_c); // initial cell. [num_directions_, batch_size, hidden_size_]
+ // initial cell. [num_directions_, batch_size, hidden_size_]
+ initial_c = ctx->Input(RNN_Input_Index::initial_c);
}
+ size_t proj_size = hidden_size_;
int64_t seq_length = X->Shape()[0];
int64_t batch_size = X->Shape()[1];
int64_t input_size = X->Shape()[2];
+ // we thread a single input as sequence_lens of length 1, require to expand to [batch_size]?
+ std::vector sequence_lengths_temp;
+ if (!sequence_lens) {
+ sequence_lengths_temp.resize(batch_size, gsl::narrow_cast(seq_length));
+ }
+
+ const int32_t* sequence_lens_data = (sequence_lens == nullptr)
+ ? sequence_lengths_temp.data()
+ : sequence_lens->Data();
+
+ // cuDNN doesn't support 0 sequence inside the batch, find the 0 sequence and set it to 1
+ // there's a ZeroMask kernel to reset the result to 0 for the 0 sequence
+ int64_t zero_seq_count = 0;
+ std::vector zero_seq_index_cache(batch_size, 0);
+
+ CudaAsyncBuffer sequence_lens_buffer(this, batch_size);
+ int32_t* seq_len_array = sequence_lens_buffer.CpuPtr();
+
+ // 0-len sequences are not supported by cuDNN.
+ // Replace them by sequences of len 1 and mask them out with SetZeroSequences
+ for (int i = 0; i < batch_size; ++i) {
+ if (0 == sequence_lens_data[i]) {
+ seq_len_array[i] = 1;
+ zero_seq_index_cache[zero_seq_count] = i;
+ ++zero_seq_count;
+ } else {
+ seq_len_array[i] = sequence_lens_data[i];
+ }
+ }
+
+ // Calculate the zero position cache for reverse direction if it's bidirectional
+ // The cache is for Y_h or Y_c, and the 1st sequence for Y, no need to do it for other sequence in Y since
+ // we hacked the 0 sequence to 1
+ if (zero_seq_count && num_directions_ > 1) {
+ zero_seq_index_cache.resize(zero_seq_count * num_directions_);
+ for (int64_t i = 0; i < zero_seq_count; ++i) {
+ zero_seq_index_cache[static_cast(zero_seq_count) + i] =
+ static_cast(batch_size + zero_seq_index_cache[i]);
+ }
+ zero_seq_count *= num_directions_;
+ }
+
+ // Prior to cuDNN 8.9.1 the sequence lens buffer must be passed to cudnnRNNForward and thus is must
+ // be copied to the GPU always.
+ ORT_RETURN_IF_ERROR(sequence_lens_buffer.CopyToGpu(ctx->GetComputeStream()));
+ // Starting with cuDNN 8.9.1 the sequence lens buffer is ignored by cudnnRNNForward and thus it must
+ // be copied to the GPU only for the ReverseBySequence kernels.
+ // if (reverse_) {
+ // ORT_RETURN_IF_ERROR(sequence_lens_buffer.CopyToGpu(ctx->GetComputeStream()));
+ // }
+
// optional outputs
TensorShapeVector dims_Y({seq_length, num_directions_, batch_size, hidden_size_});
TensorShapeVector dims_hxy({RNN_NUM_LAYERS * num_directions_, batch_size, hidden_size_});
@@ -177,25 +253,6 @@ Status CudnnRnnBase::ComputeInternal(OpKernelContext* ctx) const {
Tensor* Y_h = ctx->Output(Output_Index::Y_h, dims_hxy);
Tensor* Y_c = ctx->Output(Output_Index::Y_c, dims_yc);
- std::vector dims_x({batch_size, input_size, 1});
- std::vector dims_y({batch_size, hidden_size_ * num_directions_, 1});
-
- CudnnTensor x_desc_temp;
- ORT_RETURN_IF_ERROR(x_desc_temp.Set(dims_x, CudnnTensor::GetDataType()));
- CudnnTensor y_desc_temp;
- ORT_RETURN_IF_ERROR(y_desc_temp.Set(dims_y, CudnnTensor::GetDataType()));
- std::vector x_desc(seq_length, x_desc_temp);
- std::vector y_desc(seq_length, y_desc_temp);
-
- CudnnTensor hx_desc;
- CudnnTensor cx_desc;
- CudnnTensor y_h_desc;
- CudnnTensor y_c_desc;
- ORT_RETURN_IF_ERROR(hx_desc.Set(dims_hxy, CudnnTensor::GetDataType()));
- ORT_RETURN_IF_ERROR(cx_desc.Set(dims_hxy, CudnnTensor::GetDataType()));
- ORT_RETURN_IF_ERROR(y_h_desc.Set(dims_hxy, CudnnTensor::GetDataType()));
- ORT_RETURN_IF_ERROR(y_c_desc.Set(dims_hxy, CudnnTensor::GetDataType()));
-
IAllocatorUniquePtr x_reversed_data;
const T* x_data = X->Data();
if (reverse_) {
@@ -203,6 +260,7 @@ Status CudnnRnnBase::ComputeInternal(OpKernelContext* ctx) const {
x_reversed_data = GetScratchBuffer(seq_length * batch_size * input_size, ctx->GetComputeStream());
ReverseBySequence(Stream(ctx),
gsl::narrow_cast(seq_length),
+ sequence_lens_buffer.GpuPtr(),
gsl::narrow_cast(batch_size),
gsl::narrow_cast(input_size),
reinterpret_cast(x_data),
@@ -226,115 +284,82 @@ Status CudnnRnnBase::ComputeInternal(OpKernelContext* ctx) const {
y_data = y_alloc_data.get();
}
- const int32_t* sequence_lens_data = (sequence_lens == nullptr) ? nullptr : sequence_lens->Data();
+ const Tensor* B = ctx->Input(RNN_Input_Index::B);
+ bool has_bias = B != nullptr;
CudnnRNN rnn_desc;
- ORT_RETURN_IF_ERROR(rnn_desc.Set(GetCudnnHandle(ctx),
+ ORT_RETURN_IF_ERROR(rnn_desc.Set(input_size,
hidden_size_,
+ proj_size,
RNN_NUM_LAYERS,
cudnn_dropout_desc_,
cudnn_direction_mode_,
rnn_mode_,
- CudnnTensor::GetDataType(),
- GetDeviceProp()));
+ has_bias,
+ CudnnTensor::GetDataType()));
// Prepare the weight data
+ size_t w_data_size_in_bytes = 0;
IAllocatorUniquePtr w_data;
CudnnFilterDescriptor w_desc;
if (!weight_cached_) {
const Tensor& W = *ctx->Input(RNN_Input_Index::W);
const Tensor& R = *ctx->Input(RNN_Input_Index::R);
const Tensor* B = ctx->Input(RNN_Input_Index::B);
- ORT_RETURN_IF_ERROR(ReorganizeWeights(&W, &R, B, w_data, w_desc, rnn_desc, ctx->GetComputeStream()));
+ ORT_RETURN_IF_ERROR(ReorganizeWeights(&W, &R, B, w_data_size_in_bytes, w_data, w_desc,
+ rnn_desc, ctx->GetComputeStream()));
}
- // CUDNN_RNN_DATA_LAYOUT_SEQ_MAJOR_UNPACKED works with CUDNN_RNN_PADDED_IO_ENABLED, so that it will auto fill 0 for the shorter sequences
- CUDNN_RETURN_IF_ERROR(cudnnSetRNNPaddingMode(rnn_desc, CUDNN_RNN_PADDED_IO_ENABLED));
+ CudnnDataTensor x_desc1;
+ ORT_RETURN_IF_ERROR(x_desc1.Set(CudnnTensor::GetDataType(), seq_length, batch_size,
+ input_size, seq_len_array));
+ CudnnDataTensor y_desc1;
+ ORT_RETURN_IF_ERROR(y_desc1.Set(CudnnTensor::GetDataType(), seq_length, batch_size,
+ ((rnn_mode_ == CUDNN_LSTM) ? proj_size : hidden_size_) * num_directions_,
+ seq_len_array));
- size_t workspace_bytes;
- CUDNN_RETURN_IF_ERROR(cudnnGetRNNWorkspaceSize(GetCudnnHandle(ctx), rnn_desc, gsl::narrow_cast(seq_length), x_desc.data(), &workspace_bytes));
- auto workspace_cuda = GetScratchBuffer(workspace_bytes, ctx->GetComputeStream());
- int64_t zero_seq_count = 0;
- std::vector